feat: add model= parameter to get_features for pre-loaded model reuse#114
feat: add model= parameter to get_features for pre-loaded model reuse#114
Conversation
Adds an optional `model` parameter to `get_features()`. When provided, the pre-loaded model instance is used directly and `model_type`/`model_path` are ignored. When `model=None` (default), behaviour is unchanged. This allows batch workflows to load a model once and reuse it across many slides, avoiding repeated disk I/O and model initialisation overhead. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR adds support for reusing a pre-loaded feature-extractor model by introducing an optional model parameter to get_features(), enabling batch workflows to avoid repeated model loads.
Changes:
- Add
model=Noneparameter toget_features()to allow passing a pre-loaded model instance. - Conditionalize model-loading logic so it only runs when
model is None. - Update
get_features()docstring to document the new parameter.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mussel/utils/feature_extract.py
Outdated
| # Auto-set aggregation_method to "model" if slide_model_type is specified | ||
| if ( | ||
| use_slide_encoder | ||
| and slide_model_type is not None | ||
| and aggregation_method != "model" | ||
| ): | ||
| logger.info( | ||
| f"Auto-selecting patch encoder {required_patch_encoder} " | ||
| f"as required by slide encoder {slide_model_type}" | ||
| f"Auto-setting aggregation_method to 'model' since slide_model_type " | ||
| f"({slide_model_type}) is specified" | ||
| ) | ||
| model_type = required_patch_encoder | ||
| # Validate compatibility | ||
| validate_slide_encoder_compatibility(model_type, slide_model_type) | ||
| aggregation_method = "model" | ||
|
|
There was a problem hiding this comment.
aggregation_method auto-setting based on slide_model_type is now only executed when model is None. If a caller passes a pre-loaded model with use_slide_encoder=True and slide_model_type set, the function will no longer auto-switch aggregation_method to "model", changing behavior compared to the non-preloaded path. Consider running this auto-setting logic regardless of whether the patch model is preloaded.
There was a problem hiding this comment.
Fixed — the aggregation_method auto-set block is now outside the if model is None: guard and runs unconditionally.
mussel/utils/feature_extract.py
Outdated
| # Auto-infer patch encoder from slide encoder if using model-based aggregation | ||
| if ( | ||
| use_slide_encoder | ||
| and aggregation_method == "model" | ||
| and slide_model_type is not None | ||
| ): | ||
| required_patch_encoder = get_required_patch_encoder(slide_model_type) | ||
| if model_type != required_patch_encoder: | ||
| logger.info( | ||
| f"Auto-selecting patch encoder {required_patch_encoder} " | ||
| f"as required by slide encoder {slide_model_type}" | ||
| ) | ||
| model_type = required_patch_encoder | ||
| # Validate compatibility | ||
| validate_slide_encoder_compatibility(model_type, slide_model_type) | ||
|
|
There was a problem hiding this comment.
Slide-encoder compatibility validation and required patch-encoder inference are also skipped when a pre-loaded model is provided, which can allow an incompatible patch encoder to slip through (leading to harder-to-debug failures later). Consider still calling validate_slide_encoder_compatibility(...) when use_slide_encoder + slide_model_type indicates a required patch encoder, even if you don't auto-load/override the patch model.
There was a problem hiding this comment.
Fixed — validate_slide_encoder_compatibility is now called unconditionally whenever use_slide_encoder + aggregation_method="model" + slide_model_type are set, regardless of whether the patch model is pre-loaded. When loading from disk we still auto-infer model_type first; with a pre-loaded model we validate against the caller-supplied model_type.
mussel/utils/feature_extract.py
Outdated
| if model is None: | ||
| logger.info("loading model checkpoint") | ||
|
|
||
| # Auto-set aggregation_method to "model" if slide_model_type is specified | ||
| if ( | ||
| use_slide_encoder | ||
| and slide_model_type is not None | ||
| and aggregation_method != "model" | ||
| ): | ||
| logger.info( | ||
| f"Auto-setting aggregation_method to 'model' since slide_model_type " | ||
| f"({slide_model_type}) is specified" | ||
| ) | ||
| aggregation_method = "model" | ||
| if gpu_device_ids: | ||
| gpu_device_id = gpu_device_ids | ||
|
|
||
| # Auto-infer patch encoder from slide encoder if using model-based aggregation | ||
| if ( | ||
| use_slide_encoder | ||
| and aggregation_method == "model" | ||
| and slide_model_type is not None | ||
| ): | ||
| required_patch_encoder = get_required_patch_encoder(slide_model_type) | ||
| if model_type != required_patch_encoder: | ||
| # Auto-set aggregation_method to "model" if slide_model_type is specified | ||
| if ( | ||
| use_slide_encoder | ||
| and slide_model_type is not None | ||
| and aggregation_method != "model" | ||
| ): | ||
| logger.info( | ||
| f"Auto-selecting patch encoder {required_patch_encoder} " | ||
| f"as required by slide encoder {slide_model_type}" | ||
| f"Auto-setting aggregation_method to 'model' since slide_model_type " | ||
| f"({slide_model_type}) is specified" | ||
| ) | ||
| model_type = required_patch_encoder | ||
| # Validate compatibility | ||
| validate_slide_encoder_compatibility(model_type, slide_model_type) | ||
| aggregation_method = "model" | ||
|
|
||
| model_factory = get_model_factory(model_type) | ||
| if model_factory is None: | ||
| raise ValueError("model not recognized") | ||
| model = model_factory.get_model(model_path, use_gpu, gpu_device_id) | ||
| # Auto-infer patch encoder from slide encoder if using model-based aggregation | ||
| if ( | ||
| use_slide_encoder | ||
| and aggregation_method == "model" | ||
| and slide_model_type is not None | ||
| ): | ||
| required_patch_encoder = get_required_patch_encoder(slide_model_type) | ||
| if model_type != required_patch_encoder: | ||
| logger.info( | ||
| f"Auto-selecting patch encoder {required_patch_encoder} " | ||
| f"as required by slide encoder {slide_model_type}" | ||
| ) | ||
| model_type = required_patch_encoder | ||
| # Validate compatibility | ||
| validate_slide_encoder_compatibility(model_type, slide_model_type) | ||
|
|
||
| model_factory = get_model_factory(model_type) | ||
| if model_factory is None: | ||
| raise ValueError("model not recognized") | ||
| model = model_factory.get_model(model_path, use_gpu, gpu_device_id) | ||
| else: | ||
| logger.info("using pre-loaded model") | ||
| preprocessing = model.get_preprocessing_fun() |
There was a problem hiding this comment.
There are existing unit tests for mussel.utils.feature_extract (e.g., tests/mussel/utils/test_feature_extract.py), but none cover the new model reuse path in get_features(). Please add tests that (1) ensure get_model_factory(...).get_model(...) is not called when model is provided and (2) confirm positional argument behavior remains correct for existing call patterns.
There was a problem hiding this comment.
Added tests/mussel/utils/test_feature_extract.py with 7 new tests covering: get_model_factory not called when model= provided; get_model_factory called when it isn't; positional args unchanged (batch_size not mis-interpreted); TypeError on bad model/slide_model interface; slide encoder factory not called when slide_model= provided; validate_slide_encoder_compatibility called even with pre-loaded patch model.
…ly placement - Move model= and new slide_model= to end of get_features() signature so positional callers are unaffected (addresses review comment) - Add slide_model= to _apply_slide_aggregation() for consistency - Move aggregation_method auto-set and compatibility validation outside the model-loading branch so they run for pre-loaded models too - Validate pre-loaded model interfaces up-front with clear TypeError - validate_slide_encoder_compatibility() now always called when use_slide_encoder + slide_model_type are set Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Covers: - get_model_factory not called when model= provided - get_model_factory called when model= not provided - positional args unchanged (batch_size not interpreted as model) - TypeError on invalid model/slide_model interface - slide encoder factory not called when slide_model= provided - validate_slide_encoder_compatibility called even with pre-loaded model Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Summary
Adds an optional
modelparameter toget_features(). When provided, the pre-loaded model instance is used directly andmodel_type/model_pathare ignored. Whenmodel=None(default), behaviour is completely unchanged.Motivation
Batch workflows that process many slides need to load a feature extractor once and reuse it. Without this parameter,
get_features()reloads the model from disk on every call — for 10 slides with CTransPath + Optimus that's 20 unnecessary model loads.The parameter was present in the
mosaic-devbranch but was dropped when v1.3.0 was cut frommain, breaking downstream callers (mosaic) that relied on it.Changes
mussel/utils/feature_extract.py—get_features():model=Noneparameter aftermodel_pathif model is None:blockelse: logger.info("using pre-loaded model")branchpreprocessing = model.get_preprocessing_fun()outside the conditional (needed in both paths)Backward compatibility
Default is
None, so all existing call-sites are unaffected.Example usage