Multi-site Optimal-transport Shift Alignment with Interval Calibration
MOSAIC is a Python package for harmonizing clinical tabular data collected across multiple sites. It combines 1-D optimal transport for distribution alignment, anchor regression for domain-robust prediction, and weighted conformal inference for uncertainty quantification. The three components can be used independently or chained through a single pipeline.
The package was developed for multi-center IVF (in vitro fertilization) outcome prediction, but the methods are general and apply to any multi-site clinical or biomedical dataset with batch effects.
MOSAIC has three tiers, each usable on its own:
| Tier | Class | What it does |
|---|---|---|
| 1. Harmonization | OTHarmonizer |
Per-feature quantile-based optimal transport mapping to a reference distribution. Reduces cross-center distribution shift while preserving within-center rank order. |
| 2. Robust learning | AnchorEstimator |
Wraps any sklearn estimator with anchor regression (via anchorboosting) or V-REx reweighting. Penalizes predictions that rely on center-specific patterns. |
| 3. Uncertainty | ConformalCalibrator |
Split conformal prediction with optional covariate-shift correction (Tibshirani et al., NeurIPS 2019). Produces prediction intervals (regression) or prediction sets (classification) with finite-sample coverage guarantees. |
MOSAICPipeline chains all three into a single fit / predict interface.
Core (OT harmonization + anchor regression + conformal):
pip install mosaic-harmonizeWith all optional dependencies (LightGBM, anchorboosting, MAPIE, matplotlib):
pip install mosaic-harmonize[full]Individual extras: boosting, conformal, viz. For development: dev.
Requires Python 3.9+.
from mosaic import MOSAICPipeline
from lightgbm import LGBMRegressor
pipe = MOSAICPipeline(
harmonizer="ot",
robust_learner="anchor",
uncertainty="weighted_conformal",
base_estimator=LGBMRegressor(),
)
# center_ids: array of site labels, one per row
pipe.fit(X_train, y_train, center_ids=train_centers)
result = pipe.predict(X_test, center_id="new_hospital")
print(result.prediction) # point predictions
print(result.lower, result.upper) # 90% prediction intervalsfrom mosaic import OTHarmonizer, AnchorEstimator, ConformalCalibrator
# Tier 1: align distributions
ot = OTHarmonizer(n_quantiles=1000, reference="global")
X_harmonized = ot.fit_transform(X_train, center_ids=train_centers)
# Inspect shift reduction
print(ot.wasserstein_distances())
print(ot.feature_shift_report())
# Tier 2: train a domain-robust model
anchor = AnchorEstimator(base_estimator=LGBMRegressor(), task_type="regression")
anchor.fit(X_harmonized, y_train, anchors=train_centers)
print(f"Best gamma: {anchor.best_gamma_}")
print(f"Cross-center stability: {anchor.stability_score_:.3f}")
# Tier 3: calibrate with conformal prediction
cal = ConformalCalibrator(method="weighted", alpha=0.10)
cal.calibrate(anchor, X_cal, y_cal, X_test=X_test)
result = cal.predict(X_test)
print(f"Interval widths: {(result.upper - result.lower).mean():.2f}")pipe.save("model.mosaic")
pipe = MOSAICPipeline.load("model.mosaic")pipe.register_center("hospital_B", X_new_center)
result = pipe.predict(X_query, center_id="hospital_B")| Parameter | Type | Default | Description |
|---|---|---|---|
n_quantiles |
int | 1000 | Number of quantile points for the OT map |
features |
list[str] or None | None | Columns to harmonize (None = all numeric) |
reference |
str | "global" | Reference distribution: "global" or a center name |
min_samples |
int | 50 | Minimum non-null samples to build a map |
Methods: fit(X, center_ids), transform(X, center_id=..., center_ids=...), fit_transform(X, center_ids), wasserstein_distances(), feature_shift_report().
| Parameter | Type | Default | Description |
|---|---|---|---|
base_estimator |
sklearn estimator or None | None | Base learner (None = Ridge/LogisticRegression) |
gammas |
list[float] or None | [1.5, 3.0, 7.0] | Anchor penalty strengths to search |
task_type |
str | "auto" | "auto", "regression", "binary", or "multiclass" |
n_vrex_rounds |
int | 5 | V-REx reweighting iterations (fallback mode) |
Methods: fit(X, y, anchors, X_val=None, y_val=None), predict(X), predict_proba(X). Properties: best_gamma_, stability_score_.
| Parameter | Type | Default | Description |
|---|---|---|---|
method |
str | "weighted" | "weighted", "standard", or "lac" |
alpha |
float | 0.10 | Miscoverage level (0.10 = 90% target coverage) |
Methods: calibrate(model, X_cal, y_cal, X_test=None), predict(X_test) returning ConformalResult.
| Parameter | Type | Default | Description |
|---|---|---|---|
harmonizer |
str or None | "ot" | "ot" or None |
robust_learner |
str or None | "anchor" | "anchor" or None |
uncertainty |
str or None | "weighted_conformal" | "weighted_conformal", "standard", "lac", or None |
base_estimator |
sklearn estimator | None | Base learner passed to AnchorEstimator |
Methods: fit(X_train, y_train, center_ids, X_cal=None, y_cal=None), predict(X, center_id=..., center_ids=...), register_center(name, X_new), diagnose(X, center_id), save(path), load(path).
Evaluated on a multi-center IVF dataset (334K rows, 5 centers, 15 prediction targets). Full results in benchmarks/results/.
| Target | Baseline R² | +OT | +OT+Anchor | Full MOSAIC |
|---|---|---|---|---|
| HCG_Day_E2 | -0.665 | 0.127 | 0.229 | 0.229 |
| egg_num | 0.210 | 0.205 | 0.352 | 0.352 |
| HCG_Day_Endo | -0.775 | -0.177 | 0.120 | 0.120 |
OT corrects distribution shift (E2: R² from -0.67 to 0.13). Anchor regression adds further gains for regression targets (egg_num: 0.21 to 0.35).
On an external test center unseen during training, MOSAIC reduces the validation-to-test performance gap by 42-76% for high-shift features (HCG_Day_E2: 75%, HCG_Day_P: 52%, HCG_Day_Endo: 72%).
All 11 regression targets achieve 81-93% empirical coverage at the 90% nominal level. Weighted conformal consistently produces narrower intervals than standard split conformal at comparable coverage.
| Feature | No harmonization | Z-score | ComBat | MOSAIC (OT) |
|---|---|---|---|---|
| HCG_Day_E2 (R²) | -0.665 | 0.069 | -0.383 | 0.127 |
| Clinical_pregnancy (AUC) | 0.838 | 0.837 | 0.837 | 0.839 |
| total_Gn (R²) | -0.121 | -0.327 | -0.034 | -0.022 |
MOSAIC outperforms Z-score and ComBat on high-shift features while maintaining comparable performance on low-shift targets.
If you use MOSAIC in your research, please cite:
@article{chen2026mosaic,
title={MOSAIC: Multi-site Optimal-transport Shift Alignment with Interval
Calibration for Clinical Data Harmonization},
author={Chen, Peigen},
journal={npj Digital Medicine},
year={2026},
note={Manuscript in preparation}
}
Apache-2.0. See LICENSE for details.