Skip to content

chenpg2/mosaic-harmonize

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MOSAIC

Multi-site Optimal-transport Shift Alignment with Interval Calibration

PyPI Python License

MOSAIC is a Python package for harmonizing clinical tabular data collected across multiple sites. It combines 1-D optimal transport for distribution alignment, anchor regression for domain-robust prediction, and weighted conformal inference for uncertainty quantification. The three components can be used independently or chained through a single pipeline.

The package was developed for multi-center IVF (in vitro fertilization) outcome prediction, but the methods are general and apply to any multi-site clinical or biomedical dataset with batch effects.

Overview

MOSAIC has three tiers, each usable on its own:

Tier Class What it does
1. Harmonization OTHarmonizer Per-feature quantile-based optimal transport mapping to a reference distribution. Reduces cross-center distribution shift while preserving within-center rank order.
2. Robust learning AnchorEstimator Wraps any sklearn estimator with anchor regression (via anchorboosting) or V-REx reweighting. Penalizes predictions that rely on center-specific patterns.
3. Uncertainty ConformalCalibrator Split conformal prediction with optional covariate-shift correction (Tibshirani et al., NeurIPS 2019). Produces prediction intervals (regression) or prediction sets (classification) with finite-sample coverage guarantees.

MOSAICPipeline chains all three into a single fit / predict interface.

Installation

Core (OT harmonization + anchor regression + conformal):

pip install mosaic-harmonize

With all optional dependencies (LightGBM, anchorboosting, MAPIE, matplotlib):

pip install mosaic-harmonize[full]

Individual extras: boosting, conformal, viz. For development: dev.

Requires Python 3.9+.

Quick start

Full pipeline

from mosaic import MOSAICPipeline
from lightgbm import LGBMRegressor

pipe = MOSAICPipeline(
    harmonizer="ot",
    robust_learner="anchor",
    uncertainty="weighted_conformal",
    base_estimator=LGBMRegressor(),
)

# center_ids: array of site labels, one per row
pipe.fit(X_train, y_train, center_ids=train_centers)

result = pipe.predict(X_test, center_id="new_hospital")
print(result.prediction)       # point predictions
print(result.lower, result.upper)  # 90% prediction intervals

Individual components

from mosaic import OTHarmonizer, AnchorEstimator, ConformalCalibrator

# Tier 1: align distributions
ot = OTHarmonizer(n_quantiles=1000, reference="global")
X_harmonized = ot.fit_transform(X_train, center_ids=train_centers)

# Inspect shift reduction
print(ot.wasserstein_distances())
print(ot.feature_shift_report())

# Tier 2: train a domain-robust model
anchor = AnchorEstimator(base_estimator=LGBMRegressor(), task_type="regression")
anchor.fit(X_harmonized, y_train, anchors=train_centers)
print(f"Best gamma: {anchor.best_gamma_}")
print(f"Cross-center stability: {anchor.stability_score_:.3f}")

# Tier 3: calibrate with conformal prediction
cal = ConformalCalibrator(method="weighted", alpha=0.10)
cal.calibrate(anchor, X_cal, y_cal, X_test=X_test)
result = cal.predict(X_test)
print(f"Interval widths: {(result.upper - result.lower).mean():.2f}")

Save and load

pipe.save("model.mosaic")
pipe = MOSAICPipeline.load("model.mosaic")

Register a new center at inference time

pipe.register_center("hospital_B", X_new_center)
result = pipe.predict(X_query, center_id="hospital_B")

API reference

OTHarmonizer

Parameter Type Default Description
n_quantiles int 1000 Number of quantile points for the OT map
features list[str] or None None Columns to harmonize (None = all numeric)
reference str "global" Reference distribution: "global" or a center name
min_samples int 50 Minimum non-null samples to build a map

Methods: fit(X, center_ids), transform(X, center_id=..., center_ids=...), fit_transform(X, center_ids), wasserstein_distances(), feature_shift_report().

AnchorEstimator

Parameter Type Default Description
base_estimator sklearn estimator or None None Base learner (None = Ridge/LogisticRegression)
gammas list[float] or None [1.5, 3.0, 7.0] Anchor penalty strengths to search
task_type str "auto" "auto", "regression", "binary", or "multiclass"
n_vrex_rounds int 5 V-REx reweighting iterations (fallback mode)

Methods: fit(X, y, anchors, X_val=None, y_val=None), predict(X), predict_proba(X). Properties: best_gamma_, stability_score_.

ConformalCalibrator

Parameter Type Default Description
method str "weighted" "weighted", "standard", or "lac"
alpha float 0.10 Miscoverage level (0.10 = 90% target coverage)

Methods: calibrate(model, X_cal, y_cal, X_test=None), predict(X_test) returning ConformalResult.

MOSAICPipeline

Parameter Type Default Description
harmonizer str or None "ot" "ot" or None
robust_learner str or None "anchor" "anchor" or None
uncertainty str or None "weighted_conformal" "weighted_conformal", "standard", "lac", or None
base_estimator sklearn estimator None Base learner passed to AnchorEstimator

Methods: fit(X_train, y_train, center_ids, X_cal=None, y_cal=None), predict(X, center_id=..., center_ids=...), register_center(name, X_new), diagnose(X, center_id), save(path), load(path).

Benchmarks

Evaluated on a multi-center IVF dataset (334K rows, 5 centers, 15 prediction targets). Full results in benchmarks/results/.

Ablation (Exp 1): each tier adds value

Target Baseline R² +OT +OT+Anchor Full MOSAIC
HCG_Day_E2 -0.665 0.127 0.229 0.229
egg_num 0.210 0.205 0.352 0.352
HCG_Day_Endo -0.775 -0.177 0.120 0.120

OT corrects distribution shift (E2: R² from -0.67 to 0.13). Anchor regression adds further gains for regression targets (egg_num: 0.21 to 0.35).

Cross-center generalization gap (Exp 2)

On an external test center unseen during training, MOSAIC reduces the validation-to-test performance gap by 42-76% for high-shift features (HCG_Day_E2: 75%, HCG_Day_P: 52%, HCG_Day_Endo: 72%).

Conformal coverage (Exp 3)

All 11 regression targets achieve 81-93% empirical coverage at the 90% nominal level. Weighted conformal consistently produces narrower intervals than standard split conformal at comparable coverage.

Comparison with existing methods (Exp 5)

Feature No harmonization Z-score ComBat MOSAIC (OT)
HCG_Day_E2 (R²) -0.665 0.069 -0.383 0.127
Clinical_pregnancy (AUC) 0.838 0.837 0.837 0.839
total_Gn (R²) -0.121 -0.327 -0.034 -0.022

MOSAIC outperforms Z-score and ComBat on high-shift features while maintaining comparable performance on low-shift targets.

Citation

If you use MOSAIC in your research, please cite:

@article{chen2026mosaic,
  title={MOSAIC: Multi-site Optimal-transport Shift Alignment with Interval
         Calibration for Clinical Data Harmonization},
  author={Chen, Peigen},
  journal={npj Digital Medicine},
  year={2026},
  note={Manuscript in preparation}
}

License

Apache-2.0. See LICENSE for details.

About

MOSAIC: Multi-site Optimal-transport Shift Alignment with Interval Calibration for clinical data harmonization

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors