Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev-doc/AI.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ The runtime handles two special cases transparently:
|-------|------|----------|
| All | `DT_AI_OPT_ALL` | default, fastest. works for most models |
| Basic | `DT_AI_OPT_BASIC` | constant folding + redundancy elimination only. Required for SAM2 decoder (aggressive optimization breaks shape inference on dynamic dims) |
| Disabled | `DT_AI_OPT_DISABLED` | reserved for future use |
| Disabled | `DT_AI_OPT_DISABLED` | disables graph transforms entirely. Use for ORT regressions like the SAM2.1 encoder UAF during session initialization |

### Symbolic Dimension Overrides

Expand Down
57 changes: 34 additions & 23 deletions src/common/ai/segmentation.c
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,37 @@ dt_seg_context_t *dt_seg_load(dt_ai_environment_t *env, const char *model_id)
if(!env || !model_id)
return NULL;

// detect model type up front so we can choose safe ORT settings before the
// first session is created.
const dt_ai_model_info_t *minfo
= dt_ai_get_model_info_by_id(env, model_id);
const char *arch = minfo ? minfo->arch : "";

dt_seg_model_type_t model_type;
if(strcmp(arch, "sam2") == 0)
model_type = DT_SEG_MODEL_SAM;
else if(strcmp(arch, "segnext") == 0)
model_type = DT_SEG_MODEL_SEGNEXT;
else
{
dt_print(DT_DEBUG_AI,
"[segmentation] unknown arch '%s' for %s",
arch, model_id);
return NULL;
}

// ORT 1.24.x can hit a heap-use-after-free during graph optimization while
// constant-folding the SAM 2.1 encoder. Load that encoder with optimization
// disabled so session creation stays on the safe path. SegNext and other
// non-SAM encoders keep the normal optimization level.
const dt_ai_opt_level_t encoder_opt
= (model_type == DT_SEG_MODEL_SAM) ? DT_AI_OPT_DISABLED : DT_AI_OPT_ALL;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too aggressive, results in performance regression for all


// provider is resolved from the environment (read from config at init time),
// passing AUTO lets dt_ai_load_model resolve it
// passing AUTO lets dt_ai_load_model_ext resolve it
dt_ai_context_t *encoder
= dt_ai_load_model(env, model_id, "encoder.onnx", DT_AI_PROVIDER_AUTO);
= dt_ai_load_model_ext(env, model_id, "encoder.onnx", DT_AI_PROVIDER_AUTO,
encoder_opt, NULL, 0);
if(!encoder)
{
dt_print(DT_DEBUG_AI, "[segmentation] failed to load encoder for %s", model_id);
Expand Down Expand Up @@ -329,23 +356,7 @@ dt_seg_context_t *dt_seg_load(dt_ai_environment_t *env, const char *model_id)
ctx->enc_order[0], ctx->enc_order[1], ctx->enc_order[2], ctx->enc_order[3],
ctx->n_enc_outputs);

// detect model type from arch field in model registry
const dt_ai_model_info_t *minfo
= dt_ai_get_model_info_by_id(env, model_id);
const char *arch = minfo ? minfo->arch : "";

if(strcmp(arch, "sam2") == 0)
ctx->model_type = DT_SEG_MODEL_SAM;
else if(strcmp(arch, "segnext") == 0)
ctx->model_type = DT_SEG_MODEL_SEGNEXT;
else
{
dt_print(DT_DEBUG_AI,
"[segmentation] unknown arch '%s' for %s",
arch, model_id);
dt_seg_free(ctx);
return NULL;
}
ctx->model_type = model_type;

// SAM requires external ImageNet normalization; SegNext bakes it into the encoder
ctx->normalize = (ctx->model_type == DT_SEG_MODEL_SAM);
Expand Down Expand Up @@ -571,15 +582,15 @@ void dt_seg_warmup_decoder(dt_seg_context_t *ctx)
.shape = has_mask_shape, .ndim = 1};

int64_t masks_shape[4] = {1, nm, dec_h, dec_w};
int64_t iou_shape[2] = {1, nm};
int64_t lr_shape[4] = {1, nm, pm_dim, pm_dim};
float iou_buf[MAX_NUM_MASKS];

dt_ai_tensor_t outputs[3];
int n_out;

if(is_sam)
{
int64_t iou_shape[2] = {1, nm};
int64_t lr_shape[4] = {1, nm, pm_dim, pm_dim};
const int dec_outputs = dt_ai_get_output_count(ctx->decoder);

outputs[0] = (dt_ai_tensor_t){
Expand Down Expand Up @@ -848,14 +859,15 @@ float *dt_seg_compute_mask(dt_seg_context_t *ctx,
dt_ai_tensor_t dec_outputs[3];
int n_dec_out;
int64_t masks_shape[4] = {1, nm, dec_h, dec_w};
int64_t iou_shape[2] = {1, nm};
int64_t low_res_shape[4] = {1, nm, pm_dim, pm_dim};

float iou_pred[MAX_NUM_MASKS];
float *low_res = NULL;

if(is_sam)
{
// SAM: masks [1,N,H,W] + iou [1,N], optionally low_res [1,N,pm,pm]
int64_t iou_shape[2] = {1, nm};
const int dec_out_count = dt_ai_get_output_count(ctx->decoder);

dec_outputs[0] = (dt_ai_tensor_t){
Expand All @@ -876,7 +888,6 @@ float *dt_seg_compute_mask(dt_seg_context_t *ctx,
g_free(masks);
return NULL;
}
int64_t low_res_shape[4] = {1, nm, pm_dim, pm_dim};
dec_outputs[2] = (dt_ai_tensor_t){
.data = low_res, .type = DT_AI_FLOAT, .shape = low_res_shape, .ndim = 4};
n_dec_out = 3;
Expand Down