From 2ecc6452aa60884ba907b3ae3352eb30801209ef Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Wed, 8 Apr 2026 16:57:07 -0600 Subject: [PATCH] Fix SAM2 segmentation ASan crashes Disable ORT graph optimization for the SAM2.1 encoder to avoid the ONNX Runtime 1.24.x session-init UAF, and keep dynamic decoder output shape buffers alive across dt_ai_run calls so ORT can write resolved dims safely during warmup and decode. Co-authored-by: codex@openai.com --- dev-doc/AI.md | 2 +- src/common/ai/segmentation.c | 57 +++++++++++++++++++++--------------- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/dev-doc/AI.md b/dev-doc/AI.md index 5ec549b44ff6..fb21831e6f76 100644 --- a/dev-doc/AI.md +++ b/dev-doc/AI.md @@ -120,7 +120,7 @@ The runtime handles two special cases transparently: |-------|------|----------| | All | `DT_AI_OPT_ALL` | default, fastest. works for most models | | Basic | `DT_AI_OPT_BASIC` | constant folding + redundancy elimination only. Required for SAM2 decoder (aggressive optimization breaks shape inference on dynamic dims) | -| Disabled | `DT_AI_OPT_DISABLED` | reserved for future use | +| Disabled | `DT_AI_OPT_DISABLED` | disables graph transforms entirely. Use for ORT regressions like the SAM2.1 encoder UAF during session initialization | ### Symbolic Dimension Overrides diff --git a/src/common/ai/segmentation.c b/src/common/ai/segmentation.c index 61fc4c710b08..3cc2876f54db 100644 --- a/src/common/ai/segmentation.c +++ b/src/common/ai/segmentation.c @@ -211,10 +211,37 @@ dt_seg_context_t *dt_seg_load(dt_ai_environment_t *env, const char *model_id) if(!env || !model_id) return NULL; + // detect model type up front so we can choose safe ORT settings before the + // first session is created. + const dt_ai_model_info_t *minfo + = dt_ai_get_model_info_by_id(env, model_id); + const char *arch = minfo ? minfo->arch : ""; + + dt_seg_model_type_t model_type; + if(strcmp(arch, "sam2") == 0) + model_type = DT_SEG_MODEL_SAM; + else if(strcmp(arch, "segnext") == 0) + model_type = DT_SEG_MODEL_SEGNEXT; + else + { + dt_print(DT_DEBUG_AI, + "[segmentation] unknown arch '%s' for %s", + arch, model_id); + return NULL; + } + + // ORT 1.24.x can hit a heap-use-after-free during graph optimization while + // constant-folding the SAM 2.1 encoder. Load that encoder with optimization + // disabled so session creation stays on the safe path. SegNext and other + // non-SAM encoders keep the normal optimization level. + const dt_ai_opt_level_t encoder_opt + = (model_type == DT_SEG_MODEL_SAM) ? DT_AI_OPT_DISABLED : DT_AI_OPT_ALL; + // provider is resolved from the environment (read from config at init time), - // passing AUTO lets dt_ai_load_model resolve it + // passing AUTO lets dt_ai_load_model_ext resolve it dt_ai_context_t *encoder - = dt_ai_load_model(env, model_id, "encoder.onnx", DT_AI_PROVIDER_AUTO); + = dt_ai_load_model_ext(env, model_id, "encoder.onnx", DT_AI_PROVIDER_AUTO, + encoder_opt, NULL, 0); if(!encoder) { dt_print(DT_DEBUG_AI, "[segmentation] failed to load encoder for %s", model_id); @@ -329,23 +356,7 @@ dt_seg_context_t *dt_seg_load(dt_ai_environment_t *env, const char *model_id) ctx->enc_order[0], ctx->enc_order[1], ctx->enc_order[2], ctx->enc_order[3], ctx->n_enc_outputs); - // detect model type from arch field in model registry - const dt_ai_model_info_t *minfo - = dt_ai_get_model_info_by_id(env, model_id); - const char *arch = minfo ? minfo->arch : ""; - - if(strcmp(arch, "sam2") == 0) - ctx->model_type = DT_SEG_MODEL_SAM; - else if(strcmp(arch, "segnext") == 0) - ctx->model_type = DT_SEG_MODEL_SEGNEXT; - else - { - dt_print(DT_DEBUG_AI, - "[segmentation] unknown arch '%s' for %s", - arch, model_id); - dt_seg_free(ctx); - return NULL; - } + ctx->model_type = model_type; // SAM requires external ImageNet normalization; SegNext bakes it into the encoder ctx->normalize = (ctx->model_type == DT_SEG_MODEL_SAM); @@ -571,6 +582,8 @@ void dt_seg_warmup_decoder(dt_seg_context_t *ctx) .shape = has_mask_shape, .ndim = 1}; int64_t masks_shape[4] = {1, nm, dec_h, dec_w}; + int64_t iou_shape[2] = {1, nm}; + int64_t lr_shape[4] = {1, nm, pm_dim, pm_dim}; float iou_buf[MAX_NUM_MASKS]; dt_ai_tensor_t outputs[3]; @@ -578,8 +591,6 @@ void dt_seg_warmup_decoder(dt_seg_context_t *ctx) if(is_sam) { - int64_t iou_shape[2] = {1, nm}; - int64_t lr_shape[4] = {1, nm, pm_dim, pm_dim}; const int dec_outputs = dt_ai_get_output_count(ctx->decoder); outputs[0] = (dt_ai_tensor_t){ @@ -848,6 +859,8 @@ float *dt_seg_compute_mask(dt_seg_context_t *ctx, dt_ai_tensor_t dec_outputs[3]; int n_dec_out; int64_t masks_shape[4] = {1, nm, dec_h, dec_w}; + int64_t iou_shape[2] = {1, nm}; + int64_t low_res_shape[4] = {1, nm, pm_dim, pm_dim}; float iou_pred[MAX_NUM_MASKS]; float *low_res = NULL; @@ -855,7 +868,6 @@ float *dt_seg_compute_mask(dt_seg_context_t *ctx, if(is_sam) { // SAM: masks [1,N,H,W] + iou [1,N], optionally low_res [1,N,pm,pm] - int64_t iou_shape[2] = {1, nm}; const int dec_out_count = dt_ai_get_output_count(ctx->decoder); dec_outputs[0] = (dt_ai_tensor_t){ @@ -876,7 +888,6 @@ float *dt_seg_compute_mask(dt_seg_context_t *ctx, g_free(masks); return NULL; } - int64_t low_res_shape[4] = {1, nm, pm_dim, pm_dim}; dec_outputs[2] = (dt_ai_tensor_t){ .data = low_res, .type = DT_AI_FLOAT, .shape = low_res_shape, .ndim = 4}; n_dec_out = 3;