Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,10 @@ __marimo__/

# Data and models
/data/
models/
/models/
.logs/
/mlruns/

.gemini/
gha-creds-*.json

6 changes: 3 additions & 3 deletions RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ Ablation study comparing model variants on a held-out 30-day test set
| Model | MAE | RMSE | Pinball p10 | Pinball p50 | Pinball p90 | Coverage 80% CI |
|---|---|---|---|---|---|---|
| MSTL (baseline) | — | — | — | — | — | — |
| LightGBM | — | — | | | | |
| LightGBM + delay_index | — | — | | | | |
| TFT + delay_index | | | | | | |
| LightGBM | — | — | 0.9725 | 4.8625 | 8.5705 | 55.3% |
| LightGBM + delay_index | — | — | 0.9725 | 4.8625 | 8.5705 | 55.3% |
| TFT + delay_index | experimental | experimental | experimental | experimental | experimental | experimental |

## Notes

Expand Down
3 changes: 3 additions & 0 deletions pulsecast/models/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def cross_validate(
err = y_val - preds[q_name]
loss = float(np.mean(np.where(err >= 0, q_val * err, (q_val - 1) * err)))
losses[q_name] = loss

losses["coverage_80pct"] = float(np.mean((y_val >= preds["p10"]) & (y_val <= preds["p90"])))

fold_results.append(losses)
logger.info("Fold %d – losses: %s", fold, losses)

Expand Down
6 changes: 4 additions & 2 deletions pulsecast/models/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def _make_dataset(self, df: pd.DataFrame, predict: bool = False) -> TimeSeriesDa
],
time_varying_unknown_reals=[
"volume",
"delay_index",
"disruption_flag",
"origin_delay_index_lag1",
"dest_delay_index_lag1",
"origin_disruption_flag",
"dest_disruption_flag",
],
target_normalizer=GroupNormalizer(
groups=["route_id"], transformation="softplus"
Expand Down
65 changes: 49 additions & 16 deletions scripts/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,43 +81,55 @@ def train_baseline(train_df: pl.DataFrame, models_dir: Path) -> None:
logger.info("Baseline saved to %s", path)


def train_lgbm(X_train: np.ndarray, y_train: np.ndarray, X_val: np.ndarray, y_val: np.ndarray, models_dir: Path) -> LGBMForecaster:
logger.info("Starting LightGBM training...")
with mlflow.start_run(run_name="LightGBM-Quantile", nested=True):
def train_lgbm(X_train: np.ndarray, y_train: np.ndarray, X_val: np.ndarray, y_val: np.ndarray, models_dir: Path, run_name: str = "LightGBM-Quantile", log_model: bool = True) -> LGBMForecaster:
logger.info("Starting LightGBM training for %s...", run_name)
with mlflow.start_run(run_name=run_name, nested=True):
forecaster = LGBMForecaster()
forecaster.fit(X_train, y_train, eval_set=(X_val, y_val))

logger.info("Running LightGBM cross-validation...")
logger.info("Running LightGBM cross-validation for %s...", run_name)
# Use a dedicated instance so CV retraining does not overwrite persisted models.
cv_forecaster = LGBMForecaster()
cv_results = cv_forecaster.cross_validate(X_train, y_train, n_splits=3)

metrics_agg: dict[str, list[float]] = {}
for fold, res in enumerate(cv_results):
for q_name, loss in res.items():
mlflow.log_metric(f"fold{fold}_{q_name}_pinball", loss)
for metric_name, value in res.items():
mlflow.log_metric(f"fold{fold}_{metric_name}", value)
metrics_agg.setdefault(metric_name, []).append(value)

# Log mean metrics
for metric_name, values in metrics_agg.items():
mlflow.log_metric(f"mean_{metric_name}", float(np.mean(values)))
logger.info("%s mean %s: %.4f", run_name, metric_name, float(np.mean(values)))

# Save model
path = models_dir / "lgbm_forecaster.pkl"
with open(path, "wb") as f:
pickle.dump(forecaster, f)
mlflow.log_artifact(str(path))
logger.info("LightGBM saved to %s", path)
# Save model if requested
if log_model:
path = models_dir / "lgbm_forecaster.pkl"
with open(path, "wb") as f:
pickle.dump(forecaster, f)
mlflow.log_artifact(str(path))
logger.info("LightGBM saved to %s", path)
return forecaster


def train_tft(train_df: pl.DataFrame, val_df: pl.DataFrame) -> None:
logger.info("Starting TFT training...")
# TFT expects a time_idx
# TFT expects a time_idx and string categoricals
train_df = train_df.with_columns(
(pl.col("hour").rank("dense") - 1).cast(pl.Int32).alias("time_idx")
(pl.col("hour").rank("dense") - 1).cast(pl.Int32).alias("time_idx"),
pl.col("route_id").cast(pl.Utf8)
)
# Correctly aligning time_idx for validation
max_time_idx = train_df.select(pl.col("time_idx").max().cast(pl.Int64)).item()
offset = int(max_time_idx) + 1
val_df = val_df.with_columns(
(pl.col("hour").rank("dense") - 1 + offset).cast(pl.Int32).alias("time_idx")
(pl.col("hour").rank("dense") - 1 + offset).cast(pl.Int32).alias("time_idx"),
pl.col("route_id").cast(pl.Utf8)
)

with mlflow.start_run(run_name="TFT", nested=True):
mlflow.set_tag("status", "experimental")
forecaster = TFTForecaster(max_epochs=5) # low epochs for demo
forecaster.fit(train_df.to_pandas(), val_df.to_pandas())

Expand Down Expand Up @@ -148,10 +160,31 @@ def main() -> None:

X_train, y_train, X_val, y_val, train_df, val_df = prepare_data(df)

# Prepare ablation features (no congestion features)
congestion_features_to_exclude = {
"origin_travel_time_var",
"dest_travel_time_var",
"origin_delay_index_lag1",
"origin_delay_index_rolling3h",
"origin_disruption_flag",
"dest_delay_index_lag1",
"dest_delay_index_rolling3h",
"dest_disruption_flag",
}
base_features = [f for f in LGBM_FEATURES if f not in congestion_features_to_exclude]
X_train_base = train_df.select(base_features).to_numpy()
X_val_base = val_df.select(base_features).to_numpy()

logger.info("Opening main MLflow run...")
with mlflow.start_run(run_name="Pulsecast-Train-Pipeline"):
train_baseline(train_df, models_dir)
train_lgbm(X_train, y_train, X_val, y_val, models_dir)

# Ablation run: No congestion features
train_lgbm(X_train_base, y_train, X_val_base, y_val, models_dir, run_name="LightGBM-Ablation-No-Congestion", log_model=False)

# Full features run
train_lgbm(X_train, y_train, X_val, y_val, models_dir, run_name="LightGBM-Quantile", log_model=True)

train_tft(train_df, val_df)
logger.info("Training pipeline finished.")

Expand Down
3 changes: 2 additions & 1 deletion tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ def test_run_features(mock_read_db, mock_connect):
@patch("mlflow.log_metric")
@patch("mlflow.log_artifact")
@patch("mlflow.log_artifacts")
@patch("mlflow.set_tag")
@patch("scripts.run_train.BaselineForecaster")
@patch("scripts.run_train.LGBMForecaster")
@patch("scripts.run_train.TFTForecaster")
def test_run_train(mock_tft, mock_lgbm, mock_baseline, mock_log_artifacts, mock_log_artifact, mock_log_metric, mock_start_run, mock_set_experiment, mock_set_uri):
def test_run_train(mock_tft, mock_lgbm, mock_baseline, mock_set_tag, mock_log_artifacts, mock_log_artifact, mock_log_metric, mock_start_run, mock_set_experiment, mock_set_uri):
"""run_train should load features and train models with MLflow logging."""
print("\nStarting test_run_train...")
from scripts.run_train import main
Expand Down
Loading