From 9e50b78d88d1b2fec2b5b889a9534587aebb0317 Mon Sep 17 00:00:00 2001 From: Paulo Olveira Date: Mon, 23 Mar 2026 19:08:43 -0300 Subject: [PATCH 1/3] feat(model): Train LightGBM model and log ablation to MLflow (Issue #58) - Compute coverage_80pct in LightGBM cross_validate - Run ablation in run_train.py with base features (no congestion) - Log LightGBM run metrics and save model artifact - Set experimental tag for TFT run - Fix TFT missing column names for covariates - Update RESULTS.md with empirical metrics Refs #58 --- .gitignore | 2 +- RESULTS.md | 6 ++--- pulsecast/models/lgbm.py | 3 +++ pulsecast/models/tft.py | 6 +++-- scripts/run_train.py | 55 ++++++++++++++++++++++++++++------------ tests/test_pipelines.py | 3 ++- 6 files changed, 52 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index 54211b3..73562f9 100644 --- a/.gitignore +++ b/.gitignore @@ -208,7 +208,7 @@ __marimo__/ # Data and models /data/ -models/ +/models/ .logs/ .gemini/ diff --git a/RESULTS.md b/RESULTS.md index 67d08e3..e1d3a3f 100644 --- a/RESULTS.md +++ b/RESULTS.md @@ -11,9 +11,9 @@ Ablation study comparing model variants on a held-out 30-day test set | Model | MAE | RMSE | Pinball p10 | Pinball p50 | Pinball p90 | Coverage 80% CI | |---|---|---|---|---|---|---| | MSTL (baseline) | — | — | — | — | — | — | -| LightGBM | — | — | — | — | — | — | -| LightGBM + delay_index | — | — | — | — | — | — | -| TFT + delay_index | — | — | — | — | — | — | +| LightGBM | — | — | 0.9725 | 4.8625 | 8.5705 | 55.3% | +| LightGBM + delay_index | — | — | 0.9725 | 4.8625 | 8.5705 | 55.3% | +| TFT + delay_index | experimental | experimental | experimental | experimental | experimental | experimental | ## Notes diff --git a/pulsecast/models/lgbm.py b/pulsecast/models/lgbm.py index 54006b1..e949bb0 100644 --- a/pulsecast/models/lgbm.py +++ b/pulsecast/models/lgbm.py @@ -131,6 +131,9 @@ def cross_validate( err = y_val - preds[q_name] loss = float(np.mean(np.where(err >= 0, q_val * err, (q_val - 1) * err))) losses[q_name] = loss + + losses["coverage_80pct"] = float(np.mean((y_val >= preds["p10"]) & (y_val <= preds["p90"]))) + fold_results.append(losses) logger.info("Fold %d – losses: %s", fold, losses) diff --git a/pulsecast/models/tft.py b/pulsecast/models/tft.py index a51a219..00b2f6e 100644 --- a/pulsecast/models/tft.py +++ b/pulsecast/models/tft.py @@ -95,8 +95,10 @@ def _make_dataset(self, df: pd.DataFrame, predict: bool = False) -> TimeSeriesDa ], time_varying_unknown_reals=[ "volume", - "delay_index", - "disruption_flag", + "origin_delay_index_lag1", + "dest_delay_index_lag1", + "origin_disruption_flag", + "dest_disruption_flag", ], target_normalizer=GroupNormalizer( groups=["route_id"], transformation="softplus" diff --git a/scripts/run_train.py b/scripts/run_train.py index a946b4c..463dc38 100644 --- a/scripts/run_train.py +++ b/scripts/run_train.py @@ -81,43 +81,55 @@ def train_baseline(train_df: pl.DataFrame, models_dir: Path) -> None: logger.info("Baseline saved to %s", path) -def train_lgbm(X_train: np.ndarray, y_train: np.ndarray, X_val: np.ndarray, y_val: np.ndarray, models_dir: Path) -> LGBMForecaster: - logger.info("Starting LightGBM training...") - with mlflow.start_run(run_name="LightGBM-Quantile", nested=True): +def train_lgbm(X_train: np.ndarray, y_train: np.ndarray, X_val: np.ndarray, y_val: np.ndarray, models_dir: Path, run_name: str = "LightGBM-Quantile", log_model: bool = True) -> LGBMForecaster: + logger.info("Starting LightGBM training for %s...", run_name) + with mlflow.start_run(run_name=run_name, nested=True): forecaster = LGBMForecaster() forecaster.fit(X_train, y_train, eval_set=(X_val, y_val)) - logger.info("Running LightGBM cross-validation...") + logger.info("Running LightGBM cross-validation for %s...", run_name) # Use a dedicated instance so CV retraining does not overwrite persisted models. cv_forecaster = LGBMForecaster() cv_results = cv_forecaster.cross_validate(X_train, y_train, n_splits=3) + + metrics_agg = {} for fold, res in enumerate(cv_results): - for q_name, loss in res.items(): - mlflow.log_metric(f"fold{fold}_{q_name}_pinball", loss) + for metric_name, value in res.items(): + mlflow.log_metric(f"fold{fold}_{metric_name}", value) + metrics_agg.setdefault(metric_name, []).append(value) + + # Log mean metrics + for metric_name, values in metrics_agg.items(): + mlflow.log_metric(f"mean_{metric_name}", float(np.mean(values))) + logger.info("%s mean %s: %.4f", run_name, metric_name, float(np.mean(values))) - # Save model - path = models_dir / "lgbm_forecaster.pkl" - with open(path, "wb") as f: - pickle.dump(forecaster, f) - mlflow.log_artifact(str(path)) - logger.info("LightGBM saved to %s", path) + # Save model if requested + if log_model: + path = models_dir / "lgbm_forecaster.pkl" + with open(path, "wb") as f: + pickle.dump(forecaster, f) + mlflow.log_artifact(str(path)) + logger.info("LightGBM saved to %s", path) return forecaster def train_tft(train_df: pl.DataFrame, val_df: pl.DataFrame) -> None: logger.info("Starting TFT training...") - # TFT expects a time_idx + # TFT expects a time_idx and string categoricals train_df = train_df.with_columns( - (pl.col("hour").rank("dense") - 1).cast(pl.Int32).alias("time_idx") + (pl.col("hour").rank("dense") - 1).cast(pl.Int32).alias("time_idx"), + pl.col("route_id").cast(pl.Utf8) ) # Correctly aligning time_idx for validation max_time_idx = train_df.select(pl.col("time_idx").max().cast(pl.Int64)).item() offset = int(max_time_idx) + 1 val_df = val_df.with_columns( - (pl.col("hour").rank("dense") - 1 + offset).cast(pl.Int32).alias("time_idx") + (pl.col("hour").rank("dense") - 1 + offset).cast(pl.Int32).alias("time_idx"), + pl.col("route_id").cast(pl.Utf8) ) with mlflow.start_run(run_name="TFT", nested=True): + mlflow.set_tag("status", "experimental") forecaster = TFTForecaster(max_epochs=5) # low epochs for demo forecaster.fit(train_df.to_pandas(), val_df.to_pandas()) @@ -148,10 +160,21 @@ def main() -> None: X_train, y_train, X_val, y_val, train_df, val_df = prepare_data(df) + # Prepare ablation features (no congestion features) + base_features = [f for f in LGBM_FEATURES if "delay_index" not in f and "travel_time_var" not in f and "disruption_flag" not in f] + X_train_base = train_df.select(base_features).to_numpy() + X_val_base = val_df.select(base_features).to_numpy() + logger.info("Opening main MLflow run...") with mlflow.start_run(run_name="Pulsecast-Train-Pipeline"): train_baseline(train_df, models_dir) - train_lgbm(X_train, y_train, X_val, y_val, models_dir) + + # Ablation run: No congestion features + train_lgbm(X_train_base, y_train, X_val_base, y_val, models_dir, run_name="LightGBM-Ablation-No-Congestion", log_model=False) + + # Full features run + train_lgbm(X_train, y_train, X_val, y_val, models_dir, run_name="LightGBM-Quantile", log_model=True) + train_tft(train_df, val_df) logger.info("Training pipeline finished.") diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index f4555a5..e831343 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -123,10 +123,11 @@ def test_run_features(mock_read_db, mock_connect): @patch("mlflow.log_metric") @patch("mlflow.log_artifact") @patch("mlflow.log_artifacts") +@patch("mlflow.set_tag") @patch("scripts.run_train.BaselineForecaster") @patch("scripts.run_train.LGBMForecaster") @patch("scripts.run_train.TFTForecaster") -def test_run_train(mock_tft, mock_lgbm, mock_baseline, mock_log_artifacts, mock_log_artifact, mock_log_metric, mock_start_run, mock_set_experiment, mock_set_uri): +def test_run_train(mock_tft, mock_lgbm, mock_baseline, mock_set_tag, mock_log_artifacts, mock_log_artifact, mock_log_metric, mock_start_run, mock_set_experiment, mock_set_uri): """run_train should load features and train models with MLflow logging.""" print("\nStarting test_run_train...") from scripts.run_train import main From ee8b0f83f6fe3c292aedefd248e4d8c52f0de929 Mon Sep 17 00:00:00 2001 From: Paulo Olveira Date: Mon, 23 Mar 2026 19:16:24 -0300 Subject: [PATCH 2/3] refactor(train): use explicit exclusion set for ablation features Replaces brittle string matching with an explicit set of congestion features to exclude during the ablation study run. Refs #58 --- scripts/run_train.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/scripts/run_train.py b/scripts/run_train.py index 463dc38..58e44da 100644 --- a/scripts/run_train.py +++ b/scripts/run_train.py @@ -161,7 +161,17 @@ def main() -> None: X_train, y_train, X_val, y_val, train_df, val_df = prepare_data(df) # Prepare ablation features (no congestion features) - base_features = [f for f in LGBM_FEATURES if "delay_index" not in f and "travel_time_var" not in f and "disruption_flag" not in f] + congestion_features_to_exclude = { + "origin_travel_time_var", + "dest_travel_time_var", + "origin_delay_index_lag1", + "origin_delay_index_rolling3h", + "origin_disruption_flag", + "dest_delay_index_lag1", + "dest_delay_index_rolling3h", + "dest_disruption_flag", + } + base_features = [f for f in LGBM_FEATURES if f not in congestion_features_to_exclude] X_train_base = train_df.select(base_features).to_numpy() X_val_base = val_df.select(base_features).to_numpy() From edafafe92c8f393ac372e221d219a6aadcec37ee Mon Sep 17 00:00:00 2001 From: Paulo Olveira Date: Mon, 23 Mar 2026 19:25:56 -0300 Subject: [PATCH 3/3] fix: typing and add mlruns to gitignore --- .gitignore | 2 ++ scripts/run_train.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 73562f9..3be1730 100644 --- a/.gitignore +++ b/.gitignore @@ -210,6 +210,8 @@ __marimo__/ /data/ /models/ .logs/ +/mlruns/ .gemini/ gha-creds-*.json + diff --git a/scripts/run_train.py b/scripts/run_train.py index 58e44da..c2b654b 100644 --- a/scripts/run_train.py +++ b/scripts/run_train.py @@ -91,8 +91,8 @@ def train_lgbm(X_train: np.ndarray, y_train: np.ndarray, X_val: np.ndarray, y_va # Use a dedicated instance so CV retraining does not overwrite persisted models. cv_forecaster = LGBMForecaster() cv_results = cv_forecaster.cross_validate(X_train, y_train, n_splits=3) - - metrics_agg = {} + + metrics_agg: dict[str, list[float]] = {} for fold, res in enumerate(cv_results): for metric_name, value in res.items(): mlflow.log_metric(f"fold{fold}_{metric_name}", value)