Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 54 additions & 27 deletions mussel/utils/feature_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ def _apply_slide_aggregation(
gpu_device_ids: Optional[List[int]] = None,
coords: Optional[np.ndarray] = None,
patch_size: Optional[int] = None,
slide_model=None,
) -> np.ndarray:
"""Apply slide-level aggregation to patch features.

Expand All @@ -698,6 +699,8 @@ def _apply_slide_aggregation(
coords: Optional numpy array of patch coordinates (required for some slide encoders like GIGAPATH_SLIDE, TITAN_SLIDE).
patch_size: Optional patch size at level 0 (required for TITAN_SLIDE).
If not provided, will be extracted from h5 file 'coords' attributes or default to 256.
slide_model: Optional pre-loaded slide encoder model instance. If provided,
slide_model_type and slide_model_path are ignored.

Returns:
Numpy array of aggregated features.
Expand Down Expand Up @@ -732,12 +735,16 @@ def _apply_slide_aggregation(
if gpu_device_ids:
gpu_device_id = gpu_device_ids

# Load the slide encoder model
model_factory = get_model_factory(slide_model_type)
if model_factory is None:
raise ValueError(f"Slide model type {slide_model_type} not recognized")
model = model_factory.get_model(slide_model_path, use_gpu, gpu_device_id)
model_fun = model.get_model_fun()
# Load or use pre-loaded slide encoder model
if slide_model is not None:
logger.info("using pre-loaded slide model")
model_fun = slide_model.get_model_fun()
else:
model_factory = get_model_factory(slide_model_type)
if model_factory is None:
raise ValueError(f"Slide model type {slide_model_type} not recognized")
slide_model = model_factory.get_model(slide_model_path, use_gpu, gpu_device_id)
model_fun = slide_model.get_model_fun()

# Convert features to tensor and apply model
features_tensor = torch.from_numpy(features).unsqueeze(0) # Add batch dimension
Expand Down Expand Up @@ -831,6 +838,8 @@ def get_features(
slide_model_type: Optional[ModelType] = None,
slide_model_path: Optional[str] = None,
aggregation_method: str = "identity",
model=None,
slide_model=None,
) -> tuple[np.ndarray, np.ndarray]:
"""Extract features from whole slide image tiles.

Expand All @@ -841,7 +850,8 @@ def get_features(
model_type: Type of foundation model to use (default: ModelType.CLIP).
When using model-based aggregation with a slide encoder, this will be automatically
set to the required patch encoder if not already specified correctly.
model_path: Optional path to model weights.
model_path: Optional path to model weights. — useful for batch workflows where the model
is loaded once and reused across many slides.
batch_size: Batch size for feature extraction (default: 64).
use_gpu: Whether to use GPU for inference (default: True).
gpu_device_id: GPU device ID to use.
Expand All @@ -855,16 +865,30 @@ def get_features(
slide_model_path: Optional path to slide encoder model weights.
aggregation_method: Aggregation method when using slide encoder (default: "identity").
Options: "identity", "mean", "max", "model".
model: Optional pre-loaded patch encoder model instance. If provided,
model_type and model_path are ignored.
slide_model: Optional pre-loaded slide encoder model instance. If provided,
slide_model_type and slide_model_path are ignored for slide encoding.

Returns:
Tuple of (features array, labels array).
"""
logger.info("loading model checkpoint")
# Validate pre-loaded model interfaces up-front.
if model is not None:
if not callable(getattr(model, "get_preprocessing_fun", None)) or not callable(
getattr(model, "get_model_fun", None)
):
raise TypeError(
"model must provide callable get_preprocessing_fun() and get_model_fun() methods"
)
if slide_model is not None:
if not callable(getattr(slide_model, "get_model_fun", None)):
raise TypeError("slide_model must provide a callable get_model_fun() method")

if gpu_device_ids:
gpu_device_id = gpu_device_ids

# Auto-set aggregation_method to "model" if slide_model_type is specified
# Auto-set aggregation_method unconditionally so pre-loaded models also trigger slide encoding.
if (
use_slide_encoder
and slide_model_type is not None
Expand All @@ -876,26 +900,28 @@ def get_features(
)
aggregation_method = "model"

# Auto-infer patch encoder from slide encoder if using model-based aggregation
if (
use_slide_encoder
and aggregation_method == "model"
and slide_model_type is not None
):
required_patch_encoder = get_required_patch_encoder(slide_model_type)
if model_type != required_patch_encoder:
logger.info(
f"Auto-selecting patch encoder {required_patch_encoder} "
f"as required by slide encoder {slide_model_type}"
)
model_type = required_patch_encoder
# Validate compatibility
# Validate patch/slide encoder compatibility regardless of how models are provided.
if use_slide_encoder and aggregation_method == "model" and slide_model_type is not None:
if model is None:
# Auto-infer required patch encoder when loading from disk.
required_patch_encoder = get_required_patch_encoder(slide_model_type)
if model_type != required_patch_encoder:
logger.info(
f"Auto-selecting patch encoder {required_patch_encoder} "
f"as required by slide encoder {slide_model_type}"
)
model_type = required_patch_encoder
validate_slide_encoder_compatibility(model_type, slide_model_type)

model_factory = get_model_factory(model_type)
if model_factory is None:
raise ValueError("model not recognized")
model = model_factory.get_model(model_path, use_gpu, gpu_device_id)
# Load patch encoder from disk, or use the pre-loaded instance.
if model is None:
logger.info("loading model checkpoint")
model_factory = get_model_factory(model_type)
if model_factory is None:
raise ValueError("model not recognized")
model = model_factory.get_model(model_path, use_gpu, gpu_device_id)
else:
logger.info("using pre-loaded model")
preprocessing = model.get_preprocessing_fun()

dataset = WholeSlideImageTileCoordDataset(
Expand Down Expand Up @@ -943,6 +969,7 @@ def get_features(
gpu_device_ids=gpu_device_ids,
coords=coords,
patch_size=patch_size,
slide_model=slide_model,
)

return features, labels
Expand Down
Loading
Loading