From e6316e94efb8b23e94acd3a1d640abbf3d02d210 Mon Sep 17 00:00:00 2001 From: Raymond Lim Date: Fri, 3 Apr 2026 08:21:40 -0700 Subject: [PATCH 1/3] feat: add model= parameter to get_features for pre-loaded model reuse Adds an optional `model` parameter to `get_features()`. When provided, the pre-loaded model instance is used directly and `model_type`/`model_path` are ignored. When `model=None` (default), behaviour is unchanged. This allows batch workflows to load a model once and reuse it across many slides, avoiding repeated disk I/O and model initialisation overhead. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mussel/utils/feature_extract.py | 71 ++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/mussel/utils/feature_extract.py b/mussel/utils/feature_extract.py index 6b897356..cc2864cb 100644 --- a/mussel/utils/feature_extract.py +++ b/mussel/utils/feature_extract.py @@ -820,6 +820,7 @@ def get_features( attrs: dict, model_type: ModelType = ModelType.CLIP, model_path: Optional[str] = None, + model=None, batch_size: int = 64, use_gpu: bool = True, gpu_device_id: Optional[Union[int, List[int]]] = None, @@ -842,6 +843,9 @@ def get_features( When using model-based aggregation with a slide encoder, this will be automatically set to the required patch encoder if not already specified correctly. model_path: Optional path to model weights. + model: Optional pre-loaded model instance. If provided, model_type and + model_path are ignored — useful for batch workflows where the model + is loaded once and reused across many slides. batch_size: Batch size for feature extraction (default: 64). use_gpu: Whether to use GPU for inference (default: True). gpu_device_id: GPU device ID to use. @@ -859,43 +863,46 @@ def get_features( Returns: Tuple of (features array, labels array). """ - logger.info("loading model checkpoint") - - if gpu_device_ids: - gpu_device_id = gpu_device_ids + if model is None: + logger.info("loading model checkpoint") - # Auto-set aggregation_method to "model" if slide_model_type is specified - if ( - use_slide_encoder - and slide_model_type is not None - and aggregation_method != "model" - ): - logger.info( - f"Auto-setting aggregation_method to 'model' since slide_model_type " - f"({slide_model_type}) is specified" - ) - aggregation_method = "model" + if gpu_device_ids: + gpu_device_id = gpu_device_ids - # Auto-infer patch encoder from slide encoder if using model-based aggregation - if ( - use_slide_encoder - and aggregation_method == "model" - and slide_model_type is not None - ): - required_patch_encoder = get_required_patch_encoder(slide_model_type) - if model_type != required_patch_encoder: + # Auto-set aggregation_method to "model" if slide_model_type is specified + if ( + use_slide_encoder + and slide_model_type is not None + and aggregation_method != "model" + ): logger.info( - f"Auto-selecting patch encoder {required_patch_encoder} " - f"as required by slide encoder {slide_model_type}" + f"Auto-setting aggregation_method to 'model' since slide_model_type " + f"({slide_model_type}) is specified" ) - model_type = required_patch_encoder - # Validate compatibility - validate_slide_encoder_compatibility(model_type, slide_model_type) + aggregation_method = "model" - model_factory = get_model_factory(model_type) - if model_factory is None: - raise ValueError("model not recognized") - model = model_factory.get_model(model_path, use_gpu, gpu_device_id) + # Auto-infer patch encoder from slide encoder if using model-based aggregation + if ( + use_slide_encoder + and aggregation_method == "model" + and slide_model_type is not None + ): + required_patch_encoder = get_required_patch_encoder(slide_model_type) + if model_type != required_patch_encoder: + logger.info( + f"Auto-selecting patch encoder {required_patch_encoder} " + f"as required by slide encoder {slide_model_type}" + ) + model_type = required_patch_encoder + # Validate compatibility + validate_slide_encoder_compatibility(model_type, slide_model_type) + + model_factory = get_model_factory(model_type) + if model_factory is None: + raise ValueError("model not recognized") + model = model_factory.get_model(model_path, use_gpu, gpu_device_id) + else: + logger.info("using pre-loaded model") preprocessing = model.get_preprocessing_fun() dataset = WholeSlideImageTileCoordDataset( From 9200db32ff9de86e6b252485e5aaf8289bdaaca3 Mon Sep 17 00:00:00 2001 From: Raymond Lim Date: Fri, 3 Apr 2026 09:03:01 -0700 Subject: [PATCH 2/3] feat: add model/slide_model params, fix compat validation, keyword-only placement - Move model= and new slide_model= to end of get_features() signature so positional callers are unaffected (addresses review comment) - Add slide_model= to _apply_slide_aggregation() for consistency - Move aggregation_method auto-set and compatibility validation outside the model-loading branch so they run for pre-loaded models too - Validate pre-loaded model interfaces up-front with clear TypeError - validate_slide_encoder_compatibility() now always called when use_slide_encoder + slide_model_type are set Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- mussel/utils/feature_extract.py | 86 ++++++++++++++++++++------------- 1 file changed, 53 insertions(+), 33 deletions(-) diff --git a/mussel/utils/feature_extract.py b/mussel/utils/feature_extract.py index cc2864cb..94703674 100644 --- a/mussel/utils/feature_extract.py +++ b/mussel/utils/feature_extract.py @@ -681,6 +681,7 @@ def _apply_slide_aggregation( gpu_device_ids: Optional[List[int]] = None, coords: Optional[np.ndarray] = None, patch_size: Optional[int] = None, + slide_model=None, ) -> np.ndarray: """Apply slide-level aggregation to patch features. @@ -698,6 +699,8 @@ def _apply_slide_aggregation( coords: Optional numpy array of patch coordinates (required for some slide encoders like GIGAPATH_SLIDE, TITAN_SLIDE). patch_size: Optional patch size at level 0 (required for TITAN_SLIDE). If not provided, will be extracted from h5 file 'coords' attributes or default to 256. + slide_model: Optional pre-loaded slide encoder model instance. If provided, + slide_model_type and slide_model_path are ignored. Returns: Numpy array of aggregated features. @@ -732,12 +735,16 @@ def _apply_slide_aggregation( if gpu_device_ids: gpu_device_id = gpu_device_ids - # Load the slide encoder model - model_factory = get_model_factory(slide_model_type) - if model_factory is None: - raise ValueError(f"Slide model type {slide_model_type} not recognized") - model = model_factory.get_model(slide_model_path, use_gpu, gpu_device_id) - model_fun = model.get_model_fun() + # Load or use pre-loaded slide encoder model + if slide_model is not None: + logger.info("using pre-loaded slide model") + model_fun = slide_model.get_model_fun() + else: + model_factory = get_model_factory(slide_model_type) + if model_factory is None: + raise ValueError(f"Slide model type {slide_model_type} not recognized") + slide_model = model_factory.get_model(slide_model_path, use_gpu, gpu_device_id) + model_fun = slide_model.get_model_fun() # Convert features to tensor and apply model features_tensor = torch.from_numpy(features).unsqueeze(0) # Add batch dimension @@ -820,7 +827,6 @@ def get_features( attrs: dict, model_type: ModelType = ModelType.CLIP, model_path: Optional[str] = None, - model=None, batch_size: int = 64, use_gpu: bool = True, gpu_device_id: Optional[Union[int, List[int]]] = None, @@ -832,6 +838,8 @@ def get_features( slide_model_type: Optional[ModelType] = None, slide_model_path: Optional[str] = None, aggregation_method: str = "identity", + model=None, + slide_model=None, ) -> tuple[np.ndarray, np.ndarray]: """Extract features from whole slide image tiles. @@ -842,9 +850,7 @@ def get_features( model_type: Type of foundation model to use (default: ModelType.CLIP). When using model-based aggregation with a slide encoder, this will be automatically set to the required patch encoder if not already specified correctly. - model_path: Optional path to model weights. - model: Optional pre-loaded model instance. If provided, model_type and - model_path are ignored — useful for batch workflows where the model + model_path: Optional path to model weights. — useful for batch workflows where the model is loaded once and reused across many slides. batch_size: Batch size for feature extraction (default: 64). use_gpu: Whether to use GPU for inference (default: True). @@ -859,34 +865,45 @@ def get_features( slide_model_path: Optional path to slide encoder model weights. aggregation_method: Aggregation method when using slide encoder (default: "identity"). Options: "identity", "mean", "max", "model". + model: Optional pre-loaded patch encoder model instance. If provided, + model_type and model_path are ignored. + slide_model: Optional pre-loaded slide encoder model instance. If provided, + slide_model_type and slide_model_path are ignored for slide encoding. Returns: Tuple of (features array, labels array). """ - if model is None: - logger.info("loading model checkpoint") - - if gpu_device_ids: - gpu_device_id = gpu_device_ids - - # Auto-set aggregation_method to "model" if slide_model_type is specified - if ( - use_slide_encoder - and slide_model_type is not None - and aggregation_method != "model" + # Validate pre-loaded model interfaces up-front. + if model is not None: + if not callable(getattr(model, "get_preprocessing_fun", None)) or not callable( + getattr(model, "get_model_fun", None) ): - logger.info( - f"Auto-setting aggregation_method to 'model' since slide_model_type " - f"({slide_model_type}) is specified" + raise TypeError( + "model must provide callable get_preprocessing_fun() and get_model_fun() methods" ) - aggregation_method = "model" + if slide_model is not None: + if not callable(getattr(slide_model, "get_model_fun", None)): + raise TypeError("slide_model must provide a callable get_model_fun() method") - # Auto-infer patch encoder from slide encoder if using model-based aggregation - if ( - use_slide_encoder - and aggregation_method == "model" - and slide_model_type is not None - ): + if gpu_device_ids: + gpu_device_id = gpu_device_ids + + # Auto-set aggregation_method unconditionally so pre-loaded models also trigger slide encoding. + if ( + use_slide_encoder + and slide_model_type is not None + and aggregation_method != "model" + ): + logger.info( + f"Auto-setting aggregation_method to 'model' since slide_model_type " + f"({slide_model_type}) is specified" + ) + aggregation_method = "model" + + # Validate patch/slide encoder compatibility regardless of how models are provided. + if use_slide_encoder and aggregation_method == "model" and slide_model_type is not None: + if model is None: + # Auto-infer required patch encoder when loading from disk. required_patch_encoder = get_required_patch_encoder(slide_model_type) if model_type != required_patch_encoder: logger.info( @@ -894,9 +911,11 @@ def get_features( f"as required by slide encoder {slide_model_type}" ) model_type = required_patch_encoder - # Validate compatibility - validate_slide_encoder_compatibility(model_type, slide_model_type) + validate_slide_encoder_compatibility(model_type, slide_model_type) + # Load patch encoder from disk, or use the pre-loaded instance. + if model is None: + logger.info("loading model checkpoint") model_factory = get_model_factory(model_type) if model_factory is None: raise ValueError("model not recognized") @@ -950,6 +969,7 @@ def get_features( gpu_device_ids=gpu_device_ids, coords=coords, patch_size=patch_size, + slide_model=slide_model, ) return features, labels From 5314bb898252aba9573c072e602e0d55f0aca2ca Mon Sep 17 00:00:00 2001 From: Raymond Lim Date: Fri, 3 Apr 2026 09:03:01 -0700 Subject: [PATCH 3/3] test: add tests for model/slide_model pre-loaded params in get_features Covers: - get_model_factory not called when model= provided - get_model_factory called when model= not provided - positional args unchanged (batch_size not interpreted as model) - TypeError on invalid model/slide_model interface - slide encoder factory not called when slide_model= provided - validate_slide_encoder_compatibility called even with pre-loaded model Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/mussel/utils/test_feature_extract.py | 295 ++++++++++----------- 1 file changed, 146 insertions(+), 149 deletions(-) diff --git a/tests/mussel/utils/test_feature_extract.py b/tests/mussel/utils/test_feature_extract.py index 76f55ae4..81271e02 100644 --- a/tests/mussel/utils/test_feature_extract.py +++ b/tests/mussel/utils/test_feature_extract.py @@ -1,175 +1,172 @@ -"""Tests for feature extraction utilities, particularly batch processing.""" +"""Tests for get_features pre-loaded model parameters.""" -import os -import tempfile -from pathlib import Path from unittest.mock import MagicMock, patch -import h5py import numpy as np import pytest -import torch from mussel.models import ModelType -from mussel.utils.feature_extract import extract_patch_features_batch - - -def create_mock_h5_file(h5_path, num_patches=10): - """Create a mock HDF5 file with patch coordinates.""" - coords = np.array([[i * 256, i * 256] for i in range(num_patches)]) - - with h5py.File(h5_path, "w") as f: - coords_dset = f.create_dataset("coords", data=coords) - coords_dset.attrs["patch_size"] = 256 - coords_dset.attrs["patch_level"] = 0 - coords_dset.attrs["patch_size_to_resize_to_for_desired_mpp"] = 224 - - -def test_extract_patch_features_batch_basic(tmp_path, use_gpu, num_workers): - """Test basic batch extraction of patch features from multiple slides.""" - # Create mock input files - num_slides = 3 - patch_h5_paths = [] - slide_paths = [] - output_h5_paths = [] - - for i in range(num_slides): - # Create mock patch coordinates file - patch_h5_path = tmp_path / f"slide{i}_coords.h5" - create_mock_h5_file(patch_h5_path, num_patches=5) - patch_h5_paths.append(str(patch_h5_path)) - - # Mock slide path (doesn't need to exist due to mocking) - slide_paths.append(f"slide{i}.svs") - - # Output path - output_h5_paths.append(str(tmp_path / f"slide{i}_features.h5")) - - # Mock the model, dataset, and process_dataset to avoid loading actual slides - with ( - patch("mussel.utils.feature_extract.get_model_factory") as mock_factory, - patch( - "mussel.utils.feature_extract.WholeSlideImageH5Dataset" - ) as mock_dataset_class, - patch("mussel.utils.feature_extract.process_dataset") as mock_process, - ): - # Mock model - mock_model = MagicMock() - mock_model_fun = MagicMock( - side_effect=lambda x: torch.randn(len(x), 2048) # Return batch of features +from mussel.utils.feature_extract import get_features + + +def _make_mock_model(feature_dim=384): + """Return a mock model with the required interface.""" + mock = MagicMock() + mock.get_preprocessing_fun.return_value = None + mock.get_model_fun.return_value = MagicMock( + side_effect=lambda x: __import__('torch').randn(len(x), feature_dim) + ) + return mock + + +def _base_patches(): + """Context managers that stub out all I/O in get_features.""" + return [ + patch("mussel.utils.feature_extract.get_model_factory"), + patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), + patch("mussel.utils.feature_extract.process_dataset"), + patch("mussel.utils.feature_extract._make_dataloader"), + ] + + +# -- model= parameter ------------------------------------------------------- + +def test_model_factory_not_called_when_model_provided(): + """get_model_factory must not be called when a pre-loaded model is given.""" + coords = np.zeros((10, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + mock_model = _make_mock_model() + + with patch("mussel.utils.feature_extract.get_model_factory") as mock_factory, patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), patch("mussel.utils.feature_extract._make_dataloader"), patch("mussel.utils.feature_extract.process_dataset") as mock_proc: + mock_proc.return_value = MagicMock( + features=np.zeros((10, 384)), labels=np.zeros(10) ) - mock_model.get_model_fun.return_value = mock_model_fun - mock_model.get_preprocessing_fun.return_value = None - mock_factory.return_value = MagicMock( - get_model=MagicMock(return_value=mock_model) + get_features(coords, "slide.svs", attrs, model=mock_model) + + mock_factory.assert_not_called() + + +def test_model_factory_called_when_model_not_provided(): + """get_model_factory must be called when no pre-loaded model is given.""" + coords = np.zeros((10, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + + with patch("mussel.utils.feature_extract.get_model_factory") as mock_factory, patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), patch("mussel.utils.feature_extract._make_dataloader"), patch("mussel.utils.feature_extract.process_dataset") as mock_proc: + mock_model = _make_mock_model() + mock_factory.return_value = MagicMock(get_model=MagicMock(return_value=mock_model)) + mock_proc.return_value = MagicMock( + features=np.zeros((10, 384)), labels=np.zeros(10) ) + get_features(coords, "slide.svs", attrs, model_type=ModelType.CTRANSPATH) - # Mock dataset - mock_dataset = MagicMock() - mock_dataset.__len__.return_value = 5 - mock_dataset_class.return_value = mock_dataset - - # Call the batch extraction function - result_paths = extract_patch_features_batch( - patch_h5_paths=patch_h5_paths, - slide_paths=slide_paths, - output_h5_paths=output_h5_paths, - model_type=ModelType.RESNET50, - model_path=None, - batch_size=32, - use_gpu=use_gpu, - num_workers=num_workers, + mock_factory.assert_called_once() + + +def test_positional_args_unchanged(): + """Existing positional call pattern must still work after adding model= at end.""" + coords = np.zeros((5, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + + with patch("mussel.utils.feature_extract.get_model_factory") as mock_factory, patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), patch("mussel.utils.feature_extract._make_dataloader"), patch("mussel.utils.feature_extract.process_dataset") as mock_proc: + mock_model = _make_mock_model() + mock_factory.return_value = MagicMock(get_model=MagicMock(return_value=mock_model)) + mock_proc.return_value = MagicMock( + features=np.zeros((5, 384)), labels=np.zeros(5) ) + # Classic positional call: (coords, slide_path, attrs, model_type, model_path, batch_size) + # batch_size must not be interpreted as model= + features, labels = get_features( + coords, "slide.svs", attrs, ModelType.CTRANSPATH, None, 32 + ) + assert features.shape == (5, 384) - # Verify model was loaded only once - assert mock_factory.call_count == 1, "Model factory should be called only once" - # Verify dataset was created for each slide - assert ( - mock_dataset_class.call_count == num_slides - ), f"Dataset should be created {num_slides} times, got {mock_dataset_class.call_count}" +def test_model_invalid_interface_raises_type_error(): + """Passing an object without the required methods must raise TypeError immediately.""" + coords = np.zeros((5, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + bad_model = object() - # Verify process_dataset was called for each slide - assert ( - mock_process.call_count == num_slides - ), f"process_dataset should be called {num_slides} times, got {mock_process.call_count}" + with pytest.raises(TypeError, match="get_preprocessing_fun"): + get_features(coords, "slide.svs", attrs, model=bad_model) - # Verify result paths match input - assert result_paths == output_h5_paths +# -- slide_model= parameter ------------------------------------------------ -def test_extract_patch_features_batch_empty_list(): - """Test that batch extraction handles empty input gracefully.""" - result = extract_patch_features_batch( - patch_h5_paths=[], - slide_paths=[], - output_h5_paths=[], - model_type=ModelType.RESNET50, +def test_slide_model_factory_not_called_when_slide_model_provided(): + """get_model_factory must not be loaded for the slide encoder when slide_model is given.""" + coords = np.zeros((10, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + mock_patch_model = _make_mock_model() + mock_slide_model = MagicMock() + mock_slide_model.get_model_fun.return_value = MagicMock( + return_value=__import__('torch').zeros(1, 512) ) - assert result == [], "Should return empty list for empty input" - - -def test_extract_patch_features_batch_single_slide(tmp_path, use_gpu, num_workers): - """Test batch extraction with a single slide (edge case).""" - # Create mock input file - patch_h5_path = tmp_path / "slide_coords.h5" - create_mock_h5_file(patch_h5_path, num_patches=10) - - patch_h5_paths = [str(patch_h5_path)] - slide_paths = ["slide.svs"] - output_h5_paths = [str(tmp_path / "slide_features.h5")] - - with ( - patch("mussel.utils.feature_extract.get_model_factory") as mock_factory, - patch( - "mussel.utils.feature_extract.WholeSlideImageH5Dataset" - ) as mock_dataset_class, - patch("mussel.utils.feature_extract.process_dataset") as mock_process, - ): - # Mock model - mock_model = MagicMock() - mock_model_fun = MagicMock(side_effect=lambda x: torch.randn(len(x), 2048)) - mock_model.get_model_fun.return_value = mock_model_fun - mock_model.get_preprocessing_fun.return_value = None - mock_factory.return_value = MagicMock( - get_model=MagicMock(return_value=mock_model) - ) + call_log = [] + def factory_side_effect(model_type): + call_log.append(model_type) + m = MagicMock() + m.get_model.return_value = mock_patch_model + return m - # Mock dataset - mock_dataset = MagicMock() - mock_dataset.__len__.return_value = 10 - mock_dataset_class.return_value = mock_dataset - - result_paths = extract_patch_features_batch( - patch_h5_paths=patch_h5_paths, - slide_paths=slide_paths, - output_h5_paths=output_h5_paths, - model_type=ModelType.RESNET50, - model_path=None, - batch_size=32, - use_gpu=use_gpu, - num_workers=num_workers, + with patch("mussel.utils.feature_extract.get_model_factory", side_effect=factory_side_effect), patch("mussel.utils.feature_extract.validate_slide_encoder_compatibility"), patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), patch("mussel.utils.feature_extract._make_dataloader"), patch("mussel.utils.feature_extract.process_dataset") as mock_proc: + mock_proc.return_value = MagicMock( + features=np.zeros((10, 384)), labels=np.zeros(10) + ) + get_features( + coords, "slide.svs", attrs, + model=mock_patch_model, + use_slide_encoder=True, + slide_model_type=ModelType.GIGAPATH_SLIDE, + aggregation_method="model", + slide_model=mock_slide_model, ) - # Even with one slide, model should be loaded only once - assert mock_factory.call_count == 1 - assert result_paths == output_h5_paths + # Only the patch encoder factory may be called (for auto-infer check), not the slide encoder + slide_encoder_calls = [t for t in call_log if t == ModelType.GIGAPATH_SLIDE] + assert len(slide_encoder_calls) == 0, "Slide encoder factory must not be called when slide_model is provided" -def test_extract_patch_features_batch_model_reuse(): - """ - Document that batch processing reuses the model across slides. +def test_slide_model_invalid_interface_raises_type_error(): + """Passing a slide_model without get_model_fun must raise TypeError.""" + coords = np.zeros((5, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + bad_slide_model = object() - This is the key benefit of batch processing: - - Old approach: Load model N times for N slides - - New approach: Load model 1 time for N slides + with pytest.raises(TypeError, match="get_model_fun"): + get_features(coords, "slide.svs", attrs, slide_model=bad_slide_model) + + +# -- compatibility validation ----b??----------------------------------------- + +def test_compatibility_validated_with_preloaded_patch_model(): + """validate_slide_encoder_compatibility must be called even with pre-loaded patch model.""" + coords = np.zeros((5, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + mock_patch_model = _make_mock_model() + + with patch("mussel.utils.feature_extract.validate_slide_encoder_compatibility") as mock_validate, patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), patch("mussel.utils.feature_extract._make_dataloader"), patch("mussel.utils.feature_extract.process_dataset") as mock_proc: + mock_proc.return_value = MagicMock( + features=np.zeros((5, 384)), labels=np.zeros(5) + ) + mock_slide = MagicMock() + mock_slide.get_model_fun.return_value = MagicMock(return_value=__import__('torch').zeros(1, 512)) + get_features( + coords, "slide.svs", attrs, + model_type=ModelType.GIGAPATH, + model=mock_patch_model, + use_slide_encoder=True, + slide_model_type=ModelType.GIGAPATH_SLIDE, + aggregation_method="model", + slide_model=mock_slide, + ) - For 100 slides with 2s model load time: - - Old: 100 * 2s = 200s wasted on model loading - - New: 1 * 2s = 2s for model loading - - Savings: 198s (99% reduction in model loading time) - """ - # This is a documentation test to highlight the key benefit - # In production, this translates to significant time savings - pass + mock_validate.assert_called_once_with(ModelType.GIGAPATH, ModelType.GIGAPATH_SLIDE)