diff --git a/mussel/utils/feature_extract.py b/mussel/utils/feature_extract.py index 6b897356..94703674 100644 --- a/mussel/utils/feature_extract.py +++ b/mussel/utils/feature_extract.py @@ -681,6 +681,7 @@ def _apply_slide_aggregation( gpu_device_ids: Optional[List[int]] = None, coords: Optional[np.ndarray] = None, patch_size: Optional[int] = None, + slide_model=None, ) -> np.ndarray: """Apply slide-level aggregation to patch features. @@ -698,6 +699,8 @@ def _apply_slide_aggregation( coords: Optional numpy array of patch coordinates (required for some slide encoders like GIGAPATH_SLIDE, TITAN_SLIDE). patch_size: Optional patch size at level 0 (required for TITAN_SLIDE). If not provided, will be extracted from h5 file 'coords' attributes or default to 256. + slide_model: Optional pre-loaded slide encoder model instance. If provided, + slide_model_type and slide_model_path are ignored. Returns: Numpy array of aggregated features. @@ -732,12 +735,16 @@ def _apply_slide_aggregation( if gpu_device_ids: gpu_device_id = gpu_device_ids - # Load the slide encoder model - model_factory = get_model_factory(slide_model_type) - if model_factory is None: - raise ValueError(f"Slide model type {slide_model_type} not recognized") - model = model_factory.get_model(slide_model_path, use_gpu, gpu_device_id) - model_fun = model.get_model_fun() + # Load or use pre-loaded slide encoder model + if slide_model is not None: + logger.info("using pre-loaded slide model") + model_fun = slide_model.get_model_fun() + else: + model_factory = get_model_factory(slide_model_type) + if model_factory is None: + raise ValueError(f"Slide model type {slide_model_type} not recognized") + slide_model = model_factory.get_model(slide_model_path, use_gpu, gpu_device_id) + model_fun = slide_model.get_model_fun() # Convert features to tensor and apply model features_tensor = torch.from_numpy(features).unsqueeze(0) # Add batch dimension @@ -831,6 +838,8 @@ def get_features( slide_model_type: Optional[ModelType] = None, slide_model_path: Optional[str] = None, aggregation_method: str = "identity", + model=None, + slide_model=None, ) -> tuple[np.ndarray, np.ndarray]: """Extract features from whole slide image tiles. @@ -841,7 +850,8 @@ def get_features( model_type: Type of foundation model to use (default: ModelType.CLIP). When using model-based aggregation with a slide encoder, this will be automatically set to the required patch encoder if not already specified correctly. - model_path: Optional path to model weights. + model_path: Optional path to model weights. — useful for batch workflows where the model + is loaded once and reused across many slides. batch_size: Batch size for feature extraction (default: 64). use_gpu: Whether to use GPU for inference (default: True). gpu_device_id: GPU device ID to use. @@ -855,16 +865,30 @@ def get_features( slide_model_path: Optional path to slide encoder model weights. aggregation_method: Aggregation method when using slide encoder (default: "identity"). Options: "identity", "mean", "max", "model". + model: Optional pre-loaded patch encoder model instance. If provided, + model_type and model_path are ignored. + slide_model: Optional pre-loaded slide encoder model instance. If provided, + slide_model_type and slide_model_path are ignored for slide encoding. Returns: Tuple of (features array, labels array). """ - logger.info("loading model checkpoint") + # Validate pre-loaded model interfaces up-front. + if model is not None: + if not callable(getattr(model, "get_preprocessing_fun", None)) or not callable( + getattr(model, "get_model_fun", None) + ): + raise TypeError( + "model must provide callable get_preprocessing_fun() and get_model_fun() methods" + ) + if slide_model is not None: + if not callable(getattr(slide_model, "get_model_fun", None)): + raise TypeError("slide_model must provide a callable get_model_fun() method") if gpu_device_ids: gpu_device_id = gpu_device_ids - # Auto-set aggregation_method to "model" if slide_model_type is specified + # Auto-set aggregation_method unconditionally so pre-loaded models also trigger slide encoding. if ( use_slide_encoder and slide_model_type is not None @@ -876,26 +900,28 @@ def get_features( ) aggregation_method = "model" - # Auto-infer patch encoder from slide encoder if using model-based aggregation - if ( - use_slide_encoder - and aggregation_method == "model" - and slide_model_type is not None - ): - required_patch_encoder = get_required_patch_encoder(slide_model_type) - if model_type != required_patch_encoder: - logger.info( - f"Auto-selecting patch encoder {required_patch_encoder} " - f"as required by slide encoder {slide_model_type}" - ) - model_type = required_patch_encoder - # Validate compatibility + # Validate patch/slide encoder compatibility regardless of how models are provided. + if use_slide_encoder and aggregation_method == "model" and slide_model_type is not None: + if model is None: + # Auto-infer required patch encoder when loading from disk. + required_patch_encoder = get_required_patch_encoder(slide_model_type) + if model_type != required_patch_encoder: + logger.info( + f"Auto-selecting patch encoder {required_patch_encoder} " + f"as required by slide encoder {slide_model_type}" + ) + model_type = required_patch_encoder validate_slide_encoder_compatibility(model_type, slide_model_type) - model_factory = get_model_factory(model_type) - if model_factory is None: - raise ValueError("model not recognized") - model = model_factory.get_model(model_path, use_gpu, gpu_device_id) + # Load patch encoder from disk, or use the pre-loaded instance. + if model is None: + logger.info("loading model checkpoint") + model_factory = get_model_factory(model_type) + if model_factory is None: + raise ValueError("model not recognized") + model = model_factory.get_model(model_path, use_gpu, gpu_device_id) + else: + logger.info("using pre-loaded model") preprocessing = model.get_preprocessing_fun() dataset = WholeSlideImageTileCoordDataset( @@ -943,6 +969,7 @@ def get_features( gpu_device_ids=gpu_device_ids, coords=coords, patch_size=patch_size, + slide_model=slide_model, ) return features, labels diff --git a/tests/mussel/utils/test_feature_extract.py b/tests/mussel/utils/test_feature_extract.py index 76f55ae4..81271e02 100644 --- a/tests/mussel/utils/test_feature_extract.py +++ b/tests/mussel/utils/test_feature_extract.py @@ -1,175 +1,172 @@ -"""Tests for feature extraction utilities, particularly batch processing.""" +"""Tests for get_features pre-loaded model parameters.""" -import os -import tempfile -from pathlib import Path from unittest.mock import MagicMock, patch -import h5py import numpy as np import pytest -import torch from mussel.models import ModelType -from mussel.utils.feature_extract import extract_patch_features_batch - - -def create_mock_h5_file(h5_path, num_patches=10): - """Create a mock HDF5 file with patch coordinates.""" - coords = np.array([[i * 256, i * 256] for i in range(num_patches)]) - - with h5py.File(h5_path, "w") as f: - coords_dset = f.create_dataset("coords", data=coords) - coords_dset.attrs["patch_size"] = 256 - coords_dset.attrs["patch_level"] = 0 - coords_dset.attrs["patch_size_to_resize_to_for_desired_mpp"] = 224 - - -def test_extract_patch_features_batch_basic(tmp_path, use_gpu, num_workers): - """Test basic batch extraction of patch features from multiple slides.""" - # Create mock input files - num_slides = 3 - patch_h5_paths = [] - slide_paths = [] - output_h5_paths = [] - - for i in range(num_slides): - # Create mock patch coordinates file - patch_h5_path = tmp_path / f"slide{i}_coords.h5" - create_mock_h5_file(patch_h5_path, num_patches=5) - patch_h5_paths.append(str(patch_h5_path)) - - # Mock slide path (doesn't need to exist due to mocking) - slide_paths.append(f"slide{i}.svs") - - # Output path - output_h5_paths.append(str(tmp_path / f"slide{i}_features.h5")) - - # Mock the model, dataset, and process_dataset to avoid loading actual slides - with ( - patch("mussel.utils.feature_extract.get_model_factory") as mock_factory, - patch( - "mussel.utils.feature_extract.WholeSlideImageH5Dataset" - ) as mock_dataset_class, - patch("mussel.utils.feature_extract.process_dataset") as mock_process, - ): - # Mock model - mock_model = MagicMock() - mock_model_fun = MagicMock( - side_effect=lambda x: torch.randn(len(x), 2048) # Return batch of features +from mussel.utils.feature_extract import get_features + + +def _make_mock_model(feature_dim=384): + """Return a mock model with the required interface.""" + mock = MagicMock() + mock.get_preprocessing_fun.return_value = None + mock.get_model_fun.return_value = MagicMock( + side_effect=lambda x: __import__('torch').randn(len(x), feature_dim) + ) + return mock + + +def _base_patches(): + """Context managers that stub out all I/O in get_features.""" + return [ + patch("mussel.utils.feature_extract.get_model_factory"), + patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), + patch("mussel.utils.feature_extract.process_dataset"), + patch("mussel.utils.feature_extract._make_dataloader"), + ] + + +# -- model= parameter ------------------------------------------------------- + +def test_model_factory_not_called_when_model_provided(): + """get_model_factory must not be called when a pre-loaded model is given.""" + coords = np.zeros((10, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + mock_model = _make_mock_model() + + with patch("mussel.utils.feature_extract.get_model_factory") as mock_factory, patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), patch("mussel.utils.feature_extract._make_dataloader"), patch("mussel.utils.feature_extract.process_dataset") as mock_proc: + mock_proc.return_value = MagicMock( + features=np.zeros((10, 384)), labels=np.zeros(10) ) - mock_model.get_model_fun.return_value = mock_model_fun - mock_model.get_preprocessing_fun.return_value = None - mock_factory.return_value = MagicMock( - get_model=MagicMock(return_value=mock_model) + get_features(coords, "slide.svs", attrs, model=mock_model) + + mock_factory.assert_not_called() + + +def test_model_factory_called_when_model_not_provided(): + """get_model_factory must be called when no pre-loaded model is given.""" + coords = np.zeros((10, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + + with patch("mussel.utils.feature_extract.get_model_factory") as mock_factory, patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), patch("mussel.utils.feature_extract._make_dataloader"), patch("mussel.utils.feature_extract.process_dataset") as mock_proc: + mock_model = _make_mock_model() + mock_factory.return_value = MagicMock(get_model=MagicMock(return_value=mock_model)) + mock_proc.return_value = MagicMock( + features=np.zeros((10, 384)), labels=np.zeros(10) ) + get_features(coords, "slide.svs", attrs, model_type=ModelType.CTRANSPATH) - # Mock dataset - mock_dataset = MagicMock() - mock_dataset.__len__.return_value = 5 - mock_dataset_class.return_value = mock_dataset - - # Call the batch extraction function - result_paths = extract_patch_features_batch( - patch_h5_paths=patch_h5_paths, - slide_paths=slide_paths, - output_h5_paths=output_h5_paths, - model_type=ModelType.RESNET50, - model_path=None, - batch_size=32, - use_gpu=use_gpu, - num_workers=num_workers, + mock_factory.assert_called_once() + + +def test_positional_args_unchanged(): + """Existing positional call pattern must still work after adding model= at end.""" + coords = np.zeros((5, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + + with patch("mussel.utils.feature_extract.get_model_factory") as mock_factory, patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), patch("mussel.utils.feature_extract._make_dataloader"), patch("mussel.utils.feature_extract.process_dataset") as mock_proc: + mock_model = _make_mock_model() + mock_factory.return_value = MagicMock(get_model=MagicMock(return_value=mock_model)) + mock_proc.return_value = MagicMock( + features=np.zeros((5, 384)), labels=np.zeros(5) ) + # Classic positional call: (coords, slide_path, attrs, model_type, model_path, batch_size) + # batch_size must not be interpreted as model= + features, labels = get_features( + coords, "slide.svs", attrs, ModelType.CTRANSPATH, None, 32 + ) + assert features.shape == (5, 384) - # Verify model was loaded only once - assert mock_factory.call_count == 1, "Model factory should be called only once" - # Verify dataset was created for each slide - assert ( - mock_dataset_class.call_count == num_slides - ), f"Dataset should be created {num_slides} times, got {mock_dataset_class.call_count}" +def test_model_invalid_interface_raises_type_error(): + """Passing an object without the required methods must raise TypeError immediately.""" + coords = np.zeros((5, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + bad_model = object() - # Verify process_dataset was called for each slide - assert ( - mock_process.call_count == num_slides - ), f"process_dataset should be called {num_slides} times, got {mock_process.call_count}" + with pytest.raises(TypeError, match="get_preprocessing_fun"): + get_features(coords, "slide.svs", attrs, model=bad_model) - # Verify result paths match input - assert result_paths == output_h5_paths +# -- slide_model= parameter ------------------------------------------------ -def test_extract_patch_features_batch_empty_list(): - """Test that batch extraction handles empty input gracefully.""" - result = extract_patch_features_batch( - patch_h5_paths=[], - slide_paths=[], - output_h5_paths=[], - model_type=ModelType.RESNET50, +def test_slide_model_factory_not_called_when_slide_model_provided(): + """get_model_factory must not be loaded for the slide encoder when slide_model is given.""" + coords = np.zeros((10, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + mock_patch_model = _make_mock_model() + mock_slide_model = MagicMock() + mock_slide_model.get_model_fun.return_value = MagicMock( + return_value=__import__('torch').zeros(1, 512) ) - assert result == [], "Should return empty list for empty input" - - -def test_extract_patch_features_batch_single_slide(tmp_path, use_gpu, num_workers): - """Test batch extraction with a single slide (edge case).""" - # Create mock input file - patch_h5_path = tmp_path / "slide_coords.h5" - create_mock_h5_file(patch_h5_path, num_patches=10) - - patch_h5_paths = [str(patch_h5_path)] - slide_paths = ["slide.svs"] - output_h5_paths = [str(tmp_path / "slide_features.h5")] - - with ( - patch("mussel.utils.feature_extract.get_model_factory") as mock_factory, - patch( - "mussel.utils.feature_extract.WholeSlideImageH5Dataset" - ) as mock_dataset_class, - patch("mussel.utils.feature_extract.process_dataset") as mock_process, - ): - # Mock model - mock_model = MagicMock() - mock_model_fun = MagicMock(side_effect=lambda x: torch.randn(len(x), 2048)) - mock_model.get_model_fun.return_value = mock_model_fun - mock_model.get_preprocessing_fun.return_value = None - mock_factory.return_value = MagicMock( - get_model=MagicMock(return_value=mock_model) - ) + call_log = [] + def factory_side_effect(model_type): + call_log.append(model_type) + m = MagicMock() + m.get_model.return_value = mock_patch_model + return m - # Mock dataset - mock_dataset = MagicMock() - mock_dataset.__len__.return_value = 10 - mock_dataset_class.return_value = mock_dataset - - result_paths = extract_patch_features_batch( - patch_h5_paths=patch_h5_paths, - slide_paths=slide_paths, - output_h5_paths=output_h5_paths, - model_type=ModelType.RESNET50, - model_path=None, - batch_size=32, - use_gpu=use_gpu, - num_workers=num_workers, + with patch("mussel.utils.feature_extract.get_model_factory", side_effect=factory_side_effect), patch("mussel.utils.feature_extract.validate_slide_encoder_compatibility"), patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), patch("mussel.utils.feature_extract._make_dataloader"), patch("mussel.utils.feature_extract.process_dataset") as mock_proc: + mock_proc.return_value = MagicMock( + features=np.zeros((10, 384)), labels=np.zeros(10) + ) + get_features( + coords, "slide.svs", attrs, + model=mock_patch_model, + use_slide_encoder=True, + slide_model_type=ModelType.GIGAPATH_SLIDE, + aggregation_method="model", + slide_model=mock_slide_model, ) - # Even with one slide, model should be loaded only once - assert mock_factory.call_count == 1 - assert result_paths == output_h5_paths + # Only the patch encoder factory may be called (for auto-infer check), not the slide encoder + slide_encoder_calls = [t for t in call_log if t == ModelType.GIGAPATH_SLIDE] + assert len(slide_encoder_calls) == 0, "Slide encoder factory must not be called when slide_model is provided" -def test_extract_patch_features_batch_model_reuse(): - """ - Document that batch processing reuses the model across slides. +def test_slide_model_invalid_interface_raises_type_error(): + """Passing a slide_model without get_model_fun must raise TypeError.""" + coords = np.zeros((5, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + bad_slide_model = object() - This is the key benefit of batch processing: - - Old approach: Load model N times for N slides - - New approach: Load model 1 time for N slides + with pytest.raises(TypeError, match="get_model_fun"): + get_features(coords, "slide.svs", attrs, slide_model=bad_slide_model) + + +# -- compatibility validation ----b??----------------------------------------- + +def test_compatibility_validated_with_preloaded_patch_model(): + """validate_slide_encoder_compatibility must be called even with pre-loaded patch model.""" + coords = np.zeros((5, 2), dtype=np.int32) + attrs = {"patch_size": 256, "patch_level": 0, "mpp": 0.5, + "patch_size_to_resize_to_for_desired_mpp": 224} + mock_patch_model = _make_mock_model() + + with patch("mussel.utils.feature_extract.validate_slide_encoder_compatibility") as mock_validate, patch("mussel.utils.feature_extract.WholeSlideImageTileCoordDataset"), patch("mussel.utils.feature_extract._make_dataloader"), patch("mussel.utils.feature_extract.process_dataset") as mock_proc: + mock_proc.return_value = MagicMock( + features=np.zeros((5, 384)), labels=np.zeros(5) + ) + mock_slide = MagicMock() + mock_slide.get_model_fun.return_value = MagicMock(return_value=__import__('torch').zeros(1, 512)) + get_features( + coords, "slide.svs", attrs, + model_type=ModelType.GIGAPATH, + model=mock_patch_model, + use_slide_encoder=True, + slide_model_type=ModelType.GIGAPATH_SLIDE, + aggregation_method="model", + slide_model=mock_slide, + ) - For 100 slides with 2s model load time: - - Old: 100 * 2s = 200s wasted on model loading - - New: 1 * 2s = 2s for model loading - - Savings: 198s (99% reduction in model loading time) - """ - # This is a documentation test to highlight the key benefit - # In production, this translates to significant time savings - pass + mock_validate.assert_called_once_with(ModelType.GIGAPATH, ModelType.GIGAPATH_SLIDE)