This document provides comprehensive documentation for developers who want to use or extend the HoVer-NeXt inference pipeline programmatically.
pip install -e .See requirements.txt for the complete list. Key dependencies:
- PyTorch >= 2.1.1
- OpenSlide for WSI support
- Zarr for efficient array storage
- segmentation-models-pytorch
- timm (PyTorch Image Models)
from inference.inference import inference_main, get_inference_setup
from inference.post_process import post_process_main
# Setup parameters
params = {
"cp": "lizard_convnextv2_large",
"input": "/path/to/slide.svs",
"output_dir": "results/",
"tta": 4,
"batch_size": 64,
"tile_size": 256,
"overlap": 0.96875,
# ... other parameters
}
# Get model and augmenter
params, models, augmenter, color_aug_fn = get_inference_setup(params)
# Run inference
params, z = inference_main(params, models, augmenter, color_aug_fn)
# Post-process results
z_pp = post_process_main(params, z)Purpose: Main inference pipeline for tile-based prediction on whole slide images, arrays, and standard images.
Key Functions:
inference_main(params, models, augmenter, color_aug_fn): Run inference on a single inputget_inference_setup(params): Initialize models and augmentation modules
Supported Input Types:
- WSI formats: All OpenSlide-supported formats (.svs, .ndpi, .mrxs, etc.)
- CZI format via pylibCZIrw
- Numpy arrays (.npy)
- Standard images (.jpg, .png, .jpeg, .bmp)
Purpose: Convert raw model predictions into final instance segmentation maps.
Key Functions:
post_process_main(params, z): Main post-processing pipeline- Performs tile stitching, overlap resolution, and watershed segmentation
Output Formats:
- Instance segmentation map (zarr compressed)
- Class lookup dictionary (JSON)
- QuPath-compatible TSV files
- Optional GeoJSON polygons
Purpose: Data loading and preprocessing utilities.
Key Classes:
WholeSlideDataset: Dataset for WSI tile extractionNpyDataset: Dataset for numpy array inputsImageDataset: Dataset for standard image formatsczi_wrapper: Wrapper for CZI file support
Key Functions:
normalize_min_max(x, mi, ma): Min-max normalizationcenter_crop(t, croph, cropw): Center cropping utility
Purpose: Color augmentation for histopathology images.
Key Classes:
HedNormalizeTorch: Stain augmentation in HED color spaceRgb2Hed: RGB to HED conversion moduleHed2Rgb: HED to RGB conversion moduleGaussianNoise: Gaussian noise augmentation
Key Functions:
color_augmentations(train, sigma, bias, s, rank): Create augmentation pipeline
Purpose: Geometric transformations for test-time augmentation.
Key Class:
SpatialAugmenter: Reversible geometric augmentation module
Supported Transformations:
- Mirror (horizontal/vertical flipping)
- Translation
- Scaling
- Rotation (90° and arbitrary angles)
- Shearing
- Elastic deformation
Purpose: Model architecture implementation.
Key Classes:
MultiHeadModel: Complete multi-head U-Net modelTimmEncoderFixed: ConvNeXt encoder wrapperUnetDecoder: U-Net decoder with skip connections
Key Functions:
get_model(enc, out_channels_cls, out_channels_inst, pretrained): Create modelload_checkpoint(model, cp_path, device): Load model weights
Purpose: Visualization and export utilities.
Key Functions:
create_geojson(polygons, classids, lookup, params): Export to GeoJSONcreate_tsvs(pcls_out, params): Export centroid TSVs for QuPathcont(x, offset): Extract contour from binary mask
Run inference on a single input file.
Parameters:
params(dict): Configuration parametersmodels(list): List of model instancesaugmenter(SpatialAugmenter): Geometric augmentation modulecolor_aug_fn(torch.nn.Sequential): Color augmentation pipeline
Returns:
params(dict): Updated parametersz(tuple): (instance_predictions, class_predictions) as zarr stores
Example:
params, z = inference_main(params, models, augmenter, color_aug_fn)Post-process raw predictions into instance segmentation.
Parameters:
params(dict): Configuration parametersz(tuple or None): Zarr stores from inference
Returns:
z_pp(zarr.Array): Final instance segmentation map
Example:
z_pp = post_process_main(params, z)Initialize models and augmentation modules.
Parameters:
params(dict): Configuration with 'cp', 'tta', etc.
Returns:
params(dict): Updated parametersmodels(list): List of loaded modelsaugmenter(SpatialAugmenter): Geometric augmentercolor_aug_fn(torch.nn.Sequential): Color augmenter
Example:
params, models, augmenter, color_aug = get_inference_setup(params)Dataset for extracting tiles from whole slide images.
Parameters:
slide_path(str): Path to WSI filetile_size(int): Size of tiles to extract (default: 256)padding_factor(float): Overlap between tiles (default: 0.96875)ratio_object_thresh(float): Minimum tissue ratio (default: 0.3)min_tiss(float): Minimum tissue percentage (default: 0.1)
Methods:
__len__(): Number of tiles__getitem__(idx): Get tile and metadata
Example:
dataset = WholeSlideDataset(
"slide.svs",
tile_size=256,
padding_factor=0.96875
)
dataloader = DataLoader(dataset, batch_size=64, num_workers=16)To add support for a new image format:
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, file_path, tile_size, **kwargs):
# Initialize your reader
self.reader = CustomReader(file_path)
self.tile_size = tile_size
# Calculate tiles
def __len__(self):
return self.num_tiles
def __getitem__(self, idx):
# Extract tile
tile = self.reader.read_tile(idx)
# Return (tile, tile_coords)
return tile, coordsOptimize post-processing for your specific use case:
from inference.post_process_utils import get_pp_params
# Load default parameters
params = get_pp_params(params, optimize=True)
# Customize thresholds
params['best_fg_thresh_cl'] = [0.45] * num_classes # Foreground threshold
params['best_seediness_thresh_cl'] = [0.55] * num_classes # Seed threshold
# Run with custom parameters
z_pp = post_process_main(params, z)import glob
from pathlib import Path
# Get list of files
files = glob.glob("/path/to/slides/*.svs")
for file_path in files:
print(f"Processing {file_path}")
# Update parameters for this file
params['p'] = file_path
params['output_dir'] = f"results/{Path(file_path).stem}/"
# Run pipeline
params, z = inference_main(params, models, augmenter, color_aug_fn)
z_pp = post_process_main(params, z)
# Clean up
if not params['keep_raw']:
# Remove intermediate files
pass# Configure TTA views
params['tta'] = 8 # More views = more robust but slower
# Custom augmentation parameters
from inference.constants import TTA_AUG_PARAMS
custom_aug_params = TTA_AUG_PARAMS.copy()
custom_aug_params['rotate']['prob'] = 1.0 # Always rotate
custom_aug_params['mirror']['prob'] = 1.0 # Always mirror
augmenter = SpatialAugmenter(custom_aug_params)- Train your model using the training repository
- Save checkpoint with this structure:
{
'model_state_dict': model.state_dict(),
# Optional: other training info
}- Create a config.toml file:
[model]
encoder = "convnextv2_large.fcmae_ft_in22k_in1k"
out_channels_cls = 8
out_channels_inst = 5
[dataset]
pannuke = false
mpp = 0.5- Use in inference:
params['cp'] = "/path/to/your/checkpoint/"Add your own metric optimization:
# In post_process_utils.py, extend get_pp_params()
def get_pp_params(params, optimize=True):
# ... existing code ...
if params['metric'] == 'custom':
# Load your optimized parameters
params['best_fg_thresh_cl'] = [0.50] * num_classes
params['best_seediness_thresh_cl'] = [0.60] * num_classes
# ... other parameters
return paramsimport tifffile
# Load results
pinst = zarr.open(os.path.join(output_dir, "pinst_pp.zip"))
# Save as TIFF
tifffile.imwrite(
"instance_map.tif",
np.array(pinst),
compression='lzw'
)from skimage.segmentation import find_boundaries
# Load instance map
pinst = zarr.open(os.path.join(output_dir, "pinst_pp.zip"))
# Extract boundaries
boundaries = find_boundaries(pinst[:], mode='inner')
# Save as binary mask
cv2.imwrite("boundaries.png", boundaries.astype(np.uint8) * 255)# Reduce batch size if running out of memory
params['batch_size'] = 32
# Increase tiling for post-processing
params['pp_tiling'] = 16 # Higher = less memory but slower# Set workers based on CPU cores
import multiprocessing
n_cores = multiprocessing.cpu_count()
params['inf_workers'] = n_cores # Inference dataloader
params['pp_workers'] = n_cores - 1 # Post-processing# Cache WSI to local storage for faster access
params['cache'] = "/tmp/cache/"-
CUDA Out of Memory
- Reduce
batch_size - Increase
pp_tiling - Use smaller model variant (tiny instead of large)
- Reduce
-
Slow Inference
- Check GPU is being used
- Increase
inf_workers - Use faster storage (SSD vs HDD)
-
Missing Tissue Detection
- Adjust
ratio_object_threshandmin_tiss - Check WSI thumbnail quality
- Adjust
For issues and questions:
- GitHub Issues: https://github.com/pathology-data-mining/hover_next_inference/issues
- Documentation: https://github.com/pathology-data-mining/hover_next_inference#readme
- Paper: https://openreview.net/pdf?id=3vmB43oqIO
If you use this code, please cite:
@inproceedings{baumann2024hover,
title={HoVer-NeXt: A Fast Nuclei Segmentation and Classification Pipeline for Next Generation Histopathology},
author={Baumann, Elias and Dislich, Bastian and Rumberger, Josef Lorenz and Nagtegaal, Iris D and Martinez, Maria Rodriguez and Zlobec, Inti},
booktitle={Medical Imaging with Deep Learning},
year={2024}
}