From 54955a9062afa4449c5e42993c476ead81284d62 Mon Sep 17 00:00:00 2001 From: Max Feenstra Date: Fri, 3 Oct 2025 13:24:44 -0400 Subject: [PATCH 01/50] Minimal replacement of Hydra --- cents/utils/config_loader.py | 66 ++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 cents/utils/config_loader.py diff --git a/cents/utils/config_loader.py b/cents/utils/config_loader.py new file mode 100644 index 0000000..7ebd1f1 --- /dev/null +++ b/cents/utils/config_loader.py @@ -0,0 +1,66 @@ +import os +from pathlib import Path +from typing import Any, List, Union + +from omegaconf import OmegaConf + + +Config = Union[dict, Any] + + +def load_yaml(path: Union[str, Path]) -> Config: + """ + Load a YAML file into an OmegaConf object. + """ + return OmegaConf.load(str(path)) + + +def deep_merge(*cfgs: Config) -> Config: + """ + Deep-merge multiple OmegaConf configs into one. + Later arguments override earlier ones. + """ + cfg = OmegaConf.create({}) + for c in cfgs: + if c is None: + continue + cfg = OmegaConf.merge(cfg, c) + return cfg + + +def _coerce_scalar(value: str): + v = value.strip() + if v.lower() in ("true", "false"): + return v.lower() == "true" + try: + return int(v) + except ValueError: + try: + return float(v) + except ValueError: + return v + + +def apply_overrides(cfg: Config, overrides: List[str]) -> Config: + """ + Apply dot-path overrides like ["trainer.max_epochs=10", "dataset.normalize=False"]. + Works on OmegaConf configs. + """ + if not overrides: + return cfg + # materialize to a plain container then recreate, to ensure mutability + cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) + for s in overrides: + if "=" not in s: + continue + key, val = s.split("=", 1) + parts = key.split(".") + cur = cfg + for p in parts[:-1]: + if p not in cur or cur[p] is None: + cur[p] = {} + cur = cur[p] + cur[parts[-1]] = _coerce_scalar(val) + return cfg + + From 98275aa415d41ec1da66eb0cd849d1b20a3def49 Mon Sep 17 00:00:00 2001 From: Max Feenstra Date: Fri, 3 Oct 2025 13:25:28 -0400 Subject: [PATCH 02/50] Remove dependency on Hydra --- cents/config/config.yaml | 24 --------------- cents/config/trainer/acgan.yaml | 2 +- cents/data_generator.py | 28 ++++++++--------- cents/datasets/pecanstreet.py | 9 +++--- cents/datasets/timeseries_dataset.py | 45 +++++++++++++++++---------- cents/models/normalizer.py | 2 +- cents/trainer.py | 39 ++++++++++++++++------- scripts/eval_pretrained.py | 46 +++++++++++++++------------- scripts/train.py | 7 +++-- tests/conftest.py | 25 ++++++--------- 10 files changed, 112 insertions(+), 115 deletions(-) diff --git a/cents/config/config.yaml b/cents/config/config.yaml index eb69f62..e9a3532 100644 --- a/cents/config/config.yaml +++ b/cents/config/config.yaml @@ -1,31 +1,7 @@ -defaults: - - model: null - - dataset: pecanstreet - - evaluator: default - - trainer: null - - _self_ - device: auto job_name: ${model.name}_${dataset.name}_${dataset.user_group} run_dir: outputs/${job_name}/${now:%Y-%m-%d_%H-%M-%S} model_ckpt: null -hydra: - job_logging: - version: 1 - formatters: - simple: - format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - handlers: - console: - class: logging.StreamHandler - formatter: simple - level: INFO - root: - handlers: [console] - level: INFO - run: - dir: ${run_dir} - wandb: enabled: false diff --git a/cents/config/trainer/acgan.yaml b/cents/config/trainer/acgan.yaml index 3034158..4795c3e 100644 --- a/cents/config/trainer/acgan.yaml +++ b/cents/config/trainer/acgan.yaml @@ -2,7 +2,7 @@ precision: "16-mixed" accelerator: auto devices: auto strategy: ddp_find_unused_parameters_true -max_epochs: 5000 +max_epochs: 1000 batch_size: 1024 sampling_batch_size: 4096 gradient_accumulate_every: 1 diff --git a/cents/data_generator.py b/cents/data_generator.py index a77aa18..e0d613b 100644 --- a/cents/data_generator.py +++ b/cents/data_generator.py @@ -7,8 +7,7 @@ import pytorch_lightning as pl import torch from huggingface_hub import hf_hub_download -from hydra import compose, initialize_config_dir -from omegaconf import DictConfig, ListConfig +from omegaconf import DictConfig, ListConfig, OmegaConf import cents.models from cents.datasets.utils import convert_generated_data_to_df @@ -19,6 +18,7 @@ get_normalizer_training_config, parse_dims_from_name, ) +from cents.utils.config_loader import load_yaml, apply_overrides PKG_ROOT = Path(__file__).resolve().parent CONF_DIR = PKG_ROOT / "config" @@ -76,23 +76,21 @@ def __init__( def _default_cfg(self) -> DictConfig: """ - Load the default Hydra config for this model_name. - - Returns: - Composed DictConfig from 'config/config.yaml'. + Build a minimal default config (model + dataset) without Hydra. """ - # Extract dimensions from model name dims = parse_dims_from_name(self.model_name) time_series_dims = int(dims) - with initialize_config_dir(str(CONF_DIR), version_base=None): - return compose( - config_name="config", - overrides=[ - f"model={self.model_type}", - f"dataset.time_series_dims={time_series_dims}", - ], - ) + model_cfg = load_yaml(CONF_DIR / "model" / f"{self.model_type}.yaml") + dataset_cfg = load_yaml(CONF_DIR / "dataset" / "default.yaml") + dataset_cfg = apply_overrides( + dataset_cfg, [f"time_series_dims={time_series_dims}"] + ) + + cfg = OmegaConf.create({}) + cfg.model = model_cfg + cfg.dataset = dataset_cfg + return cfg def set_dataset_spec( self, dataset_cfg: DictConfig, ctx_codes: Dict[str, Dict[int, str]] diff --git a/cents/datasets/pecanstreet.py b/cents/datasets/pecanstreet.py index 866921d..e25d627 100644 --- a/cents/datasets/pecanstreet.py +++ b/cents/datasets/pecanstreet.py @@ -4,8 +4,8 @@ import numpy as np import pandas as pd -from hydra import compose, initialize_config_dir from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides from cents.datasets.timeseries_dataset import TimeSeriesDataset @@ -51,10 +51,9 @@ def __init__( FileNotFoundError: If required CSV files are missing. """ if cfg is None: - with initialize_config_dir( - config_dir=os.path.join(ROOT_DIR, "config/dataset"), version_base=None - ): - cfg = compose(config_name="pecanstreet", overrides=overrides) + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "pecanstreet.yaml")) + if overrides: + cfg = apply_overrides(cfg, overrides) self.cfg = cfg self.name = cfg.name diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 4f70776..6ce08fd 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -8,13 +8,13 @@ import pandas as pd import pytorch_lightning as pl import torch -from hydra import compose, initialize_config_dir from pytorch_lightning.callbacks import ModelCheckpoint from sklearn.cluster import KMeans from torch.utils.data import DataLoader, Dataset from cents.datasets.utils import encode_context_variables from cents.models.normalizer import Normalizer +from cents.utils.config_loader import load_yaml, apply_overrides from cents.utils.utils import _ckpt_name, get_normalizer_training_config ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -63,20 +63,18 @@ def __init__( # Load dataset-level config if not already set if not hasattr(self, "cfg"): - with initialize_config_dir( - config_dir=os.path.join(ROOT_DIR, "config", "dataset"), - version_base=None, - ): - overrides = [ - f"seq_len={seq_len}", - f"time_series_dims={len(self.time_series_column_names)}", - ] - cfg = compose(config_name="default", overrides=overrides) - cfg.time_series_columns = self.time_series_column_names - self.numeric_context_bins = cfg.numeric_context_bins - context_vars = self._get_context_var_dict(data) - cfg.context_vars = context_vars - self.cfg = cfg + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "default.yaml")) + # dynamic overrides for required fields + dyn = [ + f"seq_len={seq_len}", + f"time_series_dims={len(self.time_series_column_names)}", + ] + cfg = apply_overrides(cfg, dyn) + cfg.time_series_columns = self.time_series_column_names + self.numeric_context_bins = cfg.numeric_context_bins + context_vars = self._get_context_var_dict(data) + cfg.context_vars = context_vars + self.cfg = cfg self.numeric_context_bins = self.cfg.numeric_context_bins if not hasattr(self, "threshold"): @@ -149,8 +147,21 @@ def __getitem__(self, idx: int): } return timeseries, context_vars_dict + def __getstate__(self): + """ + Make the dataset picklable for DataLoader worker processes by dropping + non-picklable attributes (e.g., attached Lightning normalizer module). + + Note: Training workers only need access to the preprocessed `self.data`. + The normalizer is not used during batching, so it is safe to omit here. + """ + state = self.__dict__.copy() + if state.get("_normalizer", None) is not None: + state["_normalizer"] = None + return state + def get_train_dataloader( - self, batch_size: int, shuffle: bool = True, num_workers: int = 4 + self, batch_size: int, shuffle: bool = True, num_workers: int = 9 ) -> DataLoader: """ Create a PyTorch DataLoader for training. @@ -164,7 +175,7 @@ def get_train_dataloader( DataLoader: Configured data loader. """ return DataLoader( - self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers + self, batch_size=batch_size, shuffle=shuffle, num_workers=9 ) def split_timeseries(self, df: pd.DataFrame) -> pd.DataFrame: diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 39afd4b..ea9efb4 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -234,7 +234,7 @@ def train_dataloader(self): ds, batch_size=self.normalizer_training_cfg.batch_size, shuffle=True, - num_workers=1, + num_workers=0, ) def _compute_group_stats(self) -> dict: diff --git a/cents/trainer.py b/cents/trainer.py index 70dc8ed..76abe0f 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -3,7 +3,7 @@ import pytorch_lightning as pl import wandb -from hydra import compose, initialize_config_dir +from datetime import datetime from omegaconf import DictConfig, OmegaConf from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.loggers import WandbLogger @@ -13,6 +13,7 @@ from cents.eval.eval import Evaluator from cents.models.registry import get_model_cls from cents.utils.utils import get_normalizer_training_config +from cents.utils.config_loader import load_yaml, apply_overrides PKG_ROOT = Path(__file__).resolve().parent CONF_DIR = PKG_ROOT / "config" @@ -135,22 +136,38 @@ def evaluate(self, **kwargs) -> Dict: def _compose_cfg(self, ov: List[str]) -> DictConfig: """ - Compose the full Hydra configuration by merging defaults, - dataset-specific config, and any user overrides. + Compose configuration by loading YAMLs and applying overrides. - Args: - ov: List of Hydra-style overrides. - - Returns: - OmegaConf DictConfig. + Structure: + cfg.model <- config/model/{model_type}.yaml + cfg.trainer <- config/trainer/{model_type}.yaml + cfg.dataset <- provided dataset.cfg (if any) """ - base_ov = [f"model={self.model_type}", f"trainer={self.model_type}"] - with initialize_config_dir(str(CONF_DIR), version_base=None): - cfg = compose(config_name="config", overrides=base_ov + ov) + model_cfg = load_yaml(CONF_DIR / "model" / f"{self.model_type}.yaml") + trainer_cfg = load_yaml(CONF_DIR / "trainer" / f"{self.model_type}.yaml") + + cfg = OmegaConf.create({}) + cfg.model = model_cfg + cfg.trainer = trainer_cfg + if self.dataset is not None: cfg.dataset = OmegaConf.create( OmegaConf.to_container(self.dataset.cfg, resolve=True) ) + + cfg = apply_overrides(cfg, ov) + + # Ensure required top-level fields exist without Hydra + if not hasattr(cfg, "device"): + cfg.device = "auto" + if not hasattr(cfg, "job_name"): + ds_name = getattr(cfg, "dataset", {}).get("name", "dataset") if isinstance(getattr(cfg, "dataset", {}), dict) else getattr(cfg.dataset, "name", "dataset") + ds_group = getattr(cfg, "dataset", {}).get("user_group", "all") if isinstance(getattr(cfg, "dataset", {}), dict) else getattr(cfg.dataset, "user_group", "all") + model_name = getattr(cfg.model, "name", self.model_type) + cfg.job_name = f"{model_name}_{ds_name}_{ds_group}" + if not hasattr(cfg, "run_dir") or not cfg.run_dir: + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + cfg.run_dir = str(PKG_ROOT / "outputs" / cfg.job_name / timestamp) return cfg def _instantiate_model(self): diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 5b4f0ea..12606b6 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -1,24 +1,19 @@ import logging from datetime import datetime -from pathlib import Path -from hydra import compose, initialize_config_dir +import wandb from omegaconf import OmegaConf -import wandb from cents.data_generator import DataGenerator from cents.datasets.pecanstreet import PecanStreetDataset from cents.eval.eval import Evaluator +from cents.utils.config_loader import load_yaml + MODEL_KEY = "Watts_1_2D" -OVERRIDES = [ - "dataset.user_group=pv_users", - "dataset.time_series_dims=2", - "evaluator.eval_disentanglement=False", - "wandb.enabled=True", - "wandb.project=cents", - "wandb.entity=michael-fuest-technical-university-of-munich", - f"wandb.name=eval_{MODEL_KEY}_{datetime.now().strftime('%Y%m%d-%H%M%S')}_dim2", +DATASET_OVERRIDES = [ + "user_group=pv_users", + "time_series_dims=2", ] @@ -31,21 +26,28 @@ def main() -> None: wandb.init( project="cents", name=f"{MODEL_KEY}-eval-only-run_{datetime.now().strftime('%Y%m%d-%H%M%S')}", - entity="michael-fuest-technical-university-of-munich", + entity="pmfeen-massachusetts-institute-of-technology", ) - CONF_DIR = Path(__file__).resolve().parents[1] / "cents" / "config" - with initialize_config_dir(str(CONF_DIR), version_base=None): - cfg = compose(config_name="config", overrides=[f"model=acgan"] + OVERRIDES) - - ds_overrides = [ - o.split("dataset.")[1] for o in OVERRIDES if o.startswith("dataset.") - ] - dataset = PecanStreetDataset(overrides=ds_overrides) - cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) + # Dataset with simple overrides (no Hydra) + dataset = PecanStreetDataset(overrides=DATASET_OVERRIDES) + + # Build a minimal cfg for evaluator and generator + eval_cfg = load_yaml("cents/config/evaluator/default.yaml") + top_cfg = load_yaml("cents/config/config.yaml") + cfg = OmegaConf.create({}) + cfg.evaluator = eval_cfg + cfg.wandb = top_cfg.get("wandb", {}) + cfg.device = top_cfg.get("device", "auto") + cfg.model = OmegaConf.create({"name": MODEL_KEY}) + cfg.dataset = OmegaConf.create( + OmegaConf.to_container(dataset.cfg, resolve=True) + ) # Use the fixed checkpoint with DataGenerator - gen = DataGenerator(MODEL_KEY, cfg=cfg) + gen = DataGenerator(MODEL_KEY) + gen.set_dataset_spec(cfg.dataset, dataset.get_context_var_codes()) + results = Evaluator(cfg, dataset).evaluate_model(data_generator=gen) print(results) diff --git a/scripts/train.py b/scripts/train.py index 8bada3f..7773137 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -5,18 +5,19 @@ def main() -> None: - MODEL_NAME = "acgan" + MODEL_NAME = "diffusion_ts" CR_LOSS_WEIGHT = 0.1 TC_LOSS_WEIGHT = 0.1 - dataset = PecanStreetDataset(overrides=["user_group=all", "time_series_dims=2"]) + dataset = PecanStreetDataset(overrides=["user_group=all"]) trainer_overrides = [ "trainer.max_epochs=5000", "trainer.strategy=ddp_find_unused_parameters_true", + "trainer.accelerator=cpu", "trainer.eval_after_training=True", "wandb.enabled=True", "wandb.project=cents", - "wandb.entity=michael-fuest-technical-university-of-munich", + "wandb.entity=pmfeen-massachusetts-institute-of-technology", f"model.context_reconstruction_loss_weight={CR_LOSS_WEIGHT}", f"model.tc_loss_weight={TC_LOSS_WEIGHT}", f"wandb.name=training_dai_{MODEL_NAME}_{datetime.now().strftime('%Y%m%d-%H%M%S')}_L{CR_LOSS_WEIGHT}_TC_{TC_LOSS_WEIGHT}_dim2", diff --git a/tests/conftest.py b/tests/conftest.py index e4af31a..ea5af39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,8 @@ import numpy as np import pandas as pd import pytest -from hydra import compose, initialize_config_dir from omegaconf import DictConfig, OmegaConf +from cents.utils.config_loader import load_yaml from cents.datasets.timeseries_dataset import TimeSeriesDataset from cents.trainer import Trainer @@ -135,32 +135,25 @@ def normalized_dataset_2d(raw_df_2d): def load_top_level_config() -> DictConfig: - config_dir = os.path.join(ROOT_DIR, "tests", "test_configs") - with initialize_config_dir(config_dir=config_dir, version_base=None): - cfg = compose(config_name="test_config", overrides=[]) - return cfg + path = os.path.join(ROOT_DIR, "tests", "test_configs", "test_config.yaml") + return load_yaml(path) def load_dataset_config(case: str) -> DictConfig: - config_dir = os.path.join(ROOT_DIR, "tests", "test_configs", "dataset") - with initialize_config_dir(config_dir=config_dir, version_base=None): - ds_cfg = compose(config_name=case, overrides=[]) + path = os.path.join(ROOT_DIR, "tests", "test_configs", "dataset", f"{case}.yaml") + ds_cfg = load_yaml(path) OmegaConf.set_struct(ds_cfg, False) return ds_cfg def load_model_config(model_type: str) -> DictConfig: - config_dir = os.path.join(ROOT_DIR, "tests", "test_configs", "model") - with initialize_config_dir(config_dir=config_dir, version_base=None): - model_cfg = compose(config_name=model_type, overrides=[]) - return model_cfg + path = os.path.join(ROOT_DIR, "tests", "test_configs", "model", f"{model_type}.yaml") + return load_yaml(path) def load_trainer_config(trainer_name: str) -> DictConfig: - config_dir = os.path.join(ROOT_DIR, "tests", "test_configs", "trainer") - with initialize_config_dir(config_dir=config_dir, version_base=None): - trainer_cfg = compose(config_name=trainer_name, overrides=[]) - return trainer_cfg + path = os.path.join(ROOT_DIR, "tests", "test_configs", "trainer", f"{trainer_name}.yaml") + return load_yaml(path) @pytest.fixture From be3e37c07a83a1ef930ce44e88631e67d0726748 Mon Sep 17 00:00:00 2001 From: Max Feenstra Date: Tue, 7 Oct 2025 11:27:14 -0400 Subject: [PATCH 03/50] Added commercial energy dataset; WIP --- .gitignore | 1 + cents/config/dataset/commercial.yaml | 26 ++++++ cents/datasets/commercial.py | 131 +++++++++++++++++++++++++++ scripts/train.py | 5 +- 4 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 cents/config/dataset/commercial.yaml create mode 100644 cents/datasets/commercial.py diff --git a/.gitignore b/.gitignore index f1316a3..be2515a 100644 --- a/.gitignore +++ b/.gitignore @@ -108,6 +108,7 @@ ENV/ # Repository Specific cents/data/* cents/data/pecanstreet/* +cents/data/commercial/* cents/data/custom/ .DS_Store .ipynb_checkpoints diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml new file mode 100644 index 0000000..9d4dea0 --- /dev/null +++ b/cents/config/dataset/commercial.yaml @@ -0,0 +1,26 @@ +name: commercial +geography: null +normalize: True +scale: True +use_learned_normalizer: True +threshold: 8 +seq_len: 24 +time_series_dims: 1 +shuffle: True +path: "./data/commercial/csv" +time_series_columns: ["energy_meter"] +data_columns: ["dataid","energy_meter","timestamp"] +metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sub_primaryspaceusage", "sqm", "sqft", "yearbuilt"] +numeric_context_bins: 5 + +context_vars: # for each desired context variable, add the name and number of categories + year: 2 + month: 12 + weekday: 7 + site_id: 19 + primaryspaceusage: 17 + sub_primaryspaceusage: 105 + sqm: 5 + sqft: 5 + yearbuilt: 5 + diff --git a/cents/datasets/commercial.py b/cents/datasets/commercial.py new file mode 100644 index 0000000..af5631e --- /dev/null +++ b/cents/datasets/commercial.py @@ -0,0 +1,131 @@ +import os +import warnings +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides + +from cents.datasets.timeseries_dataset import TimeSeriesDataset + +warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +class CommercialDataset(TimeSeriesDataset): + def __init__(self, cfg: DictConfig = None, + overrides: Optional[List[str]] = None): + + """ + Initializes the commercial energy dataset. + Original Dataset: https://github.com/buds-lab/building-data-genome-project-2 + Note: This uses the clean version of their data, which already has some preprocessing done. + Args: + cfg (Optional[DictConfig]): Override Hydra config; if None, + load from `config/dataset/commercial.yaml`. + overrides (Optional[List[str]]): Override Hydra config; if None, + load from `config/dataset/commercial.yaml` and apply overrides. + """ + + if cfg is None: + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "commercial.yaml")) + if overrides: + cfg = apply_overrides(cfg, overrides) + + self.cfg = cfg + self.name = cfg.name + self.normalize = cfg.normalize + self.cfg.time_series_columns = ["energy_meter"] + self.geography = cfg.geography + + self._load_data() + self.time_series_column_names = cfg.time_series_columns + self.time_series_dims = len(self.time_series_column_names) + + super().__init__( + data=self.data, + time_series_column_names=self.cfg.time_series_columns, + seq_len=self.cfg.seq_len, + context_var_column_names=self.cfg.context_vars, + normalize=self.cfg.normalize, + scale=self.cfg.scale, + ) + + def _load_data(self): + """ + Loads in metadata and data for the commercial energy dataset. + """ + module_dir = os.path.dirname(os.path.abspath(__file__)) + base_path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) + + metapath = os.path.join(base_path, "metadata.csv") + if not os.path.exists(metapath): + raise FileNotFoundError(f"Metadata file not found at {metapath}") + metadata = pd.read_csv(metapath)[self.cfg.metadata_columns] + + data_path = os.path.join(base_path, "electricity_cleaned.csv") + if not os.path.exists(data_path): + raise FileNotFoundError(f"Data file not found at {data_path}") + data = pd.read_csv(data_path) + + data = data.melt( + id_vars="timestamp", # keep timestamp as is + var_name="dataid", # old column names (id1, id2, etc.) become values in this column + value_name="energy_meter" # their values go here + ) + + data['site_id'] = data['dataid'].str.split('_').str[0] + + if self.geography: + if self.geography not in metadata["site_id"].unique(): + raise ValueError(f"Geography {self.geography} not found in metadata") + data = data[data["site_id"] == self.geography] + metadata = metadata[metadata["site_id"] == self.geography] + + self.data = data + self.metadata = metadata + + + def _preprocess_data(self, data): + """ + Creates sequences of seq_len and merges metadata. Removes any sequences with missing data. + + Args: + data (pd.DataFrame): Raw DataFrame including 'energy_meter' values. + + Returns: + pd.DataFrame: Metadata columns, datetime, year, month, weekday, date_day, and array-valued 'energy_meter'. + """ + data = data.copy() + + data['datetime'] = pd.to_datetime(data['timestamp']) + data = data.dropna(subset=['energy_meter']) # any NaN makes the day shorter -> filtered by size later + data = data.sort_values(by=["dataid", "datetime"]) + + # grouped = data.groupby(["dataid", "datetime", "year", "month", "date_day", "weekday"])["energy_meter"].apply(np.array).reset_index() + grouped = data.groupby(['dataid', pd.Grouper(key='datetime', freq='D')])['energy_meter'] + grouped = grouped.apply(np.asarray).reset_index(name='energy_meter') + + ## Just gonna remove any sequence with any missing values + grouped = grouped[grouped["energy_meter"].apply(len) == self.cfg.seq_len].reset_index(drop=True) + # grouped = grouped[grouped["energy_meter"].apply(lambda x: not np.isnan(x).any())] + + grouped['year'] = grouped['datetime'].dt.year + grouped["month"] = grouped["datetime"].dt.month + grouped["weekday"] = grouped["datetime"].dt.day_name() + grouped["date_day"] = grouped["datetime"].dt.day + + if grouped["energy_meter"].apply(lambda x: np.isnan(x).any()).any(): + raise ValueError("NaN values remain in grouped energy_meter sequences after filtering.") + + + merged = pd.merge(grouped, self.metadata, how="left", left_on="dataid", right_on="building_id").drop(columns=["building_id"]) + merged.sort_values(by=["dataid", "datetime"], inplace=True) + # merged = self._handle_missing_data(merged) + + return merged + + def _handle_missing_data(self, merged): + raise NotImplementedError + + \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py index 7773137..efbbf32 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,14 +1,15 @@ from datetime import datetime from cents.datasets.pecanstreet import PecanStreetDataset +from cents.datasets.commercial import CommercialDataset from cents.trainer import Trainer def main() -> None: - MODEL_NAME = "diffusion_ts" + MODEL_NAME = "acgan" CR_LOSS_WEIGHT = 0.1 TC_LOSS_WEIGHT = 0.1 - dataset = PecanStreetDataset(overrides=["user_group=all"]) + dataset = CommercialDataset(overrides=[]) trainer_overrides = [ "trainer.max_epochs=5000", From 13cde3499c3192a8ef6366d5196aec9e1ebce7ad Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 13 Oct 2025 13:23:35 -0400 Subject: [PATCH 04/50] Fixed training for PecanStreet on GPU, fixed indexing errors in CommercialEnergy --- cents/config/dataset/commercial.yaml | 3 +- cents/config/model/diffusion_ts.yaml | 2 +- cents/config/trainer/acgan.yaml | 4 +- cents/config/trainer/diffusion_ts.yaml | 14 +-- cents/config/trainer/normalizer.yaml | 2 +- cents/datasets/commercial.py | 49 ++++++++- cents/datasets/pecanstreet.py | 2 + cents/datasets/timeseries_dataset.py | 134 +++++++++++++++++++++++-- cents/datasets/utils.py | 9 +- cents/models/context.py | 11 ++ cents/models/normalizer.py | 11 ++ cents/trainer.py | 6 +- scripts/train.py | 30 +++++- 13 files changed, 244 insertions(+), 33 deletions(-) diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index 9d4dea0..52a558a 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -12,9 +12,10 @@ time_series_columns: ["energy_meter"] data_columns: ["dataid","energy_meter","timestamp"] metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sub_primaryspaceusage", "sqm", "sqft", "yearbuilt"] numeric_context_bins: 5 +numeric_cols: ["sqm", "sqft", "yearbuilt"] # Columns to bin as numeric context_vars: # for each desired context variable, add the name and number of categories - year: 2 + year: 5 month: 12 weekday: 7 site_id: 19 diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 396856d..8434637 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -11,7 +11,7 @@ n_steps: 1000 sampling_timesteps: 1000 sampling_batch_size: 4096 loss_type: l1 #l2 -beta_schedule: cosine #linear +beta_schedule: cosine #linear diffusion ts paper uses linear schedule n_heads: 4 mlp_hidden_times: 4 eta: 0.0 diff --git a/cents/config/trainer/acgan.yaml b/cents/config/trainer/acgan.yaml index 4795c3e..665551c 100644 --- a/cents/config/trainer/acgan.yaml +++ b/cents/config/trainer/acgan.yaml @@ -1,6 +1,6 @@ precision: "16-mixed" -accelerator: auto -devices: auto +accelerator: cpu +devices: cpu strategy: ddp_find_unused_parameters_true max_epochs: 1000 batch_size: 1024 diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 4f3e701..f69e77b 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -1,19 +1,19 @@ precision: "16-mixed" -accelerator: auto -devices: auto +accelerator: cpu +devices: cpu strategy: ddp_find_unused_parameters_true -gradient_accumulate_every: 2 +gradient_accumulate_every: 4 log_every_n_steps: 1 -batch_size: 1024 +batch_size: 512 max_epochs: 5000 base_lr: 1e-4 eval_after_training: False checkpoint: - save_last: False - save_top_k: 0 + save_last: True # Save final model + save_top_k: 3 # Save top 3 best models every_n_train_steps: null - every_n_epochs: null + every_n_epochs: 500 # Save every 500 epochs lr_scheduler_params: factor: 0.5 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index 666ded6..6b9fdaa 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,5 +1,5 @@ strategy: auto -accelerator: auto +accelerator: cpu devices: 1 log_every_n_steps: 1 hidden_dim: 512 diff --git a/cents/datasets/commercial.py b/cents/datasets/commercial.py index af5631e..8c78e9d 100644 --- a/cents/datasets/commercial.py +++ b/cents/datasets/commercial.py @@ -14,7 +14,8 @@ class CommercialDataset(TimeSeriesDataset): def __init__(self, cfg: DictConfig = None, - overrides: Optional[List[str]] = None): + overrides: Optional[List[str]] = None, + skip_heavy_processing: bool = False): """ Initializes the commercial energy dataset. @@ -46,10 +47,18 @@ def __init__(self, cfg: DictConfig = None, data=self.data, time_series_column_names=self.cfg.time_series_columns, seq_len=self.cfg.seq_len, - context_var_column_names=self.cfg.context_vars, + context_var_column_names=list(self.cfg.context_vars.keys()), normalize=self.cfg.normalize, scale=self.cfg.scale, + skip_heavy_processing=skip_heavy_processing, ) + + # Force recomputation of context variables to match actual encoded data + print("=== FORCING CONTEXT VAR RECOMPUTATION ===") + context_vars = self._get_context_var_dict(self.data) + print(f"Computed context_vars: {context_vars}") + self.cfg.context_vars = context_vars + print("=========================================") def _load_data(self): """ @@ -61,7 +70,7 @@ def _load_data(self): metapath = os.path.join(base_path, "metadata.csv") if not os.path.exists(metapath): raise FileNotFoundError(f"Metadata file not found at {metapath}") - metadata = pd.read_csv(metapath)[self.cfg.metadata_columns] + metadata = pd.read_csv(metapath, usecols=self.cfg.metadata_columns) data_path = os.path.join(base_path, "electricity_cleaned.csv") if not os.path.exists(data_path): @@ -84,6 +93,18 @@ def _load_data(self): self.data = data self.metadata = metadata + + # Debug: Check raw data before preprocessing + print("=== RAW DATA DEBUG ===") + print(f"Data shape: {data.shape}") + print(f"Metadata shape: {metadata.shape}") + print(f"Context vars in metadata: {[col for col in self.cfg.context_vars.keys() if col in metadata.columns]}") + for col in self.cfg.context_vars.keys(): + if col in metadata.columns: + unique_vals = metadata[col].nunique() + print(f"{col}: {unique_vals} unique values, dtype: {metadata[col].dtype}") + print("{col}: {self.cfg.context_vars[col]}, config unique") + print("======================") def _preprocess_data(self, data): @@ -111,18 +132,36 @@ def _preprocess_data(self, data): # grouped = grouped[grouped["energy_meter"].apply(lambda x: not np.isnan(x).any())] grouped['year'] = grouped['datetime'].dt.year - grouped["month"] = grouped["datetime"].dt.month + grouped["month"] = grouped["datetime"].dt.month_name() grouped["weekday"] = grouped["datetime"].dt.day_name() grouped["date_day"] = grouped["datetime"].dt.day if grouped["energy_meter"].apply(lambda x: np.isnan(x).any()).any(): raise ValueError("NaN values remain in grouped energy_meter sequences after filtering.") - merged = pd.merge(grouped, self.metadata, how="left", left_on="dataid", right_on="building_id").drop(columns=["building_id"]) merged.sort_values(by=["dataid", "datetime"], inplace=True) # merged = self._handle_missing_data(merged) + # Drop rows with NaN values in context variables + context_cols = [col for col in self.cfg.context_vars.keys() if col in merged.columns] + nan_before = merged.shape[0] + merged = merged.dropna(subset=context_cols) + nan_after = merged.shape[0] + if nan_before != nan_after: + print(f"Dropped {nan_before - nan_after} rows with NaN values in context variables") + + # Debug: Check data after merging + print("=== AFTER MERGING DEBUG ===") + print(f"Merged shape: {merged.shape}") + for col in self.cfg.context_vars.keys(): + if col in merged.columns: + unique_vals = merged[col].nunique() + print(f"{col}: {unique_vals} unique values, dtype: {merged[col].dtype}") + if unique_vals < 20: # Only print if not too many + print(f" Values: {merged[col].unique()}") + print("===========================") + return merged def _handle_missing_data(self, merged): diff --git a/cents/datasets/pecanstreet.py b/cents/datasets/pecanstreet.py index e25d627..c61dcab 100644 --- a/cents/datasets/pecanstreet.py +++ b/cents/datasets/pecanstreet.py @@ -33,6 +33,7 @@ def __init__( self, cfg: Optional[DictConfig] = None, overrides: Optional[List[str]] = None, + skip_heavy_processing: bool = False, ): """ Initialize and preprocess the PecanStreet dataset. @@ -84,6 +85,7 @@ def __init__( seq_len=self.cfg.seq_len, normalize=self.cfg.normalize, scale=self.cfg.scale, + skip_heavy_processing=skip_heavy_processing, ) def _load_data(self) -> None: diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 6ce08fd..e7a113c 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -50,6 +50,7 @@ def __init__( normalize: bool = True, scale: bool = True, overrides: Dict[str, Any] = {}, + skip_heavy_processing: bool = False, ): # Initialize basic attributes self.time_series_column_names = ( @@ -94,15 +95,38 @@ def __init__( self.data, self.context_var_codes = self._encode_context_vars(self.data) self._save_context_var_codes() + print("normalizing") if self.normalize: self._init_normalizer() self.data = self._normalizer.transform(self.data) + print("finished normalizing") self.data = self.merge_timeseries_columns(self.data) + print("merged time series columns") self.data = self.data.reset_index() - self.data = self.get_frequency_based_rarity() - self.data = self.get_clustering_based_rarity() - self.data = self.get_combined_rarity() + print("reset index") + + # Check if we should skip heavy processing for DDP + is_ddp_subprocess = self._is_ddp_subprocess() + if skip_heavy_processing or is_ddp_subprocess: + print("skipped rarity computation for DDP compatibility") + self._rarity_computed = False + else: + # Only compute if not in DDP subprocess or if cache doesn't exist + cache_path = self._get_rarity_cache_path() + if self._load_rarity_cache(cache_path): + print("loaded cached rarity features") + self._rarity_computed = True + else: + print("computing rarity features...") + self.data = self.get_frequency_based_rarity() + print("computed frequency based rarity") + self.data = self.get_clustering_based_rarity() + print("computed clustering based rarity") + self.data = self.get_combined_rarity() + print("computed combined rarity") + self._save_rarity_cache(cache_path) + self._rarity_computed = True @abstractmethod def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: @@ -161,7 +185,7 @@ def __getstate__(self): return state def get_train_dataloader( - self, batch_size: int, shuffle: bool = True, num_workers: int = 9 + self, batch_size: int, shuffle: bool = True, num_workers: int = 9, persistent_workers: bool = False ) -> DataLoader: """ Create a PyTorch DataLoader for training. @@ -175,7 +199,7 @@ def get_train_dataloader( DataLoader: Configured data loader. """ return DataLoader( - self, batch_size=batch_size, shuffle=shuffle, num_workers=9 + self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, persistent_workers=persistent_workers ) def split_timeseries(self, df: pd.DataFrame) -> pd.DataFrame: @@ -267,11 +291,33 @@ def _encode_context_vars( Returns: Tuple of encoded DataFrame and mapping codes. """ - return encode_context_variables( + # Debug: Check data before encoding + print("=== BEFORE ENCODING DEBUG ===") + for col in self.context_vars: + if col in data.columns: + unique_vals = data[col].nunique() + print(f"{col}: {unique_vals} unique values, dtype: {data[col].dtype}") + if pd.api.types.is_numeric_dtype(data[col]): + print(f" Range: {data[col].min()} to {data[col].max()}") + print("=============================") + + encoded_data, mapping = encode_context_variables( data=data, columns_to_encode=self.context_vars, bins=self.numeric_context_bins, + numeric_cols=getattr(self.cfg, 'numeric_cols', None), ) + + # Debug: Check data after encoding + print("=== AFTER ENCODING DEBUG ===") + for col in self.context_vars: + if col in encoded_data.columns: + unique_vals = encoded_data[col].nunique() + print(f"{col}: {unique_vals} unique values, dtype: {encoded_data[col].dtype}") + print(f" Range: {encoded_data[col].min()} to {encoded_data[col].max()}") + print("============================") + + return encoded_data, mapping def _get_context_var_dict(self, data: pd.DataFrame) -> Dict[str, int]: """ @@ -284,13 +330,16 @@ def _get_context_var_dict(self, data: pd.DataFrame) -> Dict[str, int]: dict: {var_name: num_categories} """ context_dict = {} + numeric_cols = getattr(self.cfg, 'numeric_cols', []) for var in self.context_vars: - if pd.api.types.is_numeric_dtype(data[var]): + if var in numeric_cols: + print(f"{var}: {self.numeric_context_bins}, config unique") binned = pd.cut( data[var], bins=self.numeric_context_bins, include_lowest=True ) context_dict[var] = binned.nunique() else: + print(f"{var}: {self.context_vars[var]}, config unique") context_dict[var] = data[var].astype("category").nunique() return context_dict @@ -417,6 +466,77 @@ def get_combined_rarity(self) -> pd.DataFrame: ) return self.data + def _ensure_rarity_computed(self): + """ + Compute rarity features lazily to avoid DDP issues. + """ + if not self._rarity_computed: + cache_path = self._get_rarity_cache_path() + if self._load_rarity_cache(cache_path): + print("loaded cached rarity features") + self._rarity_computed = True + else: + print("computing rarity features (deferred for DDP compatibility)") + self.data = self.get_frequency_based_rarity() + print("computed frequency based rarity") + self.data = self.get_clustering_based_rarity() + print("computed clustering based rarity") + self.data = self.get_combined_rarity() + print("computed combined rarity") + self._save_rarity_cache(cache_path) + self._rarity_computed = True + + def _get_rarity_cache_path(self) -> str: + """Get cache file path for rarity features.""" + import hashlib + # Create a hash based on dataset characteristics for cache key + cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{hash(str(sorted(self.context_vars)))}" + cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8] + cache_dir = os.path.join(ROOT_DIR, "cache", "rarity") + os.makedirs(cache_dir, exist_ok=True) + return os.path.join(cache_dir, f"rarity_{cache_hash}.pkl") + + def _save_rarity_cache(self, cache_path: str) -> None: + """Save rarity features to cache.""" + import pickle + rarity_data = { + 'is_frequency_rare': self.data.get('is_frequency_rare'), + 'is_pattern_rare': self.data.get('is_pattern_rare'), + 'is_rare': self.data.get('is_rare'), + 'cluster': self.data.get('cluster') + } + with open(cache_path, 'wb') as f: + pickle.dump(rarity_data, f) + print(f"Saved rarity cache to {cache_path}") + + def _load_rarity_cache(self, cache_path: str) -> bool: + """Load rarity features from cache.""" + import pickle + if not os.path.exists(cache_path): + return False + try: + with open(cache_path, 'rb') as f: + rarity_data = pickle.load(f) + for col, data in rarity_data.items(): + if data is not None: + self.data[col] = data + return True + except Exception as e: + print(f"Failed to load rarity cache: {e}") + return False + + def _is_ddp_subprocess(self) -> bool: + """Detect if we're running in a DDP subprocess.""" + import os + # Check for DDP-related environment variables + ddp_indicators = [ + 'LOCAL_RANK' in os.environ, + 'WORLD_SIZE' in os.environ, + 'RANK' in os.environ, + os.environ.get('CUDA_VISIBLE_DEVICES', '').count(',') > 0, # Multiple GPUs + ] + return any(ddp_indicators) + def _init_normalizer(self) -> None: """ Initialize or load a cached Normalizer for this dataset. diff --git a/cents/datasets/utils.py b/cents/datasets/utils.py index 96a9183..60503ba 100644 --- a/cents/datasets/utils.py +++ b/cents/datasets/utils.py @@ -108,7 +108,7 @@ def split_dataset(dataset: Dataset, val_split: float = 0.1) -> Tuple[Dataset, Da def encode_context_variables( - data: pd.DataFrame, columns_to_encode: List[str], bins: int + data: pd.DataFrame, columns_to_encode: List[str], bins: int, numeric_cols: List[str] = None ) -> Tuple[pd.DataFrame, Dict[str, Dict[int, Any]]]: """ Encodes specified columns in the DataFrame either by binning numeric columns @@ -154,8 +154,13 @@ def encode_context_variables( ] for col in columns_to_encode: - if pd.api.types.is_numeric_dtype(encoded_data[col]): + if numeric_cols and col in numeric_cols: # Numeric column: Perform binning + # Handle NaN values by filling with median before binning + if encoded_data[col].isna().any(): + print(f" Warning: {col} has {encoded_data[col].isna().sum()} NaN values, filling with median") + encoded_data[col] = encoded_data[col].fillna(encoded_data[col].median()) + binned = pd.cut(encoded_data[col], bins=bins, include_lowest=True) encoded_data[col] = binned.cat.codes # Assign integer codes starting from 0 bin_intervals = binned.cat.categories diff --git a/cents/models/context.py b/cents/models/context.py index 5fc908a..e916833 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -64,6 +64,17 @@ def forward( classification_logits (Dict[str, Tensor]): Logits per variable, each of shape (batch_size, num_categories). """ + # Debug: Check actual values being passed + print("=== CONTEXT MODULE FORWARD DEBUG ===") + for name, tensor in context_vars.items(): + print(f"{name}: shape={tensor.shape}, min={tensor.min().item()}, max={tensor.max().item()}") + if name in self.context_embeddings: + embedding_size = self.context_embeddings[name].num_embeddings + print(f" Embedding layer size: {embedding_size}") + if tensor.max().item() >= embedding_size: + print(f" ERROR: Index {tensor.max().item()} >= embedding size {embedding_size}") + print("====================================") + embeddings = [ layer(context_vars[name]) for name, layer in self.context_embeddings.items() ] diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index ea9efb4..27fd68f 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -157,6 +157,14 @@ def __init__( self.time_series_dims = dataset_cfg.time_series_dims self.do_scale = dataset_cfg.scale + # Debug: Check what context_vars are being used + print("=== NORMALIZER INIT DEBUG ===") + print(f"Context vars from config: {dataset_cfg.context_vars}") + print(f"Context vars keys: {list(dataset_cfg.context_vars.keys())}") + for name, num_categories in dataset_cfg.context_vars.items(): + print(f" {name}: {num_categories} categories") + print("=============================") + self.context_module = ContextModule( dataset_cfg.context_vars, 256, @@ -355,6 +363,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: Returns: DataFrame with normalized series in same columns. """ + print("beginning transform") missing = [c for c in self.time_series_cols if c not in df.columns] if missing: @@ -366,6 +375,8 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: f"{self.time_series_cols}." ) + print("nothing missing") + df_out = df.copy() self.eval() with torch.no_grad(): diff --git a/cents/trainer.py b/cents/trainer.py index 76abe0f..59a3c8d 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -83,8 +83,10 @@ def fit(self) -> "Trainer": train_loader = self.dataset.get_train_dataloader( batch_size=self.cfg.trainer.batch_size, shuffle=True, - num_workers=4, + num_workers=6, # Maximum for 7.5GB/10GB GPU usage + persistent_workers=True, ) + print(f"got train loader with {train_loader.num_workers} workers") self.pl_trainer.fit(self.model, train_loader, None) return self @@ -201,7 +203,7 @@ def _instantiate_trainer(self) -> pl.Trainer: f"_dim{self.cfg.dataset.time_series_dims}" ), save_last=tc.checkpoint.save_last, - save_on_train_epoch_end=True, + save_on_train_epoch_end=True, ### Perhaps excessive ) ) callbacks.append(EvalAfterTraining(self.cfg, self.dataset)) diff --git a/scripts/train.py b/scripts/train.py index efbbf32..662cc7f 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,22 +1,35 @@ from datetime import datetime +import pandas as pd from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset from cents.trainer import Trainer +from pytorch_lightning.callbacks import EarlyStopping def main() -> None: - MODEL_NAME = "acgan" + MODEL_NAME = "diffusion_ts" CR_LOSS_WEIGHT = 0.1 TC_LOSS_WEIGHT = 0.1 - dataset = CommercialDataset(overrides=[]) + # Skip heavy processing for DDP compatibility + dataset = CommercialDataset( + skip_heavy_processing=True + ) trainer_overrides = [ "trainer.max_epochs=5000", - "trainer.strategy=ddp_find_unused_parameters_true", - "trainer.accelerator=cpu", + # "trainer.strategy=ddp_spawn", + # "trainer.devices=1,2,3", # Exclude GPU 0, use GPUs 1,2,3 + "trainer.devices=1", "trainer.eval_after_training=True", - "wandb.enabled=True", + # "train.accelerator=gpu", + "train.accelerator=cpu", + "trainer.early_stopping.patience=100", # Stop if no improvement for 100 epochs + "trainer.early_stopping.monitor=train_loss", # Monitor training loss + "trainer.early_stopping.mode=min", # Stop when loss stops decreasing + "trainer.enable_checkpointing=True", # Explicitly enable checkpointing + "trainer.logger=False", # Disable logger to see checkpoint messages + "wandb.enabled=False", "wandb.project=cents", "wandb.entity=pmfeen-massachusetts-institute-of-technology", f"model.context_reconstruction_loss_weight={CR_LOSS_WEIGHT}", @@ -30,8 +43,15 @@ def main() -> None: overrides=trainer_overrides, ) + print("initialized trainer") + trainer.fit() + print("fit") if __name__ == "__main__": + import os + # Enable CUDA debugging for better error messages + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["TORCH_USE_CUDA_DSA"] = "1" main() From 5bdccffb01efa9808c4485c713ec593a243ea201 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 13 Oct 2025 13:46:33 -0400 Subject: [PATCH 05/50] Removed debugging print statements, moved back to GPU --- cents/config/trainer/acgan.yaml | 4 +-- cents/config/trainer/diffusion_ts.yaml | 4 +-- cents/config/trainer/normalizer.yaml | 2 +- cents/datasets/commercial.py | 28 --------------------- cents/datasets/timeseries_dataset.py | 35 -------------------------- cents/models/context.py | 13 +--------- cents/models/normalizer.py | 12 --------- cents/trainer.py | 1 - scripts/train.py | 12 +++------ 9 files changed, 10 insertions(+), 101 deletions(-) diff --git a/cents/config/trainer/acgan.yaml b/cents/config/trainer/acgan.yaml index 665551c..4795c3e 100644 --- a/cents/config/trainer/acgan.yaml +++ b/cents/config/trainer/acgan.yaml @@ -1,6 +1,6 @@ precision: "16-mixed" -accelerator: cpu -devices: cpu +accelerator: auto +devices: auto strategy: ddp_find_unused_parameters_true max_epochs: 1000 batch_size: 1024 diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index f69e77b..7714e8f 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -1,6 +1,6 @@ precision: "16-mixed" -accelerator: cpu -devices: cpu +accelerator: auto +devices: auto strategy: ddp_find_unused_parameters_true gradient_accumulate_every: 4 log_every_n_steps: 1 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index 6b9fdaa..666ded6 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,5 +1,5 @@ strategy: auto -accelerator: cpu +accelerator: auto devices: 1 log_every_n_steps: 1 hidden_dim: 512 diff --git a/cents/datasets/commercial.py b/cents/datasets/commercial.py index 8c78e9d..f9d9e18 100644 --- a/cents/datasets/commercial.py +++ b/cents/datasets/commercial.py @@ -52,13 +52,8 @@ def __init__(self, cfg: DictConfig = None, scale=self.cfg.scale, skip_heavy_processing=skip_heavy_processing, ) - - # Force recomputation of context variables to match actual encoded data - print("=== FORCING CONTEXT VAR RECOMPUTATION ===") context_vars = self._get_context_var_dict(self.data) - print(f"Computed context_vars: {context_vars}") self.cfg.context_vars = context_vars - print("=========================================") def _load_data(self): """ @@ -93,18 +88,6 @@ def _load_data(self): self.data = data self.metadata = metadata - - # Debug: Check raw data before preprocessing - print("=== RAW DATA DEBUG ===") - print(f"Data shape: {data.shape}") - print(f"Metadata shape: {metadata.shape}") - print(f"Context vars in metadata: {[col for col in self.cfg.context_vars.keys() if col in metadata.columns]}") - for col in self.cfg.context_vars.keys(): - if col in metadata.columns: - unique_vals = metadata[col].nunique() - print(f"{col}: {unique_vals} unique values, dtype: {metadata[col].dtype}") - print("{col}: {self.cfg.context_vars[col]}, config unique") - print("======================") def _preprocess_data(self, data): @@ -151,17 +134,6 @@ def _preprocess_data(self, data): if nan_before != nan_after: print(f"Dropped {nan_before - nan_after} rows with NaN values in context variables") - # Debug: Check data after merging - print("=== AFTER MERGING DEBUG ===") - print(f"Merged shape: {merged.shape}") - for col in self.cfg.context_vars.keys(): - if col in merged.columns: - unique_vals = merged[col].nunique() - print(f"{col}: {unique_vals} unique values, dtype: {merged[col].dtype}") - if unique_vals < 20: # Only print if not too many - print(f" Values: {merged[col].unique()}") - print("===========================") - return merged def _handle_missing_data(self, merged): diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index e7a113c..69ce05a 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -95,16 +95,12 @@ def __init__( self.data, self.context_var_codes = self._encode_context_vars(self.data) self._save_context_var_codes() - print("normalizing") if self.normalize: self._init_normalizer() self.data = self._normalizer.transform(self.data) - print("finished normalizing") self.data = self.merge_timeseries_columns(self.data) - print("merged time series columns") self.data = self.data.reset_index() - print("reset index") # Check if we should skip heavy processing for DDP is_ddp_subprocess = self._is_ddp_subprocess() @@ -115,16 +111,11 @@ def __init__( # Only compute if not in DDP subprocess or if cache doesn't exist cache_path = self._get_rarity_cache_path() if self._load_rarity_cache(cache_path): - print("loaded cached rarity features") self._rarity_computed = True else: - print("computing rarity features...") self.data = self.get_frequency_based_rarity() - print("computed frequency based rarity") self.data = self.get_clustering_based_rarity() - print("computed clustering based rarity") self.data = self.get_combined_rarity() - print("computed combined rarity") self._save_rarity_cache(cache_path) self._rarity_computed = True @@ -291,16 +282,6 @@ def _encode_context_vars( Returns: Tuple of encoded DataFrame and mapping codes. """ - # Debug: Check data before encoding - print("=== BEFORE ENCODING DEBUG ===") - for col in self.context_vars: - if col in data.columns: - unique_vals = data[col].nunique() - print(f"{col}: {unique_vals} unique values, dtype: {data[col].dtype}") - if pd.api.types.is_numeric_dtype(data[col]): - print(f" Range: {data[col].min()} to {data[col].max()}") - print("=============================") - encoded_data, mapping = encode_context_variables( data=data, columns_to_encode=self.context_vars, @@ -308,15 +289,6 @@ def _encode_context_vars( numeric_cols=getattr(self.cfg, 'numeric_cols', None), ) - # Debug: Check data after encoding - print("=== AFTER ENCODING DEBUG ===") - for col in self.context_vars: - if col in encoded_data.columns: - unique_vals = encoded_data[col].nunique() - print(f"{col}: {unique_vals} unique values, dtype: {encoded_data[col].dtype}") - print(f" Range: {encoded_data[col].min()} to {encoded_data[col].max()}") - print("============================") - return encoded_data, mapping def _get_context_var_dict(self, data: pd.DataFrame) -> Dict[str, int]: @@ -333,13 +305,11 @@ def _get_context_var_dict(self, data: pd.DataFrame) -> Dict[str, int]: numeric_cols = getattr(self.cfg, 'numeric_cols', []) for var in self.context_vars: if var in numeric_cols: - print(f"{var}: {self.numeric_context_bins}, config unique") binned = pd.cut( data[var], bins=self.numeric_context_bins, include_lowest=True ) context_dict[var] = binned.nunique() else: - print(f"{var}: {self.context_vars[var]}, config unique") context_dict[var] = data[var].astype("category").nunique() return context_dict @@ -473,16 +443,11 @@ def _ensure_rarity_computed(self): if not self._rarity_computed: cache_path = self._get_rarity_cache_path() if self._load_rarity_cache(cache_path): - print("loaded cached rarity features") self._rarity_computed = True else: - print("computing rarity features (deferred for DDP compatibility)") self.data = self.get_frequency_based_rarity() - print("computed frequency based rarity") self.data = self.get_clustering_based_rarity() - print("computed clustering based rarity") self.data = self.get_combined_rarity() - print("computed combined rarity") self._save_rarity_cache(cache_path) self._rarity_computed = True diff --git a/cents/models/context.py b/cents/models/context.py index e916833..d96efe2 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -63,18 +63,7 @@ def forward( embedding (Tensor): Combined embedding of shape (batch_size, embedding_dim). classification_logits (Dict[str, Tensor]): Logits per variable, each of shape (batch_size, num_categories). - """ - # Debug: Check actual values being passed - print("=== CONTEXT MODULE FORWARD DEBUG ===") - for name, tensor in context_vars.items(): - print(f"{name}: shape={tensor.shape}, min={tensor.min().item()}, max={tensor.max().item()}") - if name in self.context_embeddings: - embedding_size = self.context_embeddings[name].num_embeddings - print(f" Embedding layer size: {embedding_size}") - if tensor.max().item() >= embedding_size: - print(f" ERROR: Index {tensor.max().item()} >= embedding size {embedding_size}") - print("====================================") - + """ embeddings = [ layer(context_vars[name]) for name, layer in self.context_embeddings.items() ] diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 27fd68f..ec0ed6a 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -156,14 +156,6 @@ def __init__( ] self.time_series_dims = dataset_cfg.time_series_dims self.do_scale = dataset_cfg.scale - - # Debug: Check what context_vars are being used - print("=== NORMALIZER INIT DEBUG ===") - print(f"Context vars from config: {dataset_cfg.context_vars}") - print(f"Context vars keys: {list(dataset_cfg.context_vars.keys())}") - for name, num_categories in dataset_cfg.context_vars.items(): - print(f" {name}: {num_categories} categories") - print("=============================") self.context_module = ContextModule( dataset_cfg.context_vars, @@ -363,7 +355,6 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: Returns: DataFrame with normalized series in same columns. """ - print("beginning transform") missing = [c for c in self.time_series_cols if c not in df.columns] if missing: @@ -374,9 +365,6 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: "Normalizer.transform expects data in split format with columns " f"{self.time_series_cols}." ) - - print("nothing missing") - df_out = df.copy() self.eval() with torch.no_grad(): diff --git a/cents/trainer.py b/cents/trainer.py index 59a3c8d..17d2d0a 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -86,7 +86,6 @@ def fit(self) -> "Trainer": num_workers=6, # Maximum for 7.5GB/10GB GPU usage persistent_workers=True, ) - print(f"got train loader with {train_loader.num_workers} workers") self.pl_trainer.fit(self.model, train_loader, None) return self diff --git a/scripts/train.py b/scripts/train.py index 662cc7f..1ae8857 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -19,11 +19,11 @@ def main() -> None: trainer_overrides = [ "trainer.max_epochs=5000", # "trainer.strategy=ddp_spawn", - # "trainer.devices=1,2,3", # Exclude GPU 0, use GPUs 1,2,3 - "trainer.devices=1", + "trainer.devices=1,2,3", # Exclude GPU 0, use GPUs 1,2,3 + # "trainer.devices=1", "trainer.eval_after_training=True", - # "train.accelerator=gpu", - "train.accelerator=cpu", + "train.accelerator=gpu", + # "train.accelerator=cpu", "trainer.early_stopping.patience=100", # Stop if no improvement for 100 epochs "trainer.early_stopping.monitor=train_loss", # Monitor training loss "trainer.early_stopping.mode=min", # Stop when loss stops decreasing @@ -43,12 +43,8 @@ def main() -> None: overrides=trainer_overrides, ) - print("initialized trainer") - trainer.fit() - print("fit") - if __name__ == "__main__": import os # Enable CUDA debugging for better error messages From 121fb56baa21b5e5917073be0b8120ec4d049dad Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 20 Oct 2025 12:52:28 -0400 Subject: [PATCH 06/50] Fixed merge error in CommercialDataset, Introduced logic for sampling during inference, minor training param changes --- cents/config/dataset/commercial.yaml | 13 +- cents/config/dataset/pecanstreet.yaml | 3 + cents/config/trainer/diffusion_ts.yaml | 4 +- cents/config/trainer/normalizer.yaml | 10 +- cents/datasets/commercial.py | 175 +++++++++++++++++++++---- cents/datasets/pecanstreet.py | 11 +- cents/datasets/timeseries_dataset.py | 15 ++- cents/datasets/utils.py | 11 +- 8 files changed, 195 insertions(+), 47 deletions(-) diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index 52a558a..7cfca1b 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -7,21 +7,22 @@ threshold: 8 seq_len: 24 time_series_dims: 1 shuffle: True +skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP) +max_samples: null # Limit dataset size (null = use all data) path: "./data/commercial/csv" -time_series_columns: ["energy_meter"] +time_series_columns: "energy_meter" data_columns: ["dataid","energy_meter","timestamp"] -metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sub_primaryspaceusage", "sqm", "sqft", "yearbuilt"] +metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sqft", "yearbuilt"] numeric_context_bins: 5 -numeric_cols: ["sqm", "sqft", "yearbuilt"] # Columns to bin as numeric - +numeric_cols: ["sqft", "yearbuilt"] # Columns to bin as numeric +reduce_cardinality: False context_vars: # for each desired context variable, add the name and number of categories year: 5 month: 12 weekday: 7 site_id: 19 primaryspaceusage: 17 - sub_primaryspaceusage: 105 - sqm: 5 + # sub_primaryspaceusage: 105 sqft: 5 yearbuilt: 5 diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index ab513cf..a63acec 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -7,10 +7,13 @@ threshold: 8 seq_len: 96 time_series_dims: 1 shuffle: True +skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP) +max_samples: null # Limit dataset size (null = use all data) path: "./data/pecanstreet/csv" time_series_columns: ["grid", "solar"] data_columns: ["dataid","local_15min","car1","grid","solar"] metadata_columns: ["dataid","building_type","solar","car1","city","state","total_square_footage","house_construction_year"] +numeric_cols: ["total_square_footage", "house_construction_year"] user_group: all # non_pv_users, all, pv_users numeric_context_bins: 5 diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 7714e8f..184c781 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -5,7 +5,7 @@ strategy: ddp_find_unused_parameters_true gradient_accumulate_every: 4 log_every_n_steps: 1 batch_size: 512 -max_epochs: 5000 +max_epochs: 200 base_lr: 1e-4 eval_after_training: False @@ -13,7 +13,7 @@ checkpoint: save_last: True # Save final model save_top_k: 3 # Save top 3 best models every_n_train_steps: null - every_n_epochs: 500 # Save every 500 epochs + every_n_epochs: 20 # Save every 500 epochs lr_scheduler_params: factor: 0.5 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index 666ded6..1e27808 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,12 +1,14 @@ strategy: auto -accelerator: auto +accelerator: gpu devices: 1 +precision: 16-mixed log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 n_epochs: 2000 -batch_size: 4096 -lr: 3e-4 +batch_size: 8192 +lr: 1e-5 +gradient_clip_val: 1.0 save_cycle: 5000 eval_after_training: False @@ -14,4 +16,4 @@ checkpoint: save_last: False save_top_k: 0 every_n_train_steps: null - every_n_epochs: null + every_n_epochs: 500 diff --git a/cents/datasets/commercial.py b/cents/datasets/commercial.py index f9d9e18..08d2613 100644 --- a/cents/datasets/commercial.py +++ b/cents/datasets/commercial.py @@ -10,12 +10,13 @@ from cents.datasets.timeseries_dataset import TimeSeriesDataset warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) +# These are warnings for an error that is accounted for in the code +warnings.filterwarnings("ignore", category=RuntimeWarning) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) class CommercialDataset(TimeSeriesDataset): def __init__(self, cfg: DictConfig = None, - overrides: Optional[List[str]] = None, - skip_heavy_processing: bool = False): + overrides: Optional[List[str]] = None): """ Initializes the commercial energy dataset. @@ -40,20 +41,27 @@ def __init__(self, cfg: DictConfig = None, self.geography = cfg.geography self._load_data() - self.time_series_column_names = cfg.time_series_columns + + # Apply max_samples limit if specified + # if hasattr(cfg, 'max_samples') and cfg.max_samples is not None: + # original_len = len(self.data) + # if original_len > cfg.max_samples: + # self.data = self.data.sample(n=cfg.max_samples, random_state=42).reset_index(drop=True) + # print(f"Limited dataset to {cfg.max_samples} samples (from {original_len} total)") + + self.time_series_column_names = self.cfg.time_series_columns self.time_series_dims = len(self.time_series_column_names) super().__init__( data=self.data, - time_series_column_names=self.cfg.time_series_columns, + time_series_column_names=self.time_series_column_names, seq_len=self.cfg.seq_len, context_var_column_names=list(self.cfg.context_vars.keys()), normalize=self.cfg.normalize, scale=self.cfg.scale, - skip_heavy_processing=skip_heavy_processing, + skip_heavy_processing=cfg.get('skip_heavy_processing', False), + size=cfg.get('max_samples', None) ) - context_vars = self._get_context_var_dict(self.data) - self.cfg.context_vars = context_vars def _load_data(self): """ @@ -87,8 +95,11 @@ def _load_data(self): metadata = metadata[metadata["site_id"] == self.geography] self.data = data + + print("Data head \n :" + str(self.data.head()) + "\n") self.metadata = metadata + print(f"Completed Loading Data: {metadata.sqft.notna().sum()} out of {metadata.sqft.size} not na") def _preprocess_data(self, data): """ @@ -103,40 +114,148 @@ def _preprocess_data(self, data): data = data.copy() data['datetime'] = pd.to_datetime(data['timestamp']) - data = data.dropna(subset=['energy_meter']) # any NaN makes the day shorter -> filtered by size later + # data = data.dropna(subset=['energy_meter']) # any NaN makes the day shorter -> filtered by size later data = data.sort_values(by=["dataid", "datetime"]) - # grouped = data.groupby(["dataid", "datetime", "year", "month", "date_day", "weekday"])["energy_meter"].apply(np.array).reset_index() - grouped = data.groupby(['dataid', pd.Grouper(key='datetime', freq='D')])['energy_meter'] - grouped = grouped.apply(np.asarray).reset_index(name='energy_meter') + # Extract date for grouping (groups all hourly measurements from same calendar day) + data['date'] = data['datetime'].dt.date + + grouped = data.groupby(['dataid', 'date'])['energy_meter'].apply(np.asarray).reset_index() + + # grouped = grouped[grouped["energy_meter"].apply(len) == self.cfg.seq_len].reset_index(drop=True) + grouped = grouped[grouped["energy_meter"].apply(lambda x: not np.isnan(x).any())] + grouped['date'] = pd.to_datetime(grouped['date']) - ## Just gonna remove any sequence with any missing values - grouped = grouped[grouped["energy_meter"].apply(len) == self.cfg.seq_len].reset_index(drop=True) - # grouped = grouped[grouped["energy_meter"].apply(lambda x: not np.isnan(x).any())] - grouped['year'] = grouped['datetime'].dt.year - grouped["month"] = grouped["datetime"].dt.month_name() - grouped["weekday"] = grouped["datetime"].dt.day_name() - grouped["date_day"] = grouped["datetime"].dt.day + grouped['year'] = grouped['date'].dt.year + grouped["month"] = grouped["date"].dt.month_name() + grouped["weekday"] = grouped["date"].dt.day_name() + grouped["date_day"] = grouped["date"].dt.day if grouped["energy_meter"].apply(lambda x: np.isnan(x).any()).any(): raise ValueError("NaN values remain in grouped energy_meter sequences after filtering.") - merged = pd.merge(grouped, self.metadata, how="left", left_on="dataid", right_on="building_id").drop(columns=["building_id"]) - merged.sort_values(by=["dataid", "datetime"], inplace=True) - # merged = self._handle_missing_data(merged) + merged.sort_values(by=["dataid", "date"], inplace=True) - # Drop rows with NaN values in context variables + merged = self._handle_missing_data(merged) + if hasattr(self.cfg, 'reduce_cardinality') and self.cfg.reduce_cardinality: + if 'sub_primaryspaceusage' in merged.columns: + print(f"Original sub_primaryspaceusage categories: {merged['sub_primaryspaceusage'].nunique()}") + merged = self._reduce_high_cardinality_features( + merged, + 'sub_primaryspaceusage', + min_samples=self.cfg.get('min_samples_per_category', 50), + max_categories=self.cfg.get('max_subcategories', 30) + ) + print(f"Reduced sub_primaryspaceusage categories: {merged['sub_primaryspaceusage'].nunique()}") + + # Check if any NaN remains context_cols = [col for col in self.cfg.context_vars.keys() if col in merged.columns] - nan_before = merged.shape[0] - merged = merged.dropna(subset=context_cols) - nan_after = merged.shape[0] - if nan_before != nan_after: - print(f"Dropped {nan_before - nan_after} rows with NaN values in context variables") + if merged[context_cols].isna().sum().sum() > 0: + print(f"Warning: {merged[context_cols].isna().sum().sum()} NaN values remain after handling") return merged def _handle_missing_data(self, merged): - raise NotImplementedError + """ + Fill NaNs using hierarchical and context-aware imputation. + + Strategy: + 1. Numeric: Group-based median (by site/building type), fallback to global median + 2. Categorical: Use hierarchical structure (e.g., sub_primaryspaceusage from primaryspaceusage) + 3. Last resort: Mode or 'unknown' category + + Args: + merged (pd.DataFrame): Merged sequence+metadata rows. + + Returns: + pd.DataFrame: Fully imputed DataFrame. + """ + df = merged.copy() + + # Handle numeric columns with group-based imputation + numeric_cols = self.cfg.get('numeric_cols', []) + for col in numeric_cols: + if col in df.columns and df[col].isna().any(): + # Try imputing based on similar buildings (same site_id and primaryspaceusage) + if 'site_id' in df.columns and 'primaryspaceusage' in df.columns: + for (site, usage), group in df.groupby(['site_id', 'primaryspaceusage']): + group_median = group[col].median() + if pd.notna(group_median): + mask = (df['site_id'] == site) & (df['primaryspaceusage'] == usage) & df[col].isna() + df.loc[mask, col] = group_median + + # Fallback to global median for remaining NaNs + if df[col].isna().any(): + global_median = df[col].median() + df[col] = df[col].fillna(global_median if pd.notna(global_median) else 0) + + # Handle hierarchical categorical: sub_primaryspaceusage from primaryspaceusage + if 'sub_primaryspaceusage' in df.columns and 'primaryspaceusage' in df.columns: + if df['sub_primaryspaceusage'].isna().any(): + # For each primaryspaceusage, find most common sub category + for primary in df['primaryspaceusage'].unique(): + mask = (df['primaryspaceusage'] == primary) + mode_sub = df.loc[mask, 'sub_primaryspaceusage'].mode() + if len(mode_sub) > 0: + df.loc[mask & df['sub_primaryspaceusage'].isna(), 'sub_primaryspaceusage'] = mode_sub[0] + + # Handle remaining categorical columns + categorical_cols = [col for col in self.cfg.context_vars.keys() + if col not in numeric_cols and col in df.columns] + for col in categorical_cols: + if col in df.columns and df[col].isna().any(): + mode_val = df[col].mode() + if len(mode_val) > 0: + df[col] = df[col].fillna(mode_val[0]) + else: + # Create 'unknown' category instead of dropping + df[col] = df[col].fillna('unknown') + + return df + + def _reduce_high_cardinality_features(self, df, col, min_samples=50, max_categories=30): + """ + Reduce high-cardinality categorical features by grouping rare categories. + + Strategy: Keep top N categories by frequency, group the rest as 'other_{parent}' + where parent is the primaryspaceusage. + + Args: + df (pd.DataFrame): Input DataFrame + col (str): Column name to reduce + min_samples (int): Minimum samples to keep category separate + max_categories (int): Maximum number of categories to keep + + Returns: + pd.DataFrame: DataFrame with reduced categories + """ + if col not in df.columns: + return df + + df = df.copy() + value_counts = df[col].value_counts() + + # Keep categories with enough samples + keep_categories = value_counts[value_counts >= min_samples].index.tolist() + + # If still too many, keep only top max_categories + if len(keep_categories) > max_categories: + keep_categories = value_counts.nlargest(max_categories).index.tolist() + + # For sub_primaryspaceusage, group by parent category + if col == 'sub_primaryspaceusage' and 'primaryspaceusage' in df.columns: + # Map rare subcategories to 'other_{primary}' + def map_category(row): + if row[col] in keep_categories: + return row[col] + else: + return f"other_{row['primaryspaceusage']}" + df[col] = df.apply(map_category, axis=1) + else: + # Generic grouping + df.loc[~df[col].isin(keep_categories), col] = 'other' + + return df \ No newline at end of file diff --git a/cents/datasets/pecanstreet.py b/cents/datasets/pecanstreet.py index c61dcab..2d195c8 100644 --- a/cents/datasets/pecanstreet.py +++ b/cents/datasets/pecanstreet.py @@ -33,7 +33,6 @@ def __init__( self, cfg: Optional[DictConfig] = None, overrides: Optional[List[str]] = None, - skip_heavy_processing: bool = False, ): """ Initialize and preprocess the PecanStreet dataset. @@ -75,6 +74,13 @@ def __init__( self._load_data() self._set_user_flags() + + # Apply max_samples limit if specified + if hasattr(cfg, 'max_samples') and cfg.max_samples is not None: + original_len = len(self.data) + if original_len > cfg.max_samples: + self.data = self.data.sample(n=cfg.max_samples, random_state=42).reset_index(drop=True) + print(f"Limited dataset to {cfg.max_samples} samples (from {original_len} total)") ts_cols: List[str] = self.cfg.time_series_columns[: self.time_series_dims] @@ -85,7 +91,8 @@ def __init__( seq_len=self.cfg.seq_len, normalize=self.cfg.normalize, scale=self.cfg.scale, - skip_heavy_processing=skip_heavy_processing, + skip_heavy_processing=cfg.get('skip_heavy_processing', False), + size=cfg.get('max_samples', None) ) def _load_data(self) -> None: diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 69ce05a..9d5f15e 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -11,6 +11,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from sklearn.cluster import KMeans from torch.utils.data import DataLoader, Dataset +from omegaconf import ListConfig from cents.datasets.utils import encode_context_variables from cents.models.normalizer import Normalizer @@ -51,8 +52,14 @@ def __init__( scale: bool = True, overrides: Dict[str, Any] = {}, skip_heavy_processing: bool = False, + size: int = None, ): # Initialize basic attributes + # Handle OmegaConf ListConfig objects + if isinstance(time_series_column_names, ListConfig): + time_series_column_names = list(time_series_column_names) + if isinstance(context_var_column_names, ListConfig): + context_var_column_names = list(context_var_column_names) self.time_series_column_names = ( time_series_column_names if isinstance(time_series_column_names, list) @@ -91,6 +98,10 @@ def __init__( # Preprocess and optionally encode context self.data = self._preprocess_data(data) + + if size is not None: + self.data = self.data.sample(size) + print(f"Sampled {size} rows from dataset") if self.context_vars: self.data, self.context_var_codes = self._encode_context_vars(self.data) self._save_context_var_codes() @@ -108,6 +119,7 @@ def __init__( print("skipped rarity computation for DDP compatibility") self._rarity_computed = False else: + print("Computing rarity features...") # Only compute if not in DDP subprocess or if cache doesn't exist cache_path = self._get_rarity_cache_path() if self._load_rarity_cache(cache_path): @@ -118,7 +130,6 @@ def __init__( self.data = self.get_combined_rarity() self._save_rarity_cache(cache_path) self._rarity_computed = True - @abstractmethod def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: """ @@ -176,7 +187,7 @@ def __getstate__(self): return state def get_train_dataloader( - self, batch_size: int, shuffle: bool = True, num_workers: int = 9, persistent_workers: bool = False + self, batch_size: int, shuffle: bool = True, num_workers: int = 8, persistent_workers: bool = True ) -> DataLoader: """ Create a PyTorch DataLoader for training. diff --git a/cents/datasets/utils.py b/cents/datasets/utils.py index 60503ba..682c4c3 100644 --- a/cents/datasets/utils.py +++ b/cents/datasets/utils.py @@ -157,11 +157,16 @@ def encode_context_variables( if numeric_cols and col in numeric_cols: # Numeric column: Perform binning # Handle NaN values by filling with median before binning - if encoded_data[col].isna().any(): + if encoded_data[col].isna().all(): + # Entire column is NaN - fill with 0 and create single bin + print(f" Warning: {col} is entirely NaN, filling with 0") + encoded_data[col] = 0 + elif encoded_data[col].isna().any(): print(f" Warning: {col} has {encoded_data[col].isna().sum()} NaN values, filling with median") - encoded_data[col] = encoded_data[col].fillna(encoded_data[col].median()) + median_val = encoded_data[col].median() + encoded_data[col] = encoded_data[col].fillna(median_val if pd.notna(median_val) else 0) - binned = pd.cut(encoded_data[col], bins=bins, include_lowest=True) + binned = pd.cut(encoded_data[col], bins=bins, include_lowest=True, duplicates='drop') encoded_data[col] = binned.cat.codes # Assign integer codes starting from 0 bin_intervals = binned.cat.categories # Create the mapping from integer code to bin interval From c18968daaf3fe16c9ce704ab2e4d7bf351bec2a5 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 27 Oct 2025 13:54:12 -0400 Subject: [PATCH 07/50] Eval pretrain fixes, normalization caching for subprocesses, EMA weight fixes --- cents/config/config.yaml | 2 - cents/config/dataset/commercial.yaml | 5 +- cents/config/evaluator/default.yaml | 8 ++- cents/config/model/diffusion_ts.yaml | 2 +- cents/config/trainer/diffusion_ts.yaml | 2 +- cents/data_generator.py | 21 ++++++- cents/datasets/commercial.py | 21 ------- cents/datasets/pecanstreet.py | 7 --- cents/datasets/timeseries_dataset.py | 59 +++++++++++++------ cents/eval/eval.py | 67 +++++++++++++--------- cents/models/diffusion_ts.py | 78 ++++++++++++++++++++++---- cents/models/normalizer.py | 8 ++- scripts/eval_pretrained.py | 56 ++++++++++++------ scripts/train.py | 11 ++-- 14 files changed, 229 insertions(+), 118 deletions(-) diff --git a/cents/config/config.yaml b/cents/config/config.yaml index e9a3532..f779f20 100644 --- a/cents/config/config.yaml +++ b/cents/config/config.yaml @@ -1,6 +1,4 @@ device: auto -job_name: ${model.name}_${dataset.name}_${dataset.user_group} -run_dir: outputs/${job_name}/${now:%Y-%m-%d_%H-%M-%S} model_ckpt: null wandb: diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index 7cfca1b..034ede6 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -1,5 +1,6 @@ name: commercial geography: null +user_group: all normalize: True scale: True use_learned_normalizer: True @@ -17,11 +18,11 @@ numeric_context_bins: 5 numeric_cols: ["sqft", "yearbuilt"] # Columns to bin as numeric reduce_cardinality: False context_vars: # for each desired context variable, add the name and number of categories - year: 5 + year: 2 month: 12 weekday: 7 site_id: 19 - primaryspaceusage: 17 + primaryspaceusage: 16 # sub_primaryspaceusage: 105 sqft: 5 yearbuilt: 5 diff --git a/cents/config/evaluator/default.yaml b/cents/config/evaluator/default.yaml index 69ca339..2f3467f 100644 --- a/cents/config/evaluator/default.yaml +++ b/cents/config/evaluator/default.yaml @@ -1,7 +1,11 @@ -model_name: ${model.name} +model: + name: diffusion_ts # Set this to your model name +dataset: + name: commercial # Set this to your dataset name (e.g., "commercial") eval_pv_shift: False eval_metrics: True eval_context_sparse: True save_results: False eval_disentanglement: True -save_dir: ${run_dir}/eval +job_name: diffusion_ts_commercial +save_dir: outputs/diffusion_ts_commercial/eval diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 8434637..a40635d 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -24,4 +24,4 @@ reg_weight: null gradient_accumulate_every: 2 ema_decay: 0.99 ema_update_interval: 10 -use_ema_sampling: False +use_ema_sampling: True diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 184c781..551f1ce 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -1,7 +1,7 @@ precision: "16-mixed" accelerator: auto devices: auto -strategy: ddp_find_unused_parameters_true +strategy: ddp_find_unused_parameters_false gradient_accumulate_every: 4 log_every_n_steps: 1 batch_size: 512 diff --git a/cents/data_generator.py b/cents/data_generator.py index e0d613b..b8042b5 100644 --- a/cents/data_generator.py +++ b/cents/data_generator.py @@ -46,9 +46,11 @@ class DataGenerator: def __init__( self, - model_name: str, + model_name: str = None, + model_type: str = None, device: str = None, cfg: DictConfig = None, + dataset = None, model: Optional[pl.LightningModule] = None, normalizer: Optional[Normalizer] = None, ): @@ -71,8 +73,22 @@ def __init__( self.model = None self.normalizer = None self.load_pretrained(model_name) + elif model_type is not None: + # Init without loading - user will call load_from_checkpoint separately + self.model_type = model_type + self.model = None + self.normalizer = None + if dataset is not None: + self.cfg = cfg or OmegaConf.create({}) + if not hasattr(self.cfg, 'dataset'): + self.cfg.dataset = dataset.cfg + if not hasattr(self.cfg, 'model'): + self.cfg.model = load_yaml(CONF_DIR / "model" / f"{model_type}.yaml") + self.set_dataset_spec( + self.cfg.dataset, self._read_ctx_codes(self.cfg.dataset.name) + ) else: - raise ValueError("Must provide either model_name or model instance.") + raise ValueError("Must provide either model_name, model_type, or model instance.") def _default_cfg(self) -> DictConfig: """ @@ -194,6 +210,7 @@ def load_from_checkpoint( ModelCls = get_model_cls(self.model_type) if ckpt_path.suffix == ".ckpt": + print(f"[Cents] Loading model from checkpoint: {ckpt_path}") self.model = ( ModelCls.load_from_checkpoint( cfg=self.cfg, diff --git a/cents/datasets/commercial.py b/cents/datasets/commercial.py index 08d2613..14a207b 100644 --- a/cents/datasets/commercial.py +++ b/cents/datasets/commercial.py @@ -42,13 +42,6 @@ def __init__(self, cfg: DictConfig = None, self._load_data() - # Apply max_samples limit if specified - # if hasattr(cfg, 'max_samples') and cfg.max_samples is not None: - # original_len = len(self.data) - # if original_len > cfg.max_samples: - # self.data = self.data.sample(n=cfg.max_samples, random_state=42).reset_index(drop=True) - # print(f"Limited dataset to {cfg.max_samples} samples (from {original_len} total)") - self.time_series_column_names = self.cfg.time_series_columns self.time_series_dims = len(self.time_series_column_names) @@ -95,12 +88,8 @@ def _load_data(self): metadata = metadata[metadata["site_id"] == self.geography] self.data = data - - print("Data head \n :" + str(self.data.head()) + "\n") self.metadata = metadata - print(f"Completed Loading Data: {metadata.sqft.notna().sum()} out of {metadata.sqft.size} not na") - def _preprocess_data(self, data): """ Creates sequences of seq_len and merges metadata. Removes any sequences with missing data. @@ -138,16 +127,6 @@ def _preprocess_data(self, data): merged.sort_values(by=["dataid", "date"], inplace=True) merged = self._handle_missing_data(merged) - if hasattr(self.cfg, 'reduce_cardinality') and self.cfg.reduce_cardinality: - if 'sub_primaryspaceusage' in merged.columns: - print(f"Original sub_primaryspaceusage categories: {merged['sub_primaryspaceusage'].nunique()}") - merged = self._reduce_high_cardinality_features( - merged, - 'sub_primaryspaceusage', - min_samples=self.cfg.get('min_samples_per_category', 50), - max_categories=self.cfg.get('max_subcategories', 30) - ) - print(f"Reduced sub_primaryspaceusage categories: {merged['sub_primaryspaceusage'].nunique()}") # Check if any NaN remains context_cols = [col for col in self.cfg.context_vars.keys() if col in merged.columns] diff --git a/cents/datasets/pecanstreet.py b/cents/datasets/pecanstreet.py index 2d195c8..54c4270 100644 --- a/cents/datasets/pecanstreet.py +++ b/cents/datasets/pecanstreet.py @@ -75,13 +75,6 @@ def __init__( self._load_data() self._set_user_flags() - # Apply max_samples limit if specified - if hasattr(cfg, 'max_samples') and cfg.max_samples is not None: - original_len = len(self.data) - if original_len > cfg.max_samples: - self.data = self.data.sample(n=cfg.max_samples, random_state=42).reset_index(drop=True) - print(f"Limited dataset to {cfg.max_samples} samples (from {original_len} total)") - ts_cols: List[str] = self.cfg.time_series_columns[: self.time_series_dims] super().__init__( diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 9d5f15e..89dca05 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -80,8 +80,8 @@ def __init__( cfg = apply_overrides(cfg, dyn) cfg.time_series_columns = self.time_series_column_names self.numeric_context_bins = cfg.numeric_context_bins - context_vars = self._get_context_var_dict(data) - cfg.context_vars = context_vars + # context_vars = self._get_context_var_dict(data) + # cfg.context_vars = context_vars self.cfg = cfg self.numeric_context_bins = self.cfg.numeric_context_bins @@ -106,30 +106,46 @@ def __init__( self.data, self.context_var_codes = self._encode_context_vars(self.data) self._save_context_var_codes() + is_ddp_subprocess = self._is_ddp_subprocess() if self.normalize: self._init_normalizer() - self.data = self._normalizer.transform(self.data) - + cache_path = self._get_normalization_cache_path() + + if cache_path.exists() and is_ddp_subprocess: + print(f"[DDP Subprocess] Loading pre-normalized data from cache") + import pickle + with open(cache_path, 'rb') as f: + self.data = pickle.load(f) + else: + # Normalize (only main process or if cache doesn't exist) + if not is_ddp_subprocess: + print("[Main Process] Normalizing data...") + self.data = self._normalizer.transform(self.data) + + # Save to cache for subprocesses (only main process) + if not is_ddp_subprocess: + cache_path.parent.mkdir(parents=True, exist_ok=True) + import pickle + with open(cache_path, 'wb') as f: + pickle.dump(self.data, f) + print(f"[Main Process] Cached normalized data for subprocesses") self.data = self.merge_timeseries_columns(self.data) self.data = self.data.reset_index() # Check if we should skip heavy processing for DDP - is_ddp_subprocess = self._is_ddp_subprocess() - if skip_heavy_processing or is_ddp_subprocess: + if is_ddp_subprocess and skip_heavy_processing: print("skipped rarity computation for DDP compatibility") - self._rarity_computed = False - else: - print("Computing rarity features...") - # Only compute if not in DDP subprocess or if cache doesn't exist cache_path = self._get_rarity_cache_path() if self._load_rarity_cache(cache_path): self._rarity_computed = True - else: - self.data = self.get_frequency_based_rarity() - self.data = self.get_clustering_based_rarity() - self.data = self.get_combined_rarity() - self._save_rarity_cache(cache_path) - self._rarity_computed = True + else: + print("Computing rarity features...") + self.data = self.get_frequency_based_rarity() + self.data = self.get_clustering_based_rarity() + self.data = self.get_combined_rarity() + rarity_cache_path = self._get_rarity_cache_path() + self._save_rarity_cache(rarity_cache_path) + self._rarity_computed = True @abstractmethod def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: """ @@ -471,6 +487,17 @@ def _get_rarity_cache_path(self) -> str: cache_dir = os.path.join(ROOT_DIR, "cache", "rarity") os.makedirs(cache_dir, exist_ok=True) return os.path.join(cache_dir, f"rarity_{cache_hash}.pkl") + + def _get_normalization_cache_path(self): + """Get cache file path for normalized data.""" + import hashlib + from pathlib import Path + # Create hash based on dataset + normalizer characteristics + cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{self.normalize}_{self.scale}" + cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8] + cache_dir = Path(ROOT_DIR) / "cache" / "normalized_data" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / f"normalized_{cache_hash}.pkl" def _save_rarity_cache(self, cache_path: str) -> None: """Save rarity features to cache.""" diff --git a/cents/eval/eval.py b/cents/eval/eval.py index d3c1a72..c6f336d 100644 --- a/cents/eval/eval.py +++ b/cents/eval/eval.py @@ -106,11 +106,13 @@ def evaluate_model( if data_generator is not None: if data_generator.model is not None: model = data_generator.model + print(f"[Cents] Using pre-trained model from DataGenerator") if data_generator.normalizer is not None: dataset._normalizer = data_generator.normalizer - print("[CENTS] Using pre-trained normalizer from DataGenerator") - elif not model: - model = self.get_trained_model(dataset) + print("[Cents] Using pre-trained normalizer from DataGenerator") + else: + if not model: + model = self.get_trained_model(dataset) model.to(self.device) model.eval() @@ -224,29 +226,42 @@ def compute_quality_metrics( rare_syn_data = syn_data[mask] rare_real_df = real_data_frame[mask].reset_index(drop=True) - dtw_mean_r, dtw_std_r = dynamic_time_warping_dist( - rare_real_data, rare_syn_data - ) - rare_metrics["DTW"] = {"mean": dtw_mean_r, "std": dtw_std_r} - logger.info(f"[Cents] DTW completed") - - mmd_mean_r, mmd_std_r = calculate_mmd(rare_real_data, rare_syn_data) - rare_metrics["MMD"] = {"mean": mmd_mean_r, "std": mmd_std_r} - logger.info(f"[Cents] MMD completed") - - fid_score_r = Context_FID(rare_real_data, rare_syn_data) - rare_metrics["Context_FID"] = fid_score_r - logger.info(f"[Cents] Context-FID completed") - - discr_score_r, _, _ = discriminative_score_metrics( - rare_real_data, rare_syn_data - ) - rare_metrics["Disc_Score"] = discr_score_r - logger.info(f"[Cents] Discr Score completed") - - pred_score_r = predictive_score_metrics(rare_real_data, rare_syn_data) - rare_metrics["Pred_Score"] = pred_score_r - logger.info(f"[Cents] Pred Score completed") + # Check if rare subset has any valid samples + if len(rare_real_data) == 0: + logger.warning("[Cents] Rare subset is empty - skipping rare subset metrics") + rare_metrics = { + "DTW": {"mean": float('nan'), "std": float('nan')}, + "MMD": {"mean": float('nan'), "std": float('nan')}, + "Context_FID": float('nan'), + "Disc_Score": float('nan'), + "Pred_Score": float('nan') + } + else: + logger.info(f"[Cents] Rare subset contains {len(rare_real_data)} samples") + + dtw_mean_r, dtw_std_r = dynamic_time_warping_dist( + rare_real_data, rare_syn_data + ) + rare_metrics["DTW"] = {"mean": dtw_mean_r, "std": dtw_std_r} + logger.info(f"[Cents] DTW completed") + + mmd_mean_r, mmd_std_r = calculate_mmd(rare_real_data, rare_syn_data) + rare_metrics["MMD"] = {"mean": mmd_mean_r, "std": mmd_std_r} + logger.info(f"[Cents] MMD completed") + + fid_score_r = Context_FID(rare_real_data, rare_syn_data) + rare_metrics["Context_FID"] = fid_score_r + logger.info(f"[Cents] Context-FID completed") + + discr_score_r, _, _ = discriminative_score_metrics( + rare_real_data, rare_syn_data + ) + rare_metrics["Disc_Score"] = discr_score_r + logger.info(f"[Cents] Discr Score completed") + + pred_score_r = predictive_score_metrics(rare_real_data, rare_syn_data) + rare_metrics["Pred_Score"] = pred_score_r + logger.info(f"[Cents] Pred Score completed") logger.info("[Cents] Done computing Rare-Subset Metrics.") metrics["rare_subset"] = rare_metrics diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index fadf2e8..ddd7e35 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -75,7 +75,7 @@ def __init__(self, cfg: DictConfig): ) # EMA helper will be initialized on train start - self._ema_helper: Optional[EMA] = None + self._ema: Optional[EMA] = None # set up beta schedule if cfg.model.beta_schedule == "linear": @@ -300,8 +300,42 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: """ Apply EMA update after each batch end. """ - if self._ema_helper: - self._ema_helper.update() + if hasattr(self, '_ema') and self._ema: + self._ema.update() + + def on_load_checkpoint(self, checkpoint: dict) -> None: + """ + Restore EMA weights from checkpoint after loading. + """ + super().on_load_checkpoint(checkpoint) + + # Check if EMA weights exist in checkpoint + state_dict = checkpoint.get('state_dict', {}) + ema_keys = [key for key in state_dict.keys() if key.startswith('_ema.')] + + if ema_keys: + if not hasattr(self, '_ema') or self._ema is None: + self._ema = EMA( + self.model, + beta=self.cfg.model.ema_decay, + update_every=self.cfg.model.ema_update_interval, + ) + + # Load EMA weights into the EMA helper + ema_state_dict = {} + for key, value in state_dict.items(): + if key.startswith('_ema.ema_model.'): + # Map '_ema.ema_model.*' -> 'ema_model.*' (remove the _ema prefix) + ema_key = key.replace('_ema.ema_model.', 'ema_model.') + ema_state_dict[ema_key] = value + + if ema_state_dict: + print(f"Loading {len(ema_state_dict)} EMA weights from checkpoint") + self._ema.ema_model.load_state_dict(ema_state_dict, strict=False) + else: + raise ValueError("No EMA model weights found in checkpoint") + else: + raise ValueError("No EMA keys found in checkpoint") @torch.no_grad() def model_predictions( @@ -424,12 +458,25 @@ def generate(self, context_vars: dict) -> torch.Tensor: shape = (current_bs, self.seq_len, self.time_series_dims) with torch.no_grad(): - if getattr(self.cfg.model, "use_ema_sampling", False) and hasattr( - self, "_ema_helper" - ): - samples = self._ema_helper.ema_model._generate( - shape, batch_context_vars - ) + if getattr(self.cfg.model, "use_ema_sampling", False): + self._ensure_ema_helper() + if hasattr(self, "_ema") and self._ema: + original_model = self.model + self.model = self._ema.ema_model + try: + if self.fast_sampling: + samples = self.fast_sample(shape, batch_context_vars) + else: + samples = self.sample(shape, batch_context_vars) + finally: + # Restore original model + self.model = original_model + else: + samples = ( + self.fast_sample(shape, batch_context_vars) + if self.fast_sampling + else self.sample(shape, batch_context_vars) + ) else: samples = ( self.fast_sample(shape, batch_context_vars) @@ -440,7 +487,18 @@ def generate(self, context_vars: dict) -> torch.Tensor: generated_samples.append(samples) return torch.cat(generated_samples, dim=0) - + + def _ensure_ema_helper(self) -> None: + """ + Ensure EMA helper is initialized if needed for inference. + """ + if not hasattr(self, '_ema') or self._ema is None: + print("Initializing EMA helper for inference...") + self._ema = EMA( + self.model, + beta=self.cfg.model.ema_decay, + update_every=self.cfg.model.ema_update_interval, + ) class EMA(nn.Module): """ diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index ec0ed6a..9cf3d64 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm from cents.datasets.utils import split_timeseries from cents.models.base import NormalizerModel @@ -234,7 +235,8 @@ def train_dataloader(self): ds, batch_size=self.normalizer_training_cfg.batch_size, shuffle=True, - num_workers=0, + num_workers=23, + persistent_workers=True ) def _compute_group_stats(self) -> dict: @@ -368,7 +370,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: df_out = df.copy() self.eval() with torch.no_grad(): - for i, row in df_out.iterrows(): + for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Normalizing"): ctx = { v: torch.tensor(row[v], dtype=torch.long).unsqueeze(0) for v in self.context_vars @@ -410,7 +412,7 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: df_out = df.copy() self.eval() with torch.no_grad(): - for i, row in df_out.iterrows(): + for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Inverse normalizing"): ctx = { v: torch.tensor(row[v], dtype=torch.long).unsqueeze(0) for v in self.context_vars diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 12606b6..e49354d 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -1,37 +1,39 @@ import logging from datetime import datetime -import wandb +# import wandb from omegaconf import OmegaConf from cents.data_generator import DataGenerator from cents.datasets.pecanstreet import PecanStreetDataset +from cents.datasets.commercial import CommercialDataset from cents.eval.eval import Evaluator from cents.utils.config_loader import load_yaml +from pathlib import Path +import torch - -MODEL_KEY = "Watts_1_2D" +MODEL_KEY = "acgan" DATASET_OVERRIDES = [ - "user_group=pv_users", - "time_series_dims=2", + "max_samples=10000", + "skip_heavy_processing=True" +] + +PECAN_OVERRIDES = [ + "time_series_dims=1", + "user_group=all" ] +HOME = Path.home() def main() -> None: + model_ckpt = HOME / f"Cents/cents/outputs/{MODEL_KEY}_pecanstreet_all/2025-10-27_10-09-04/pecanstreet_acgan_dim1.ckpt" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) + print("Loading dataset...") + dataset = PecanStreetDataset(overrides=DATASET_OVERRIDES + PECAN_OVERRIDES) - if wandb.run is None: - wandb.init( - project="cents", - name=f"{MODEL_KEY}-eval-only-run_{datetime.now().strftime('%Y%m%d-%H%M%S')}", - entity="pmfeen-massachusetts-institute-of-technology", - ) - - # Dataset with simple overrides (no Hydra) - dataset = PecanStreetDataset(overrides=DATASET_OVERRIDES) - + normalizer_ckpt = HOME / ".cache/cents/checkpoints/pecanstreet/normalizer/pecanstreet_normalizer_dim1.ckpt" # Build a minimal cfg for evaluator and generator eval_cfg = load_yaml("cents/config/evaluator/default.yaml") top_cfg = load_yaml("cents/config/config.yaml") @@ -39,15 +41,33 @@ def main() -> None: cfg.evaluator = eval_cfg cfg.wandb = top_cfg.get("wandb", {}) cfg.device = top_cfg.get("device", "auto") - cfg.model = OmegaConf.create({"name": MODEL_KEY}) + cfg.model = OmegaConf.create(OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{MODEL_KEY}.yaml"), resolve=True)) cfg.dataset = OmegaConf.create( OmegaConf.to_container(dataset.cfg, resolve=True) ) + # Enable EMA sampling to use the EMA weights from checkpoint + cfg.model.use_ema_sampling = True + cfg.eval_pv_shift = eval_cfg.get("eval_pv_shift", False) + cfg.eval_metrics = eval_cfg.get("eval_metrics", True) + cfg.eval_context_sparse = eval_cfg.get("eval_context_sparse", True) + cfg.save_results = eval_cfg.get("save_results", False) + cfg.eval_disentanglement = eval_cfg.get("eval_disentanglement", True) + cfg.job_name = eval_cfg.get("job_name", "default_job") + cfg.save_results = True + cfg.save_dir = HOME / f"Cents/cents/outputs/{MODEL_KEY}_pecanstreet_all/2025-10-27_10-09-04/eval" + print("Dataset spec set. Setting up DataGenerator...") # Use the fixed checkpoint with DataGenerator - gen = DataGenerator(MODEL_KEY) - gen.set_dataset_spec(cfg.dataset, dataset.get_context_var_codes()) + gen = DataGenerator(model_type = MODEL_KEY, dataset=dataset) + print("Loading checkpoint... EMA sampling enabled - will use EMA weights for generation") + gen.load_from_checkpoint(model_ckpt, normalizer_ckpt) + + gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) + cfg.dataset = gen.model.cfg.dataset + + print("Checkpoint loaded") + print("Evaluating model...") results = Evaluator(cfg, dataset).evaluate_model(data_generator=gen) print(results) diff --git a/scripts/train.py b/scripts/train.py index 1ae8857..f44aca5 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -12,14 +12,11 @@ def main() -> None: CR_LOSS_WEIGHT = 0.1 TC_LOSS_WEIGHT = 0.1 # Skip heavy processing for DDP compatibility - dataset = CommercialDataset( - skip_heavy_processing=True - ) - + dataset = CommercialDataset(overrides=["skip_heavy_processing=True"]) trainer_overrides = [ "trainer.max_epochs=5000", # "trainer.strategy=ddp_spawn", - "trainer.devices=1,2,3", # Exclude GPU 0, use GPUs 1,2,3 + "trainer.devices=auto", # Exclude GPU 0, use GPUs 1,2,3 # "trainer.devices=1", "trainer.eval_after_training=True", "train.accelerator=gpu", @@ -48,6 +45,6 @@ def main() -> None: if __name__ == "__main__": import os # Enable CUDA debugging for better error messages - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - os.environ["TORCH_USE_CUDA_DSA"] = "1" + # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + # os.environ["TORCH_USE_CUDA_DSA"] = "1" main() From 4f63eb42bc626c880d9e0efa7729204e4fcd6fc8 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 27 Oct 2025 22:33:43 -0400 Subject: [PATCH 08/50] Modularized ContextModule, StatsHead for future work --- cents/config/dataset/commercial.yaml | 3 +++ cents/config/dataset/default.yaml | 1 + cents/config/dataset/pecanstreet.yaml | 2 ++ cents/config/model/acgan.yaml | 1 + cents/config/model/diffusion_ts.yaml | 3 ++- cents/models/acgan.py | 9 +++++---- cents/models/base.py | 9 +++++++-- cents/models/context.py | 12 +++++++++++- cents/models/normalizer.py | 26 +++++++++++++++++++------- scripts/train.py | 4 +++- 10 files changed, 54 insertions(+), 16 deletions(-) diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index 034ede6..d3c5fb5 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -17,6 +17,9 @@ metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sqft", "yearb numeric_context_bins: 5 numeric_cols: ["sqft", "yearbuilt"] # Columns to bin as numeric reduce_cardinality: False +stats_head_type: mlp +context_module_type: mlp + context_vars: # for each desired context variable, add the name and number of categories year: 2 month: 12 diff --git a/cents/config/dataset/default.yaml b/cents/config/dataset/default.yaml index 13c28be..c91efd4 100644 --- a/cents/config/dataset/default.yaml +++ b/cents/config/dataset/default.yaml @@ -11,3 +11,4 @@ user_group: null numeric_context_bins: 5 context_vars: {} +stats_head_type: mlp diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index a63acec..7de3611 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -16,6 +16,8 @@ metadata_columns: ["dataid","building_type","solar","car1","city","state","total numeric_cols: ["total_square_footage", "house_construction_year"] user_group: all # non_pv_users, all, pv_users numeric_context_bins: 5 +stats_head_type: mlp +context_module_type: mlp context_vars: # for each desired context variable, add the name and number of categories month: 12 diff --git a/cents/config/model/acgan.yaml b/cents/config/model/acgan.yaml index b72bd60..fd0bc98 100644 --- a/cents/config/model/acgan.yaml +++ b/cents/config/model/acgan.yaml @@ -2,6 +2,7 @@ _target_: generator.gan.acgan.ACGAN name: acgan noise_dim: 256 cond_emb_dim: 16 +context_module_type: mlp include_auxiliary_losses: True context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index a40635d..02ccf0f 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -4,6 +4,7 @@ context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 noise_dim: 256 cond_emb_dim: 16 +context_module_type: mlp n_layer_enc: 4 n_layer_dec: 5 d_model: 128 @@ -24,4 +25,4 @@ reg_weight: null gradient_accumulate_every: 2 ema_decay: 0.99 ema_update_interval: 10 -use_ema_sampling: True +use_ema_sampling: True \ No newline at end of file diff --git a/cents/models/acgan.py b/cents/models/acgan.py index 2f7f84d..9a3b9e7 100644 --- a/cents/models/acgan.py +++ b/cents/models/acgan.py @@ -18,7 +18,8 @@ from omegaconf import DictConfig from cents.models.base import GenerativeModel -from cents.models.context import ContextModule +from cents.models.context import MLPContextModule # Import to trigger registration +from cents.models.context_registry import get_context_module_cls from cents.models.model_utils import total_correlation from cents.models.registry import register_model @@ -43,7 +44,7 @@ def __init__( embedding_dim: int, final_window_length: int, time_series_dims: int, - context_module: ContextModule, + context_module_type: str, context_vars: Optional[dict] = None, base_channels: int = 256, ): @@ -56,7 +57,7 @@ def __init__( self.base_channels = base_channels self.context_vars = context_vars - self.context_module = context_module + self.context_module = get_context_module_cls(context_module_type)(context_vars, embedding_dim) in_dim = noise_dim + (embedding_dim if context_vars else 0) self.fc = nn.Linear(in_dim, self.final_window_length * base_channels) @@ -199,7 +200,7 @@ def __init__(self, cfg: DictConfig): embedding_dim=cfg.model.cond_emb_dim, final_window_length=cfg.dataset.seq_len, time_series_dims=cfg.dataset.time_series_dims, - context_module=self.context_module, + context_module_type=cfg.model.context_module_type, context_vars=cfg.dataset.context_vars, ) self.discriminator = Discriminator( diff --git a/cents/models/base.py b/cents/models/base.py index 29a6682..16dcb35 100644 --- a/cents/models/base.py +++ b/cents/models/base.py @@ -5,7 +5,8 @@ import torch from omegaconf import DictConfig -from cents.models.context import ContextModule +from cents.models.context import MLPContextModule # Import to trigger registration +from cents.models.context_registry import get_context_module_cls class BaseModel(pl.LightningModule, ABC): @@ -38,7 +39,11 @@ def __init__(self, cfg: DictConfig = None): if hasattr(cfg.dataset, "context_vars") and cfg.dataset.context_vars: emb_dim = getattr(cfg.model, "cond_emb_dim", 256) - self.context_module = ContextModule(cfg.dataset.context_vars, emb_dim) + context_module_type = getattr(cfg.model, "context_module_type", "default") + + # Use registry to get the context module class + ContextModuleCls = get_context_module_cls(context_module_type) + self.context_module = ContextModuleCls(cfg.dataset.context_vars, emb_dim) else: self.context_module = None diff --git a/cents/models/context.py b/cents/models/context.py index d96efe2..f5eb704 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -1,8 +1,18 @@ import torch import torch.nn as nn +from abc import abstractmethod +from .context_registry import register_context_module +class BaseContextModule(nn.Module): + """ + Base class for context modules. Subclasses must implement the forward method. + """ + @abstractmethod + def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + pass -class ContextModule(nn.Module): +@register_context_module("default", "mlp") +class MLPContextModule(BaseContextModule): """ Integrates multiple context variables into a single embedding and provides auxiliary classification logits for each variable. diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 9cf3d64..11e5deb 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -10,11 +10,14 @@ from cents.datasets.utils import split_timeseries from cents.models.base import NormalizerModel -from cents.models.context import ContextModule +from cents.models.context import MLPContextModule # Import to trigger registration +from cents.models.context_registry import get_context_module_cls +from cents.models.stats_head_registry import register_stats_head, get_stats_head_cls from cents.models.registry import register_model -class _StatsHead(nn.Module): +@register_stats_head("default", "mlp") +class MLPStatsHead(nn.Module): """ Head module predicting summary statistics (mean, std, and optionally min/max z-scores) from context embedding. """ @@ -92,6 +95,7 @@ def __init__( hidden_dim: int = 512, time_series_dims: int = 2, do_scale: bool = True, + stats_head_type: str = "mlp", ): """ Args: @@ -99,11 +103,14 @@ def __init__( hidden_dim: Hidden dimension size for the stats head. time_series_dims: Number of time series dimensions. do_scale: Whether to include scaling predictions. + stats_head_type: Type of stats head to use (from registry). """ super().__init__() self.cond_module = cond_module self.embedding_dim = cond_module.embedding_dim - self.stats_head = _StatsHead( + # Use registry to get the stats head class + StatsHeadCls = get_stats_head_cls(stats_head_type) + self.stats_head = StatsHeadCls( embedding_dim=self.embedding_dim, hidden_dim=hidden_dim, time_series_dims=time_series_dims, @@ -157,17 +164,22 @@ def __init__( ] self.time_series_dims = dataset_cfg.time_series_dims self.do_scale = dataset_cfg.scale + + context_module_type = getattr(self.dataset_cfg, "context_module_type", "default") - self.context_module = ContextModule( - dataset_cfg.context_vars, - 256, - ) + # Use registry to get the context module class + ContextModuleCls = get_context_module_cls(context_module_type) + self.context_module = ContextModuleCls(self.dataset_cfg.context_vars, 256) + # Get stats head type from config + stats_head_type = getattr(self.dataset_cfg, "stats_head_type", "default") + self.normalizer_model = _NormalizerModule( cond_module=self.context_module, hidden_dim=512, time_series_dims=self.time_series_dims, do_scale=self.do_scale, + stats_head_type=stats_head_type, ) # Will be populated in setup() diff --git a/scripts/train.py b/scripts/train.py index f44aca5..2218f36 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -5,6 +5,8 @@ from cents.datasets.commercial import CommercialDataset from cents.trainer import Trainer from pytorch_lightning.callbacks import EarlyStopping +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) def main() -> None: @@ -12,7 +14,7 @@ def main() -> None: CR_LOSS_WEIGHT = 0.1 TC_LOSS_WEIGHT = 0.1 # Skip heavy processing for DDP compatibility - dataset = CommercialDataset(overrides=["skip_heavy_processing=True"]) + dataset = PecanStreetDataset(overrides=["skip_heavy_processing=True", "time_series_dims=1", "user_group=all"]) trainer_overrides = [ "trainer.max_epochs=5000", # "trainer.strategy=ddp_spawn", From 05d99504b107a171293accfd68e4ca5923d01907 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 29 Oct 2025 14:06:31 -0400 Subject: [PATCH 09/50] Changed train script to use argparser --- scripts/train.py | 60 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 2218f36..0aaade6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -6,31 +6,37 @@ from cents.trainer import Trainer from pytorch_lightning.callbacks import EarlyStopping import warnings +import argparse warnings.simplefilter(action='ignore', category=FutureWarning) -def main() -> None: - MODEL_NAME = "diffusion_ts" - CR_LOSS_WEIGHT = 0.1 - TC_LOSS_WEIGHT = 0.1 +def main(args) -> None: + MODEL_NAME = args.model_name + CR_LOSS_WEIGHT = args.cr_loss_weight + TC_LOSS_WEIGHT = args.tc_loss_weight # Skip heavy processing for DDP compatibility - dataset = PecanStreetDataset(overrides=["skip_heavy_processing=True", "time_series_dims=1", "user_group=all"]) + + if args.dataset == "pecanstreet": + dataset = PecanStreetDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"]) + elif args.dataset == "commercial": + dataset = CommercialDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}", "time_series_dims=1", "user_group=all"]) + else: + raise ValueError(f"Dataset {args.dataset} not supported") + trainer_overrides = [ - "trainer.max_epochs=5000", - # "trainer.strategy=ddp_spawn", - "trainer.devices=auto", # Exclude GPU 0, use GPUs 1,2,3 - # "trainer.devices=1", - "trainer.eval_after_training=True", - "train.accelerator=gpu", - # "train.accelerator=cpu", + f"trainer.max_epochs={args.epochs}", + f"trainer.strategy={args.ddp_strategy}", + f"trainer.devices={args.devices}", + f"trainer.eval_after_training={args.eval_after_training}", + f"train.accelerator={args.accelerator}", "trainer.early_stopping.patience=100", # Stop if no improvement for 100 epochs "trainer.early_stopping.monitor=train_loss", # Monitor training loss "trainer.early_stopping.mode=min", # Stop when loss stops decreasing - "trainer.enable_checkpointing=True", # Explicitly enable checkpointing + f"trainer.enable_checkpointing={args.enable_checkpointing}", # Explicitly enable checkpointing "trainer.logger=False", # Disable logger to see checkpoint messages - "wandb.enabled=False", - "wandb.project=cents", - "wandb.entity=pmfeen-massachusetts-institute-of-technology", + f"wandb.enabled={args.wandb_enabled}", + f"wandb.project={args.wandb_project}", + f"wandb.entity={args.wandb_entity}", f"model.context_reconstruction_loss_weight={CR_LOSS_WEIGHT}", f"model.tc_loss_weight={TC_LOSS_WEIGHT}", f"wandb.name=training_dai_{MODEL_NAME}_{datetime.now().strftime('%Y%m%d-%H%M%S')}_L{CR_LOSS_WEIGHT}_TC_{TC_LOSS_WEIGHT}_dim2", @@ -45,8 +51,22 @@ def main() -> None: trainer.fit() if __name__ == "__main__": - import os - # Enable CUDA debugging for better error messages - # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - # os.environ["TORCH_USE_CUDA_DSA"] = "1" + parser = argparse.ArgumentParser() + parser.add_argument("--devices", type=int, default="auto") + parser.add_argument("--accelerator", type=str, default="gpu") + parser.add_argument("--model_name", type=str, default="diffusion_ts") + parser.add_argument("--cr_loss_weight", type=float, default=0.1) + parser.add_argument("--tc_loss_weight", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="pecanstreet") + parser.add_argument("--epochs", type=int, default=5000) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--wandb-enabled", type=bool, default=False) + parser.add_argument("--wandb-project", type=str, default="cents") + parser.add_argument("--wandb-entity", type=str, default=None) + parser.add_argument("eval_after_training", type=bool, default=True) + parser.add_argument("skip_heavy_processing", type=bool, default=True) + parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_false") + parser.add_argument("enable_checkpointing", type=bool, default=True) + + args = parser.parse_args() main() From c9568e748b93fe24e84f14ed12e66c4712f85a37 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Thu, 30 Oct 2025 14:01:28 -0400 Subject: [PATCH 10/50] minor fixes --- scripts/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 0aaade6..6530140 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -63,10 +63,10 @@ def main(args) -> None: parser.add_argument("--wandb-enabled", type=bool, default=False) parser.add_argument("--wandb-project", type=str, default="cents") parser.add_argument("--wandb-entity", type=str, default=None) - parser.add_argument("eval_after_training", type=bool, default=True) - parser.add_argument("skip_heavy_processing", type=bool, default=True) + parser.add_argument("--eval_after_training", type=bool, default=True) + parser.add_argument("--skip_heavy_processing", type=bool, default=True) parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_false") - parser.add_argument("enable_checkpointing", type=bool, default=True) + parser.add_argument("--enable_checkpointing", type=bool, default=True) args = parser.parse_args() - main() + main(args) From 8dc315605a7f88570858847e0cb602487097976d Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 3 Nov 2025 16:16:24 -0500 Subject: [PATCH 11/50] Added new context embedder --- cents/config/dataset/pecanstreet.yaml | 2 +- cents/config/model/acgan.yaml | 2 +- cents/config/model/diffusion_ts.yaml | 2 +- cents/models/context.py | 57 +++++++++++++++++++++++++++ cents/models/normalizer.py | 2 +- 5 files changed, 61 insertions(+), 4 deletions(-) diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index 7de3611..57041f4 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -17,7 +17,7 @@ numeric_cols: ["total_square_footage", "house_construction_year"] user_group: all # non_pv_users, all, pv_users numeric_context_bins: 5 stats_head_type: mlp -context_module_type: mlp +context_module_type: sep_mlp context_vars: # for each desired context variable, add the name and number of categories month: 12 diff --git a/cents/config/model/acgan.yaml b/cents/config/model/acgan.yaml index fd0bc98..6adcc3c 100644 --- a/cents/config/model/acgan.yaml +++ b/cents/config/model/acgan.yaml @@ -2,7 +2,7 @@ _target_: generator.gan.acgan.ACGAN name: acgan noise_dim: 256 cond_emb_dim: 16 -context_module_type: mlp +context_module_type: sep_mlp include_auxiliary_losses: True context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 02ccf0f..2d5bea7 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -4,7 +4,7 @@ context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 noise_dim: 256 cond_emb_dim: 16 -context_module_type: mlp +context_module_type: sep_mlp n_layer_enc: 4 n_layer_dec: 5 d_model: 128 diff --git a/cents/models/context.py b/cents/models/context.py index f5eb704..25b811e 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -87,3 +87,60 @@ def forward( } return embedding, classification_logits +@register_context_module("default", "sep_mlp") +class SepMLPContextModule(BaseContextModule): + def __init__(self, context_vars: dict[str, int], embedding_dim: int, init_depth: int = 1, mixing_depth: int = 1) -> None: + super().__init__() + + self.embedding_dim = embedding_dim + + self.context_embeddings = nn.ModuleDict( + { + name: nn.Embedding(num_categories, embedding_dim) + for name, num_categories in context_vars.items() + } + ) + + self.init_mlps = nn.ModuleDict({ + name: nn.Sequential(*[ + layer + for _ in range(init_depth) + for layer in (nn.Linear(embedding_dim, 128), nn.ReLU(), nn.Linear(128, embedding_dim)) + ]) + for name in context_vars.keys() + }) + + + total_dim = embedding_dim * len(context_vars) + + self.mixing_mlp = nn.Sequential( + nn.Linear(total_dim, 128), + nn.ReLU(), + nn.Linear(128, embedding_dim)) + + self.classification_heads = nn.ModuleDict( + { + var_name: nn.Linear(embedding_dim, num_categories) + for var_name, num_categories in context_vars.items() + } + ) + + + def forward(self, context_vars): + encodings = { + name : layer(context_vars[name]) for name, layer in self.context_embeddings.items() + } + + embeddings = [ + layer(encodings[name]) for name, layer in self.init_mlps.items() + ] + + context_matrix = torch.cat(embeddings, dim=1) + embedding = self.mixing_mlp(context_matrix) + + classification_logits = { + var_name: head(embedding) + for var_name, head in self.classification_heads.items() + } + + return embedding, classification_logits \ No newline at end of file diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 11e5deb..765a861 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -10,7 +10,7 @@ from cents.datasets.utils import split_timeseries from cents.models.base import NormalizerModel -from cents.models.context import MLPContextModule # Import to trigger registration +from cents.models.context import MLPContextModule, SepMLPContextModule # Import to trigger registration from cents.models.context_registry import get_context_module_cls from cents.models.stats_head_registry import register_stats_head, get_stats_head_cls from cents.models.registry import register_model From 9eb35cab1010e3b1618691756759137b1a87f6a6 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 5 Nov 2025 10:27:24 -0500 Subject: [PATCH 12/50] Added registries that were untracked --- cents/models/context_registry.py | 55 +++++++++++++++++++++++++++++ cents/models/stats_head_registry.py | 55 +++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 cents/models/context_registry.py create mode 100644 cents/models/stats_head_registry.py diff --git a/cents/models/context_registry.py b/cents/models/context_registry.py new file mode 100644 index 0000000..8222e92 --- /dev/null +++ b/cents/models/context_registry.py @@ -0,0 +1,55 @@ +from typing import Dict + +_CONTEXT_MODULE_REGISTRY = {} + + +def register_context_module(*names): + """ + Decorator: registers a context module class under one or more names. + + Args: + *names: One or more names to register the class under. + + Example: + @register_context_module("default", "mlp") + class MLPContextModule(BaseContextModule): + pass + """ + + def decorator(cls): + for name in names: + _CONTEXT_MODULE_REGISTRY[name] = cls + return cls + + return decorator + + +def get_context_module_cls(key: str) -> type: + """ + Fetch the context module class for `key`. Raises if not found. + + Args: + key: The name of the context module to retrieve. + + Returns: + The context module class. + + Raises: + ValueError: If the key is not found in the registry. + """ + try: + return _CONTEXT_MODULE_REGISTRY[key] + except KeyError: + raise ValueError( + f"Unknown context module '{key}'. Available: {list(_CONTEXT_MODULE_REGISTRY.keys())}" + ) + + +def get_available_context_modules() -> list[str]: + """ + Get a list of all available context module names. + + Returns: + List of available context module names. + """ + return list(_CONTEXT_MODULE_REGISTRY.keys()) diff --git a/cents/models/stats_head_registry.py b/cents/models/stats_head_registry.py new file mode 100644 index 0000000..19cbf02 --- /dev/null +++ b/cents/models/stats_head_registry.py @@ -0,0 +1,55 @@ +from typing import Dict + +_STATS_HEAD_REGISTRY = {} + + +def register_stats_head(*names): + """ + Decorator: registers a stats head class under one or more names. + + Args: + *names: One or more names to register the class under. + + Example: + @register_stats_head("default", "mlp") + class MLPStatsHead(nn.Module): + pass + """ + + def decorator(cls): + for name in names: + _STATS_HEAD_REGISTRY[name] = cls + return cls + + return decorator + + +def get_stats_head_cls(key: str) -> type: + """ + Fetch the stats head class for `key`. Raises if not found. + + Args: + key: The name of the stats head to retrieve. + + Returns: + The stats head class. + + Raises: + ValueError: If the key is not found in the registry. + """ + try: + return _STATS_HEAD_REGISTRY[key] + except KeyError: + raise ValueError( + f"Unknown stats head '{key}'. Available: {list(_STATS_HEAD_REGISTRY.keys())}" + ) + + +def get_available_stats_heads() -> list[str]: + """ + Get a list of all available stats head names. + + Returns: + List of available stats head names. + """ + return list(_STATS_HEAD_REGISTRY.keys()) From 7ffc6dfdd611c3c6444c8eedb7a1fcea6ce9ac22 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 12 Nov 2025 21:24:13 -0500 Subject: [PATCH 13/50] stable normalized training --- cents/config/dataset/commercial.yaml | 10 +- cents/config/dataset/default.yaml | 3 +- cents/config/dataset/pecanstreet.yaml | 6 +- cents/config/trainer/normalizer.yaml | 2 +- cents/data_generator.py | 35 ++-- cents/datasets/timeseries_dataset.py | 108 ++++++++-- cents/datasets/utils.py | 10 +- cents/models/acgan.py | 29 ++- cents/models/base.py | 10 +- cents/models/context.py | 144 +++++++++++-- cents/models/diffusion_ts.py | 94 ++++++++- cents/models/normalizer.py | 284 +++++++++++++++++++++++--- cents/trainer.py | 25 ++- cents/utils/utils.py | 34 ++- scripts/eval_pretrained.py | 13 +- scripts/train.py | 10 +- 16 files changed, 712 insertions(+), 105 deletions(-) diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index d3c5fb5..86a8722 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -18,7 +18,7 @@ numeric_context_bins: 5 numeric_cols: ["sqft", "yearbuilt"] # Columns to bin as numeric reduce_cardinality: False stats_head_type: mlp -context_module_type: mlp +context_module_type: sep_mlp context_vars: # for each desired context variable, add the name and number of categories year: 2 @@ -26,7 +26,11 @@ context_vars: # for each desired context variable, add the name and number of ca weekday: 7 site_id: 19 primaryspaceusage: 16 + # sqft: 5 + # yearbuilt: 5 # sub_primaryspaceusage: 105 - sqft: 5 - yearbuilt: 5 + +continuous_context_vars: # for each desired continuous context variable, add the name and number of bins +- sqft +- yearbuilt diff --git a/cents/config/dataset/default.yaml b/cents/config/dataset/default.yaml index c91efd4..3317216 100644 --- a/cents/config/dataset/default.yaml +++ b/cents/config/dataset/default.yaml @@ -10,5 +10,6 @@ seq_len: 8 user_group: null numeric_context_bins: 5 -context_vars: {} +context_vars: {} # Dict mapping variable names to category counts (for categorical) or placeholders (for continuous) +continuous_context_vars: [] # Optional: list of variable names that should be kept as continuous (not binned) stats_head_type: mlp diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index 57041f4..d71666d 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -27,5 +27,7 @@ context_vars: # for each desired context variable, add the name and number of ca car1: 2 city: 7 state: 3 - total_square_footage: 5 - house_construction_year: 5 + +continuous_context_vars: +- total_square_footage +- house_construction_year diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index 1e27808..2250311 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -5,7 +5,7 @@ precision: 16-mixed log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 -n_epochs: 2000 +n_epochs: 1000 batch_size: 8192 lr: 1e-5 gradient_clip_val: 1.0 diff --git a/cents/data_generator.py b/cents/data_generator.py index b8042b5..31242a8 100644 --- a/cents/data_generator.py +++ b/cents/data_generator.py @@ -121,13 +121,15 @@ def set_dataset_spec( self.cfg.dataset = dataset_cfg self.ctx_code_book = ctx_codes - def set_context(self, auto_fill_missing: bool = False, **context_vars: int): + def set_context(self, auto_fill_missing: bool = False, **context_vars: Union[int, float]): """ Define a context vector for subsequent generation calls. Args: auto_fill_missing: If True, randomly sample missing context variables. **context_vars: Named codes for each context variable. + For categorical variables: integer codes (int). + For continuous variables: float values (float). Raises: RuntimeError: If dataset spec has not been set. @@ -139,25 +141,34 @@ def set_context(self, auto_fill_missing: bool = False, **context_vars: int): ) required = self.cfg.dataset.context_vars + continuous_vars = getattr(self.cfg.dataset, "continuous_context_vars", None) or [] if auto_fill_missing: for var, n in required.items(): - context_vars.setdefault(var, random.randrange(n)) + if var in continuous_vars: + # For continuous variables, sample from a reasonable range + # This is a simple default - users should provide actual values + context_vars[var] = random.uniform(0.0, 1.0) + else: + context_vars.setdefault(var, random.randrange(n)) else: missing = set(required) - set(context_vars) if missing: raise ValueError(f"Missing context vars: {missing}") + self._ctx_buff = {} for var, code in context_vars.items(): - max_cat = self.cfg.dataset.context_vars[var] - if not (0 <= code < max_cat): - raise ValueError( - f"Context '{var}' must be in [0, {max_cat}); got {code}." - ) - - self._ctx_buff = { - var: torch.tensor(code, device=self.device) - for var, code in context_vars.items() - } + if var in continuous_vars: + # Continuous variables: use float tensor (no validation needed) + self._ctx_buff[var] = torch.tensor(code, dtype=torch.float32, device=self.device) + else: + # Categorical variables: validate and use long tensor + if var in required: + max_cat = required[var] + if not (0 <= code < max_cat): + raise ValueError( + f"Context '{var}' must be in [0, {max_cat}); got {code}." + ) + self._ctx_buff[var] = torch.tensor(code, dtype=torch.long, device=self.device) @torch.no_grad() def generate(self, n: int = 128) -> "pd.DataFrame": diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 89dca05..fca4fee 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -12,6 +12,7 @@ from sklearn.cluster import KMeans from torch.utils.data import DataLoader, Dataset from omegaconf import ListConfig +import pickle from cents.datasets.utils import encode_context_variables from cents.models.normalizer import Normalizer @@ -90,6 +91,23 @@ def __init__( if not hasattr(self, "name"): self.name = "custom" + # Add continuous variables to context_vars if specified + continuous_vars = getattr(self.cfg, "continuous_context_vars", None) or [] + # Convert to plain Python list if it's a ListConfig from OmegaConf + if continuous_vars: + if isinstance(continuous_vars, ListConfig): + continuous_vars = [str(v) for v in continuous_vars] + elif isinstance(continuous_vars, list): + continuous_vars = [str(v) for v in continuous_vars] + else: + continuous_vars = [str(continuous_vars)] + else: + continuous_vars = [] + + # Ensure continuous variables are included in self.context_vars + if continuous_vars: + self.context_vars = list(self.context_vars) + [v for v in continuous_vars if v not in self.context_vars] + self.normalize = normalize self.scale = scale @@ -99,6 +117,10 @@ def __init__( # Preprocess and optionally encode context self.data = self._preprocess_data(data) + continuous_vars = getattr(self.cfg, "continuous_context_vars", None) or [] + if continuous_vars: + self._normalize_continuous_vars() + if size is not None: self.data = self.data.sample(size) print(f"Sampled {size} rows from dataset") @@ -111,9 +133,8 @@ def __init__( self._init_normalizer() cache_path = self._get_normalization_cache_path() - if cache_path.exists() and is_ddp_subprocess: - print(f"[DDP Subprocess] Loading pre-normalized data from cache") - import pickle + if cache_path.exists(): + print(f"[{'DDP Subprocess' if is_ddp_subprocess else 'Main Process'}] Loading pre-normalized data from cache") with open(cache_path, 'rb') as f: self.data = pickle.load(f) else: @@ -125,7 +146,6 @@ def __init__( # Save to cache for subprocesses (only main process) if not is_ddp_subprocess: cache_path.parent.mkdir(parents=True, exist_ok=True) - import pickle with open(cache_path, 'wb') as f: pickle.dump(self.data, f) print(f"[Main Process] Cached normalized data for subprocesses") @@ -180,13 +200,27 @@ def __getitem__(self, idx: int): Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - timeseries: Tensor of shape (seq_len, dims). - context_vars: Dict of context variable tensors. + Categorical variables are long tensors, continuous variables are float tensors. """ sample = self.data.iloc[idx] timeseries = torch.tensor(sample["timeseries"], dtype=torch.float32) - context_vars_dict = { - var: torch.tensor(sample[var], dtype=torch.long) - for var in self.context_vars - } + + continuous_vars = getattr(self.cfg, 'continuous_context_vars', None) or [] + context_vars_dict = {} + for var in self.context_vars: + if var in continuous_vars: + # Continuous variables: keep as float + val = sample[var] + # Check for NaN/Inf in the data itself + if isinstance(val, (float, int)) and (not isinstance(val, bool) and (np.isnan(val) or np.isinf(val))): + raise ValueError( + f"NaN/Inf detected in continuous variable '{var}' in dataset at index {idx}. " + f"Value: {val}. This should not happen if normalization was done correctly." + ) + context_vars_dict[var] = torch.tensor(val, dtype=torch.float32) + else: + # Categorical variables: use long + context_vars_dict[var] = torch.tensor(sample[var], dtype=torch.long) return timeseries, context_vars_dict def __getstate__(self): @@ -212,6 +246,7 @@ def get_train_dataloader( batch_size (int): Batch size. shuffle (bool): Whether to shuffle the data. num_workers (int): Number of worker processes. + persistent_workers (bool): Whether to keep workers alive between epochs. Returns: DataLoader: Configured data loader. @@ -302,6 +337,7 @@ def _encode_context_vars( ) -> Tuple[pd.DataFrame, Dict[str, Any]]: """ Encode and bin numeric or categorical context variables. + Continuous variables are kept as-is. Args: data (pd.DataFrame): Input DataFrame. @@ -309,15 +345,51 @@ def _encode_context_vars( Returns: Tuple of encoded DataFrame and mapping codes. """ + continuous_vars = getattr(self.cfg, 'continuous_context_vars', None) encoded_data, mapping = encode_context_variables( data=data, columns_to_encode=self.context_vars, bins=self.numeric_context_bins, numeric_cols=getattr(self.cfg, 'numeric_cols', None), + continuous_vars=continuous_vars, ) return encoded_data, mapping + def _normalize_continuous_vars(self): + """ + Normalize continuous context variables in the dataset using z-score normalization. + This is done once during dataset initialization, so models receive pre-normalized values. + """ + continuous_vars = getattr(self.cfg, "continuous_context_vars", None) or [] + if not continuous_vars: + return + + # Store stats for potential inverse transform if needed + self.continuous_var_stats = {} + + for var_name in continuous_vars: + if var_name in self.data.columns: + values = self.data[var_name] + # Compute mean and std + mean_val = float(values.mean()) + std_val = float(values.std()) + + # Avoid zero std + if std_val < 1e-8: + std_val = 1.0 + print(f"[Dataset] Warning: {var_name} has zero std, using std=1.0") + + # Store stats for reference + self.continuous_var_stats[var_name] = {'mean': mean_val, 'std': std_val} + + # Normalize the values in-place: (x - mean) / std + self.data[var_name] = (values - mean_val) / std_val + + print(f"[Dataset] Normalized {var_name}: mean={mean_val:.4f}, std={std_val:.4f}") + else: + print(f"[Dataset] Warning: Continuous variable {var_name} not found in data columns") + def _get_context_var_dict(self, data: pd.DataFrame) -> Dict[str, int]: """ Infer number of categories for each context variable. @@ -482,7 +554,8 @@ def _get_rarity_cache_path(self) -> str: """Get cache file path for rarity features.""" import hashlib # Create a hash based on dataset characteristics for cache key - cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{hash(str(sorted(self.context_vars)))}" + context_module_type = getattr(self.cfg, "context_module_type", None) + cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{str(sorted(self.context_vars))}_{context_module_type or ''}" cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8] cache_dir = os.path.join(ROOT_DIR, "cache", "rarity") os.makedirs(cache_dir, exist_ok=True) @@ -493,7 +566,9 @@ def _get_normalization_cache_path(self): import hashlib from pathlib import Path # Create hash based on dataset + normalizer characteristics - cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{self.normalize}_{self.scale}" + context_module_type = getattr(self.cfg, "context_module_type", None) + stats_head_type = getattr(self.cfg, "stats_head_type", None) + cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{self.normalize}_{self.scale}_{context_module_type or ''}_{stats_head_type or ''}" cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8] cache_dir = Path(ROOT_DIR) / "cache" / "normalized_data" cache_dir.mkdir(parents=True, exist_ok=True) @@ -551,8 +626,17 @@ def _init_normalizer(self) -> None: Path.home() / ".cache" / "cents" / "checkpoints" / self.name / "normalizer" ) normalizer_dir.mkdir(parents=True, exist_ok=True) + + # Get context_module_type and stats_head_type from config + context_module_type = getattr(self.cfg, "context_module_type", None) + stats_head_type = getattr(self.cfg, "stats_head_type", None) + cache_path = normalizer_dir / _ckpt_name( - self.name, "normalizer", self.time_series_dims + self.name, + "normalizer", + self.time_series_dims, + context_module_type=context_module_type, + stats_head_type=stats_head_type ) ncfg = get_normalizer_training_config() @@ -585,7 +669,7 @@ def _init_normalizer(self) -> None: devices=ncfg.devices, strategy=ncfg.strategy, log_every_n_steps=ncfg.log_every_n_steps, - logger=False, + logger=True, ) trainer.fit(self._normalizer) torch.save(self._normalizer.state_dict(), cache_path) diff --git a/cents/datasets/utils.py b/cents/datasets/utils.py index 682c4c3..244a52e 100644 --- a/cents/datasets/utils.py +++ b/cents/datasets/utils.py @@ -108,17 +108,20 @@ def split_dataset(dataset: Dataset, val_split: float = 0.1) -> Tuple[Dataset, Da def encode_context_variables( - data: pd.DataFrame, columns_to_encode: List[str], bins: int, numeric_cols: List[str] = None + data: pd.DataFrame, columns_to_encode: List[str], bins: int, numeric_cols: List[str] = None, continuous_vars: List[str] = None ) -> Tuple[pd.DataFrame, Dict[str, Dict[int, Any]]]: """ Encodes specified columns in the DataFrame either by binning numeric columns or by converting categorical/string columns to integer codes. For 'weekday' and 'month' columns, encoding follows chronological order. + Continuous variables are skipped and kept as-is. Args: data (pd.DataFrame): The input DataFrame containing the data. columns_to_encode (List[str]): List of column names to encode. bins (int): Number of bins for numeric columns. + numeric_cols (List[str], optional): Columns to treat as numeric for binning. + continuous_vars (List[str], optional): Columns to keep as continuous (skip encoding). Returns: Tuple[pd.DataFrame, Dict[str, Dict[int, Any]]]: @@ -127,6 +130,7 @@ def encode_context_variables( """ encoded_data = data.copy() mapping: Dict[str, Dict[int, Any]] = {} + continuous_vars = continuous_vars or [] # Define the chronological order for weekdays and months weekdays_order = [ @@ -154,6 +158,10 @@ def encode_context_variables( ] for col in columns_to_encode: + # Skip continuous variables - they should remain as float values + if col in continuous_vars: + continue + if numeric_cols and col in numeric_cols: # Numeric column: Perform binning # Handle NaN values by filling with median before binning diff --git a/cents/models/acgan.py b/cents/models/acgan.py index 9a3b9e7..c71963b 100644 --- a/cents/models/acgan.py +++ b/cents/models/acgan.py @@ -14,11 +14,12 @@ import pytorch_lightning as pl import torch import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim from omegaconf import DictConfig from cents.models.base import GenerativeModel -from cents.models.context import MLPContextModule # Import to trigger registration +from cents.models.context import MLPContextModule, SepMLPContextModule # Import to trigger registration from cents.models.context_registry import get_context_module_cls from cents.models.model_utils import total_correlation from cents.models.registry import register_model @@ -47,6 +48,7 @@ def __init__( context_module_type: str, context_vars: Optional[dict] = None, base_channels: int = 256, + continuous_vars: Optional[list] = None, ): super().__init__() self.noise_dim = noise_dim @@ -57,7 +59,7 @@ def __init__( self.base_channels = base_channels self.context_vars = context_vars - self.context_module = get_context_module_cls(context_module_type)(context_vars, embedding_dim) + self.context_module = get_context_module_cls(context_module_type)(context_vars, embedding_dim, continuous_vars=continuous_vars) in_dim = noise_dim + (embedding_dim if context_vars else 0) self.fc = nn.Linear(in_dim, self.final_window_length * base_channels) @@ -195,6 +197,7 @@ def __init__(self, cfg: DictConfig): # self.context_module = ContextModule( # cfg.dataset.context_vars, cfg.model.cond_emb_dim # ) + continuous_vars = getattr(cfg.dataset, "continuous_context_vars", None) or [] self.generator = Generator( noise_dim=cfg.model.noise_dim, embedding_dim=cfg.model.cond_emb_dim, @@ -202,14 +205,24 @@ def __init__(self, cfg: DictConfig): time_series_dims=cfg.dataset.time_series_dims, context_module_type=cfg.model.context_module_type, context_vars=cfg.dataset.context_vars, + continuous_vars=continuous_vars, ) + # Filter out continuous variables from context_vars for discriminator (only categorical needed) + categorical_context_vars = {k: v for k, v in cfg.dataset.context_vars.items() if k not in continuous_vars} self.discriminator = Discriminator( window_length=cfg.dataset.seq_len, time_series_dims=cfg.dataset.time_series_dims, - context_var_n_categories=cfg.dataset.context_vars, + context_var_n_categories=categorical_context_vars, ) self.adv_loss = nn.BCEWithLogitsLoss() self.aux_loss = nn.CrossEntropyLoss() + + # Get continuous variables from config to distinguish them in loss computation + self.continuous_context_vars = getattr(cfg.dataset, "continuous_context_vars", None) or [] + if isinstance(self.continuous_context_vars, (list, tuple)): + self.continuous_context_vars = set(self.continuous_context_vars) + else: + self.continuous_context_vars = set([self.continuous_context_vars]) if self.continuous_context_vars else set() def forward(self, noise: torch.Tensor, context_vars: dict): """ @@ -261,7 +274,15 @@ def training_step(self, batch: Any, batch_idx: int) -> None: if self.cfg.model.include_auxiliary_losses else 0.0 ) - g_ctx = sum(self.aux_loss(logits_ctx[v], ctx[v]) for v in logits_ctx) + # Context reconstruction loss: MSE for continuous, CSE for categorical + g_ctx = 0.0 + for v in logits_ctx: + if v in self.continuous_context_vars: + # Continuous variable: use MSE loss + g_ctx += F.mse_loss(logits_ctx[v], ctx[v].float()) + else: + # Categorical variable: use Cross-Entropy loss + g_ctx += self.aux_loss(logits_ctx[v], ctx[v].long()) h, _ = self.context_module(ctx) tc_term = ( diff --git a/cents/models/base.py b/cents/models/base.py index 16dcb35..2b4fa63 100644 --- a/cents/models/base.py +++ b/cents/models/base.py @@ -5,7 +5,7 @@ import torch from omegaconf import DictConfig -from cents.models.context import MLPContextModule # Import to trigger registration +from cents.models.context import MLPContextModule, SepMLPContextModule # Import to trigger registration from cents.models.context_registry import get_context_module_cls @@ -41,9 +41,15 @@ def __init__(self, cfg: DictConfig = None): emb_dim = getattr(cfg.model, "cond_emb_dim", 256) context_module_type = getattr(cfg.model, "context_module_type", "default") + # Get continuous variables from config if specified + continuous_vars = getattr(cfg.dataset, "continuous_context_vars", None) # Use registry to get the context module class ContextModuleCls = get_context_module_cls(context_module_type) - self.context_module = ContextModuleCls(cfg.dataset.context_vars, emb_dim) + self.context_module = ContextModuleCls( + cfg.dataset.context_vars, + emb_dim, + continuous_vars=continuous_vars + ) else: self.context_module = None diff --git a/cents/models/context.py b/cents/models/context.py index 25b811e..64ccfc5 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn from abc import abstractmethod +from typing import Optional from .context_registry import register_context_module class BaseContextModule(nn.Module): @@ -89,15 +90,43 @@ def forward( return embedding, classification_logits @register_context_module("default", "sep_mlp") class SepMLPContextModule(BaseContextModule): - def __init__(self, context_vars: dict[str, int], embedding_dim: int, init_depth: int = 1, mixing_depth: int = 1) -> None: + def __init__( + self, + context_vars: dict[str, int], + embedding_dim: int, + init_depth: int = 1, + mixing_depth: int = 1, + continuous_vars: Optional[list[str]] = None, + continuous_var_stats: Optional[dict[str, dict[str, float]]] = None # Deprecated, kept for backward compatibility + ) -> None: + """ + Initialize SepMLPContextModule. + + Args: + context_vars: Mapping of variable names to category counts. + embedding_dim: Size of embedding vectors. + init_depth: Depth of initial MLPs. + mixing_depth: Depth of mixing MLP. + continuous_vars: List of continuous variable names. + """ super().__init__() self.embedding_dim = embedding_dim + self.continuous_vars = continuous_vars or [] + self.categorical_vars = {k: v for k, v in context_vars.items() if k not in self.continuous_vars} self.context_embeddings = nn.ModuleDict( { name: nn.Embedding(num_categories, embedding_dim) - for name, num_categories in context_vars.items() + for name, num_categories in self.categorical_vars.items() + } + ) + + # For continuous variables, use a simple linear projection + self.continuous_projections = nn.ModuleDict( + { + name: nn.Linear(1, embedding_dim) + for name in self.continuous_vars } ) @@ -105,13 +134,22 @@ def __init__(self, context_vars: dict[str, int], embedding_dim: int, init_depth: name: nn.Sequential(*[ layer for _ in range(init_depth) - for layer in (nn.Linear(embedding_dim, 128), nn.ReLU(), nn.Linear(128, embedding_dim)) + for layer in (nn.Linear(embedding_dim, embedding_dim), nn.ReLU(), nn.Linear(embedding_dim, embedding_dim)) ]) - for name in context_vars.keys() + for name in self.categorical_vars.keys() }) + # Also create init MLPs for continuous variables + self.continuous_init_mlps = nn.ModuleDict({ + name: nn.Sequential(*[ + layer + for _ in range(init_depth) + for layer in (nn.Linear(embedding_dim, embedding_dim), nn.ReLU(), nn.Linear(embedding_dim, embedding_dim)) + ]) + for name in self.continuous_vars + }) - total_dim = embedding_dim * len(context_vars) + total_dim = embedding_dim * (len(self.categorical_vars) + len(self.continuous_vars)) self.mixing_mlp = nn.Sequential( nn.Linear(total_dim, 128), @@ -121,26 +159,104 @@ def __init__(self, context_vars: dict[str, int], embedding_dim: int, init_depth: self.classification_heads = nn.ModuleDict( { var_name: nn.Linear(embedding_dim, num_categories) - for var_name, num_categories in context_vars.items() + for var_name, num_categories in self.categorical_vars.items() + } + ) + + # Regression heads for continuous variables (output single value for MSE loss) + self.regression_heads = nn.ModuleDict( + { + var_name: nn.Linear(embedding_dim, 1) + for var_name in self.continuous_vars } ) - def forward(self, context_vars): - encodings = { - name : layer(context_vars[name]) for name, layer in self.context_embeddings.items() - } + #print(self.continuous_vars, "CONT VARS") + #print(context_vars, "VARS") + encodings = {} + + # Process categorical variables (only those present in context_vars) + for name, layer in self.context_embeddings.items(): + if name in context_vars: + encodings[name] = layer(context_vars[name]) + + #print(encodings, "ENCODINGS") + # Process continuous variables (only those present in context_vars) + for name, layer in self.continuous_projections.items(): + if name in context_vars: + # Reshape to (batch_size, 1) for linear layer + # Ensure proper shape and gradient flow + continuous_val = context_vars[name] + # Handle different input shapes + if continuous_val.dim() == 0: + # Scalar: add batch dimension + continuous_val = continuous_val.unsqueeze(0) + elif continuous_val.dim() == 1: + # 1D tensor: add feature dimension + continuous_val = continuous_val.unsqueeze(-1) + # Ensure float type while preserving gradients + if not continuous_val.is_floating_point(): + continuous_val = continuous_val.float() - embeddings = [ - layer(encodings[name]) for name, layer in self.init_mlps.items() - ] + if continuous_val.dim() == 1: + continuous_val = continuous_val.unsqueeze(-1) + encodings[name] = layer(continuous_val) + + embeddings = [] + # Apply init MLPs to categorical variables + for name, layer in self.init_mlps.items(): + embeddings.append(layer(encodings[name])) + + # Apply init MLPs to continuous variables + for name, layer in self.continuous_init_mlps.items(): + if name in encodings: + embedding_output = layer(encodings[name]) + # Check for NaN in embedding output + if torch.isnan(embedding_output).any(): + raise ValueError( + f"NaN detected in embedding output for continuous variable '{name}' " + f"after init MLP. This may indicate numerical instability in the MLP layers." + ) + embeddings.append(embedding_output) + + if not embeddings: + raise ValueError("No context variables found in context_vars dict") context_matrix = torch.cat(embeddings, dim=1) + + # Check for NaN before mixing MLP + if torch.isnan(context_matrix).any(): + raise ValueError( + f"NaN detected in context_matrix before mixing MLP. " + f"This suggests one of the context variable embeddings contains NaN." + ) + embedding = self.mixing_mlp(context_matrix) + + # Check for NaN after mixing MLP + if torch.isnan(embedding).any(): + raise ValueError( + f"NaN detected in final embedding after mixing MLP. " + f"Context matrix stats: mean={context_matrix.mean():.4f}, " + f"std={context_matrix.std():.4f}, " + f"min={context_matrix.min():.4f}, max={context_matrix.max():.4f}" + ) + #print(embedding, "post mixing") classification_logits = { var_name: head(embedding) for var_name, head in self.classification_heads.items() } + + # Regression outputs for continuous variables + regression_outputs = { + var_name: head(embedding).squeeze(-1) # Remove last dim to get (batch_size,) + for var_name, head in self.regression_heads.items() + } + + # Combine both into a single dict for backward compatibility + # The training step will need to distinguish between them + all_outputs = {**classification_logits, **regression_outputs} - return embedding, classification_logits \ No newline at end of file + return embedding, all_outputs \ No newline at end of file diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index ddd7e35..f6a1b22 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -132,11 +132,18 @@ def __init__(self, cfg: DictConfig): if self.loss_type == "l1": self.recon_loss_fn = F.l1_loss elif self.loss_type == "l2": - self.recon_loss_fn = F.mse_loss + self.recon_loss_fn = F.mse_loss # MSE for continuous RVs else: raise ValueError("Invalid loss type") self.auxiliary_loss = nn.CrossEntropyLoss() + + # Get continuous variables from config to distinguish them in loss computation + self.continuous_context_vars = getattr(cfg.dataset, "continuous_context_vars", None) or [] + if isinstance(self.continuous_context_vars, (list, tuple)): + self.continuous_context_vars = set(self.continuous_context_vars) + else: + self.continuous_context_vars = set([self.continuous_context_vars]) if self.continuous_context_vars else set() def predict_noise_from_start( self, x_t: torch.Tensor, t: torch.Tensor, x0: torch.Tensor @@ -215,17 +222,66 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di b = x.shape[0] t = torch.randint(0, self.num_timesteps, (b,), device=self.device) embedding, cond_classification_logits = self.context_module(context_vars) + + # Check embedding for NaN/Inf and extreme values + if embedding.isnan().any() or embedding.isinf().any(): + raise ValueError( + f"NaN/Inf detected in embedding from context module. " + f"NaN count: {embedding.isnan().sum()}, Inf count: {embedding.isinf().sum()}" + ) + + # Clamp extreme values to prevent numerical instability in transformer + # Don't fully normalize as that would change the learned embedding scale + # Just clip extreme outliers that could cause issues in attention/Fourier operations + # embedding_clamped = torch.clamp(embedding, min=-50.0, max=50.0) + + # # Log if clamping occurred (for debugging) + # if (embedding != embedding_clamped).any(): + # n_clamped = (embedding != embedding_clamped).sum().item() + # print(f"[Warning] Clamped {n_clamped} embedding values. " + # f"Original range: [{embedding.min():.4f}, {embedding.max():.4f}], " + # f"Clamped range: [{embedding_clamped.min():.4f}, {embedding_clamped.max():.4f}]") + + # embedding_normalized = embedding_clamped + noise = torch.randn_like(x) x_noisy = ( self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * noise ) - c = torch.cat( - [x_noisy, embedding.unsqueeze(1).repeat(1, self.seq_len, 1)], - dim=-1, - ) + + if x_noisy.isnan().any(): + raise ValueError("NaN detected in x_noisy") + + # Use normalized embedding for concatenation + embedding_expanded = embedding.unsqueeze(1).repeat(1, self.seq_len, 1) + c = torch.cat([x_noisy, embedding_expanded], dim=-1) + + if c.isnan().any() or c.isinf().any(): + raise ValueError( + f"NaN/Inf detected in concatenated input 'c'. " + f"x_noisy stats: mean={x_noisy.mean():.4f}, std={x_noisy.std():.4f}, " + f"min={x_noisy.min():.4f}, max={x_noisy.max():.4f}. " + f"embedding stats: mean={embedding.mean():.4f}, " + f"std={embedding.std():.4f}, min={embedding.min():.4f}, " + f"max={embedding.max():.4f}" + ) + + if t.isnan().any(): + raise ValueError("NaN detected in timestep 't'") + trend, season = self.model(c, t, padding_masks=None) + if trend.isnan().any(): + print("trend") + + if season.isnan().any(): + print("season") x_recon = self.fc(trend + season) + if x_recon.isnan().any(): + print("X RECON") + if x.isnan().any(): + print("x") + # print("REC LOSS", x_recon, x) rec_loss = self.recon_loss_fn(x_recon, x) return rec_loss, cond_classification_logits @@ -242,11 +298,28 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: """ ts_batch, cond_batch = batch rec_loss, cond_class_logits = self(ts_batch, cond_batch) + + # Check for NaN in reconstruction loss early + if torch.isnan(rec_loss) or torch.isinf(rec_loss): + # print(rec_loss) + # print(ts_batch, cond_batch) + raise ValueError( + f"NaN/Inf detected in rec_loss at batch {batch_idx}. " + ) + cond_loss = 0.0 - for var_name, logits in cond_class_logits.items(): + for var_name, outputs in cond_class_logits.items(): labels = cond_batch[var_name] - cond_loss += self.auxiliary_loss(logits, labels) + + if var_name in self.continuous_context_vars: + # Continuous variable: use MSE loss + # outputs is shape (batch_size,), labels is shape (batch_size,) + cond_loss += F.mse_loss(outputs, labels.float()) + else: + # Categorical variable: use Cross-Entropy loss + # outputs (logits) is shape (batch_size, num_categories), labels is shape (batch_size,) + cond_loss += self.auxiliary_loss(outputs, labels) h, _ = self.context_module(cond_batch) tc_term = ( @@ -258,6 +331,13 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: total_loss = ( rec_loss + self.context_reconstruction_loss_weight * cond_loss + tc_term ) + + # Check for NaN in total loss + if torch.isnan(total_loss) or torch.isinf(total_loss): + raise ValueError( + f"NaN/Inf detected in total_loss at batch {batch_idx}. " + f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss:.6f}, tc_term: {tc_term.item():.6f}" + ) self.log_dict( { "train_loss": total_loss, diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 765a861..49dd5bb 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -7,6 +7,8 @@ import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from tqdm import tqdm +from omegaconf import ListConfig + from cents.datasets.utils import split_timeseries from cents.models.base import NormalizerModel @@ -52,6 +54,39 @@ def __init__( in_dim = hidden_dim layers.append(nn.Linear(in_dim, out_dim)) self.net = nn.Sequential(*layers) + + # Initialize the output layer properly + # For log_sigma, initialize to small negative values so exp(log_sigma) starts around 1 + # This helps with training stability + self._initialize_output_layer() + + def _initialize_output_layer(self): + """Initialize the output layer to reasonable starting values.""" + # Get the last linear layer + output_layer = self.net[-1] + with torch.no_grad(): + # Initialize all weights with small values + # nn.init.xavier_uniform_(output_layer.weight, gain=1.0) + + # Initialize all biases to zero first + # nn.init.zeros_(output_layer.bias) + + # For log_sigma outputs (indices 1, 3, 5, ...), initialize bias to small negative + # This makes exp(log_sigma) start around 0.1-1.0 + if self.do_scale: + # Pattern: mu, log_sigma, z_min, z_max for each dimension + for dim_idx in range(self.time_series_dims): + # log_sigma is at index 1 + 4*dim_idx + log_sigma_idx = 1 + 4 * dim_idx + # Initialize to 3.0: exp(3.0) ≈ 20, closer to typical sigma ~27 + output_layer.bias[log_sigma_idx].fill_(3.0) + else: + # Pattern: mu, log_sigma for each dimension + for dim_idx in range(self.time_series_dims): + # log_sigma is at index 1 + 2*dim_idx + log_sigma_idx = 1 + 2 * dim_idx + # Initialize to 3.0: exp(3.0) ≈ 20, closer to typical sigma ~27 + output_layer.bias[log_sigma_idx].fill_(3.0) def forward(self, z: torch.Tensor): """ @@ -65,6 +100,7 @@ def forward(self, z: torch.Tensor): pred_sigma: Predicted standard deviations, shape (batch_size, time_series_dims). pred_z_min: Predicted min z-scores, or None if do_scale=False. pred_z_max: Predicted max z-scores, or None if do_scale=False. + pred_log_sigma_unclamped: Unclamped log_sigma for loss computation. """ out = self.net(z) batch_size = out.size(0) @@ -80,8 +116,16 @@ def forward(self, z: torch.Tensor): pred_log_sigma = out[:, 1, :] pred_z_min = None pred_z_max = None - pred_sigma = torch.exp(pred_log_sigma) - return pred_mu, pred_sigma, pred_z_min, pred_z_max + + # Store unclamped version for loss computation BEFORE clamping + # This must be done before any operations that might break the computation graph + pred_log_sigma_unclamped = pred_log_sigma + + # Clamp log_sigma to prevent exp() from producing infinity + # exp(88) ≈ 1.6e38 (near float32 max), so clamp to reasonable range + pred_log_sigma_clamped = torch.clamp(pred_log_sigma, min=-10.0, max=10.0) + pred_sigma = torch.exp(pred_log_sigma_clamped) + return pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped class _NormalizerModule(nn.Module): @@ -125,8 +169,16 @@ def forward(self, cat_vars_dict: dict): cat_vars_dict: Mapping of context variable names to label tensors. Returns: - Tuple of (pred_mu, pred_sigma, pred_z_min, pred_z_max). + Tuple of (pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped). """ + # Ensure all tensors in the dict are on the same device and properly connected + # This helps with DataLoader multiprocessing issues + device = next(self.cond_module.parameters()).device + cat_vars_dict = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in cat_vars_dict.items() + } + embedding, _ = self.cond_module(cat_vars_dict) return self.stats_head(embedding) @@ -158,7 +210,24 @@ def __init__( self.normalizer_training_cfg = normalizer_training_cfg self.dataset = dataset - self.context_vars = list(dataset_cfg.context_vars.keys()) + # Get continuous variables from config if specified + continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] + # Convert to plain Python list if it's a ListConfig from OmegaConf + if continuous_vars: + if isinstance(continuous_vars, ListConfig): + continuous_vars = [str(v) for v in continuous_vars] # Ensure strings + elif isinstance(continuous_vars, list): + continuous_vars = [str(v) for v in continuous_vars] # Ensure strings + else: + continuous_vars = [str(continuous_vars)] + else: + continuous_vars = [] + + # Include both categorical and continuous variables in context_vars + # Ensure all are plain Python strings + categorical_vars = [str(k) for k in dataset_cfg.context_vars.keys()] + self.context_vars = categorical_vars + continuous_vars + self.time_series_cols = dataset_cfg.time_series_columns[ : dataset_cfg.time_series_dims ] @@ -169,27 +238,89 @@ def __init__( # Use registry to get the context module class ContextModuleCls = get_context_module_cls(context_module_type) - self.context_module = ContextModuleCls(self.dataset_cfg.context_vars, 256) + # Create context module - it will be stored in normalizer_model.cond_module + context_module = ContextModuleCls( + self.dataset_cfg.context_vars, + 256, + continuous_vars=continuous_vars + ) # Get stats head type from config stats_head_type = getattr(self.dataset_cfg, "stats_head_type", "default") self.normalizer_model = _NormalizerModule( - cond_module=self.context_module, + cond_module=context_module, hidden_dim=512, time_series_dims=self.time_series_dims, do_scale=self.do_scale, stats_head_type=stats_head_type, ) + self.context_module = self.normalizer_model.cond_module # Will be populated in setup() self.group_stats = {} + self._verify_parameters() + + def _verify_parameters(self): + """ + Verify that all parameters including context module are registered. + This helps debug parameter counting issues. + """ + all_param_names = [name for name, _ in self.named_parameters()] + context_param_names = [name for name in all_param_names if 'cond_module' in name or 'context_module' in name] + stats_head_param_names = [name for name in all_param_names if 'stats_head' in name] + + if not context_param_names: + raise RuntimeError( + "Context module parameters not found! " + "Expected parameters with 'cond_module' in name. " + f"Found parameter names: {all_param_names[:10]}..." + ) + + print(f"[Normalizer] Found {len(context_param_names)} context module parameters") + print(f"[Normalizer] Found {len(stats_head_param_names)} stats head parameters") + print(f"[Normalizer] Total trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}") def setup(self, stage: Optional[str] = None): """ Lightning hook: compute group statistics before training. """ self.group_stats = self._compute_group_stats() + + # Log initial predictions to check if model is in the right ballpark + if stage == "fit" or stage is None: + self._log_initial_predictions() + + def _log_initial_predictions(self): + """Log initial model predictions to diagnose initialization issues.""" + self.eval() + with torch.no_grad(): + # Get a sample batch + dataloader = self.train_dataloader() + batch = next(iter(dataloader)) + cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch + + # Move to device + device = next(self.parameters()).device + cat_vars_dict = { + k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in cat_vars_dict.items() + } + mu_t = mu_t.to(device) + sigma_t = sigma_t.to(device) + + pred_mu, pred_sigma, pred_z_min, pred_z_max, _ = self(cat_vars_dict) + + print(f"\n[Initial Predictions]") + print(f" Target mu: mean={mu_t.mean().item():.4f}, std={mu_t.std().item():.4f}, range=[{mu_t.min().item():.4f}, {mu_t.max().item():.4f}]") + print(f" Predicted mu: mean={pred_mu.mean().item():.4f}, std={pred_mu.std().item():.4f}, range=[{pred_mu.min().item():.4f}, {pred_mu.max().item():.4f}]") + print(f" Target sigma: mean={sigma_t.mean().item():.4f}, std={sigma_t.std().item():.4f}, range=[{sigma_t.min().item():.4f}, {sigma_t.max().item():.4f}]") + print(f" Predicted sigma: mean={pred_sigma.mean().item():.4f}, std={pred_sigma.std().item():.4f}, range=[{pred_sigma.min().item():.4f}, {pred_sigma.max().item():.4f}]") + print(f" Initial loss_mu: {F.mse_loss(pred_mu, mu_t).item():.6f}") + print(f" Initial loss_sigma: {F.mse_loss(pred_sigma, sigma_t).item():.6f}") + print() + + self.train() def forward(self, cat_vars_dict: dict): """ @@ -199,7 +330,7 @@ def forward(self, cat_vars_dict: dict): cat_vars_dict: Mapping of context variable names to label tensors. Returns: - Tuple of (pred_mu, pred_sigma, pred_z_min, pred_z_max). + Tuple of (pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped). """ return self.normalizer_model(cat_vars_dict) @@ -215,18 +346,55 @@ def training_step(self, batch, batch_idx: int): loss tensor. """ cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch - pred_mu, pred_sigma, pred_z_min, pred_z_max = self(cat_vars_dict) + pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped = self(cat_vars_dict) + # Use standard MSE loss for mu loss_mu = F.mse_loss(pred_mu, mu_t) - loss_sigma = F.mse_loss(pred_sigma, sigma_t) + + # Use log-space loss for sigma - this is more numerically stable + # and handles scale differences better + target_log_sigma = torch.log(sigma_t + 1e-8) # Add small epsilon to avoid log(0) + loss_sigma = F.mse_loss(pred_log_sigma_unclamped, target_log_sigma) + total_loss = loss_mu + loss_sigma if self.do_scale: - total_loss += F.mse_loss(pred_z_min, zmin_t) + F.mse_loss( - pred_z_max, zmax_t + if torch.isnan(pred_z_min).any() or torch.isnan(pred_z_max).any(): + raise ValueError( + f"NaN detected in scale predictions at batch {batch_idx}" + ) + loss_zmin = F.mse_loss(pred_z_min, zmin_t) + loss_zmax = F.mse_loss(pred_z_max, zmax_t) + total_loss += loss_zmin + loss_zmax + else: + loss_zmin = torch.tensor(0.0, device=total_loss.device) + loss_zmax = torch.tensor(0.0, device=total_loss.device) + + # Check for NaN in loss + if torch.isnan(total_loss) or torch.isinf(total_loss): + raise ValueError( + f"NaN/Inf loss detected at batch {batch_idx}. " + f"loss_mu: {loss_mu.item():.6f}, loss_sigma: {loss_sigma.item():.6f}" ) - + + # Log individual components to understand what's happening self.log("train_loss", total_loss, prog_bar=True) + self.log("loss_mu", loss_mu, on_step=True, on_epoch=True, prog_bar=False) + self.log("loss_sigma", loss_sigma, on_step=True, on_epoch=True, prog_bar=False) + if self.do_scale: + self.log("loss_zmin", loss_zmin, on_step=True, on_epoch=True, prog_bar=False) + self.log("loss_zmax", loss_zmax, on_step=True, on_epoch=True, prog_bar=False) + + # Log prediction statistics to monitor if model is learning + if batch_idx % 100 == 0: # Log every 100 batches to avoid spam + with torch.no_grad(): + self.log("pred_mu_mean", pred_mu.mean(), on_step=True, on_epoch=False) + self.log("pred_mu_std", pred_mu.std(), on_step=True, on_epoch=False) + self.log("pred_sigma_mean", pred_sigma.mean(), on_step=True, on_epoch=False) + self.log("pred_sigma_std", pred_sigma.std(), on_step=True, on_epoch=False) + self.log("target_mu_mean", mu_t.mean(), on_step=True, on_epoch=False) + self.log("target_sigma_mean", sigma_t.mean(), on_step=True, on_epoch=False) + return total_loss def configure_optimizers(self): @@ -236,7 +404,47 @@ def configure_optimizers(self): Returns: Adam optimizer instance. """ - return torch.optim.Adam(self.parameters(), lr=self.normalizer_training_cfg.lr) + optimizer = torch.optim.Adam( + self.parameters(), + lr=self.normalizer_training_cfg.lr, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0 + ) + + return optimizer + + def on_train_batch_end(self, outputs, batch, batch_idx): + """ + Monitor gradients after each training step to diagnose training issues. + """ + if batch_idx % 100 == 0: # Check every 100 batches + total_norm = 0.0 + param_count = 0 + zero_grad_count = 0 + + for name, param in self.named_parameters(): + if param.grad is not None: + param_norm = param.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + param_count += 1 + if param_norm.item() < 1e-8: + zero_grad_count += 1 + else: + # Parameter has no gradient - this might indicate a problem + if 'cond_module' in name or 'stats_head' in name: + # Only warn about important parameters + pass + + total_norm = total_norm ** (1. / 2) + + self.log("grad_norm", total_norm, on_step=True, on_epoch=False) + self.log("params_with_grad", param_count, on_step=True, on_epoch=False) + + if total_norm < 1e-6: + print(f"[Warning] Very small gradient norm at batch {batch_idx}: {total_norm:.2e}") + if zero_grad_count > 0: + print(f"[Warning] {zero_grad_count} parameters have near-zero gradients at batch {batch_idx}") def train_dataloader(self): """ @@ -247,8 +455,10 @@ def train_dataloader(self): ds, batch_size=self.normalizer_training_cfg.batch_size, shuffle=True, - num_workers=23, - persistent_workers=True + num_workers=8, # Reduce workers to avoid synchronization issues + persistent_workers=True, + pin_memory=torch.cuda.is_available(), # Helps with GPU transfer + prefetch_factor=2, # Reduce prefetch to avoid memory issues ) def _compute_group_stats(self) -> dict: @@ -317,15 +527,18 @@ def _create_training_dataset(self) -> Dataset: ) in self.group_stats.items() ] + continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] + class _TrainSet(Dataset): """ Adapter Dataset to wrap group_stats tuples for DataLoader. """ - def __init__(self, samples, context_vars, do_scale): + def __init__(self, samples, context_vars, do_scale, continuous_vars): self.samples = samples self.context_vars = context_vars self.do_scale = do_scale + self.continuous_vars = continuous_vars def __len__(self) -> int: return len(self.samples) @@ -345,17 +558,20 @@ def __getitem__(self, idx: int): zmax_t: True max z-score tensor or None. """ ctx_tuple, mu_arr, sigma_arr, zmin_arr, zmax_arr = self.samples[idx] - cat_vars_dict = { - var_name: torch.tensor(ctx_tuple[i], dtype=torch.long) - for i, var_name in enumerate(self.context_vars) - } + cat_vars_dict = {} + + for i, var_name in enumerate(self.context_vars): + if var_name in self.continuous_vars: + cat_vars_dict[var_name] = torch.tensor(ctx_tuple[i], dtype=torch.float32) + else: + cat_vars_dict[var_name] = torch.tensor(ctx_tuple[i], dtype=torch.long) mu_t = torch.from_numpy(mu_arr).float() sigma_t = torch.from_numpy(sigma_arr).float() zmin_t = torch.from_numpy(zmin_arr).float() if self.do_scale else None zmax_t = torch.from_numpy(zmax_arr).float() if self.do_scale else None return cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t - return _TrainSet(data_tuples, self.context_vars, self.do_scale) + return _TrainSet(data_tuples, self.context_vars, self.do_scale, continuous_vars) def transform(self, df: pd.DataFrame) -> pd.DataFrame: """ @@ -381,13 +597,16 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: ) df_out = df.copy() self.eval() + continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] with torch.no_grad(): for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Normalizing"): - ctx = { - v: torch.tensor(row[v], dtype=torch.long).unsqueeze(0) - for v in self.context_vars - } - mu, sigma, zmin, zmax = self(ctx) + ctx = {} + for v in self.context_vars: + if v in continuous_vars: + ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + else: + ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) + mu, sigma, zmin, zmax, _ = self(ctx) mu, sigma = mu[0].cpu().numpy(), sigma[0].cpu().numpy() for d, col in enumerate(self.time_series_cols): @@ -423,13 +642,16 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: df_out = df.copy() self.eval() + continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] with torch.no_grad(): for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Inverse normalizing"): - ctx = { - v: torch.tensor(row[v], dtype=torch.long).unsqueeze(0) - for v in self.context_vars - } - mu, sigma, zmin, zmax = self(ctx) + ctx = {} + for v in self.context_vars: + if v in continuous_vars: + ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + else: + ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) + mu, sigma, zmin, zmax, _ = self(ctx) mu, sigma = mu[0].cpu().numpy(), sigma[0].cpu().numpy() for d, col in enumerate(self.time_series_cols): diff --git a/cents/trainer.py b/cents/trainer.py index 17d2d0a..e032099 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -194,13 +194,30 @@ def _instantiate_trainer(self) -> pl.Trainer: """ tc = self.cfg.trainer callbacks = [] + # Build filename with optional context_module_type + filename_parts = [ + self.cfg.dataset.name, + self.model_type, + f"dim{self.cfg.dataset.time_series_dims}" + ] + + # Add context_module_type if available (from model or dataset config) + context_module_type = getattr( + self.cfg.model, "context_module_type", + getattr(self.cfg.dataset, "context_module_type", None) + ) + if context_module_type: + filename_parts.append(f"ctx{context_module_type}") + + # Add stats_head_type if available (typically in dataset config for normalizer) + stats_head_type = getattr(self.cfg.dataset, "stats_head_type", None) + if stats_head_type: + filename_parts.append(f"stats{stats_head_type}") + callbacks.append( ModelCheckpoint( dirpath=self.cfg.run_dir, - filename=( - f"{self.cfg.dataset.name}_{self.model_type}" - f"_dim{self.cfg.dataset.time_series_dims}" - ), + filename="_".join(filename_parts), save_last=tc.checkpoint.save_last, save_on_train_epoch_end=True, ### Perhaps excessive ) diff --git a/cents/utils/utils.py b/cents/utils/utils.py index 2beb7ff..70088dc 100644 --- a/cents/utils/utils.py +++ b/cents/utils/utils.py @@ -7,8 +7,38 @@ ROOT_DIR = Path(__file__).parent.parent -def _ckpt_name(dataset: str, model: str, dims: int, *, ext: str = "ckpt") -> str: - return f"{dataset}_{model}_dim{dims}.{ext}" +def _ckpt_name( + dataset: str, + model: str, + dims: int, + *, + ext: str = "ckpt", + context_module_type: str = None, + stats_head_type: str = None +) -> str: + """ + Generate checkpoint filename with optional context_module_type and stats_head_type. + + Args: + dataset: Dataset name + model: Model name + dims: Number of dimensions + ext: File extension (default: "ckpt") + context_module_type: Optional context module type (e.g., "mlp", "sep_mlp") + stats_head_type: Optional stats head type (e.g., "mlp") + + Returns: + Formatted checkpoint filename + """ + parts = [dataset, model, f"dim{dims}"] + + if context_module_type: + parts.append(f"ctx{context_module_type}") + + if stats_head_type: + parts.append(f"stats{stats_head_type}") + + return "_".join(parts) + f".{ext}" def parse_dims_from_name(model_name: str) -> str: diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index e49354d..3d3a1f6 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -1,5 +1,6 @@ import logging from datetime import datetime +from typing import override # import wandb from omegaconf import OmegaConf @@ -12,7 +13,7 @@ from pathlib import Path import torch -MODEL_KEY = "acgan" +MODEL_KEY = "diffusion_ts" DATASET_OVERRIDES = [ "max_samples=10000", "skip_heavy_processing=True" @@ -26,14 +27,16 @@ HOME = Path.home() def main() -> None: - model_ckpt = HOME / f"Cents/cents/outputs/{MODEL_KEY}_pecanstreet_all/2025-10-27_10-09-04/pecanstreet_acgan_dim1.ckpt" + + model_ckpt = "cents/outputs/diffusion_ts_commercial_all/2025-11-07_15-09-33/commercial_diffusion_ts_dim1_ctxsep_mlp_statsmlp.ckpt" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) print("Loading dataset...") - dataset = PecanStreetDataset(overrides=DATASET_OVERRIDES + PECAN_OVERRIDES) + # dataset = PecanStreetDataset(overrides=DATASET_OVERRIDES + PECAN_OVERRIDES) + dataset = CommercialDataset(overrides = DATASET_OVERRIDES) - normalizer_ckpt = HOME / ".cache/cents/checkpoints/pecanstreet/normalizer/pecanstreet_normalizer_dim1.ckpt" + normalizer_ckpt = HOME / ".cache/cents/checkpoints/commercial/normalizer/commercial_normalizer_dim1_ctxsep_mlp_statsmlp.ckpt" # Build a minimal cfg for evaluator and generator eval_cfg = load_yaml("cents/config/evaluator/default.yaml") top_cfg = load_yaml("cents/config/config.yaml") @@ -54,7 +57,7 @@ def main() -> None: cfg.eval_disentanglement = eval_cfg.get("eval_disentanglement", True) cfg.job_name = eval_cfg.get("job_name", "default_job") cfg.save_results = True - cfg.save_dir = HOME / f"Cents/cents/outputs/{MODEL_KEY}_pecanstreet_all/2025-10-27_10-09-04/eval" + cfg.save_dir = HOME / f"cents/outputs/diffusion_ts_commercial_all/2025-11-07_15-09-33/eval" print("Dataset spec set. Setting up DataGenerator...") # Use the fixed checkpoint with DataGenerator diff --git a/scripts/train.py b/scripts/train.py index 6530140..ff0ec75 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -17,16 +17,18 @@ def main(args) -> None: # Skip heavy processing for DDP compatibility if args.dataset == "pecanstreet": - dataset = PecanStreetDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"]) + dataset = PecanStreetDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}, time_series_dims=1, user_group=all"]) elif args.dataset == "commercial": - dataset = CommercialDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}", "time_series_dims=1", "user_group=all"]) + dataset = CommercialDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"]) else: raise ValueError(f"Dataset {args.dataset} not supported") + print("Initialized Dataset") + trainer_overrides = [ f"trainer.max_epochs={args.epochs}", f"trainer.strategy={args.ddp_strategy}", - f"trainer.devices={args.devices}", + f"trainer.devices=0,1", f"trainer.eval_after_training={args.eval_after_training}", f"train.accelerator={args.accelerator}", "trainer.early_stopping.patience=100", # Stop if no improvement for 100 epochs @@ -52,7 +54,7 @@ def main(args) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--devices", type=int, default="auto") + parser.add_argument("--devices", type=str, default="auto") parser.add_argument("--accelerator", type=str, default="gpu") parser.add_argument("--model_name", type=str, default="diffusion_ts") parser.add_argument("--cr_loss_weight", type=float, default=0.1) From ea2ba4474af8fbe38e66cb3056191fbc3cddc566 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Thu, 13 Nov 2025 13:32:23 -0500 Subject: [PATCH 14/50] Stabilized training with cont. cvs --- cents/datasets/timeseries_dataset.py | 9 +++ cents/models/diffusion_ts.py | 98 +++++++++++++--------------- 2 files changed, 54 insertions(+), 53 deletions(-) diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index fca4fee..627f42f 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -251,6 +251,15 @@ def get_train_dataloader( Returns: DataLoader: Configured data loader. """ + continuous_vars = getattr(self.cfg, "continuous_context_vars", None) or [] + + # for col in continuous_vars: + # if col in self.data.columns: + # print(self.data[col].mean()) + self._normalize_continuous_vars() + # for col in continuous_vars: + # if col in self.data.columns: + # print(self.data[col].mean()) return DataLoader( self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, persistent_workers=persistent_workers ) diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index f6a1b22..70d67f9 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -140,10 +140,6 @@ def __init__(self, cfg: DictConfig): # Get continuous variables from config to distinguish them in loss computation self.continuous_context_vars = getattr(cfg.dataset, "continuous_context_vars", None) or [] - if isinstance(self.continuous_context_vars, (list, tuple)): - self.continuous_context_vars = set(self.continuous_context_vars) - else: - self.continuous_context_vars = set([self.continuous_context_vars]) if self.continuous_context_vars else set() def predict_noise_from_start( self, x_t: torch.Tensor, t: torch.Tensor, x0: torch.Tensor @@ -224,11 +220,11 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di embedding, cond_classification_logits = self.context_module(context_vars) # Check embedding for NaN/Inf and extreme values - if embedding.isnan().any() or embedding.isinf().any(): - raise ValueError( - f"NaN/Inf detected in embedding from context module. " - f"NaN count: {embedding.isnan().sum()}, Inf count: {embedding.isinf().sum()}" - ) + # if embedding.isnan().any() or embedding.isinf().any(): + # raise ValueError( + # f"NaN/Inf detected in embedding from context module. " + # f"NaN count: {embedding.isnan().sum()}, Inf count: {embedding.isinf().sum()}" + # ) # Clamp extreme values to prevent numerical instability in transformer # Don't fully normalize as that would change the learned embedding scale @@ -250,37 +246,37 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * noise ) - if x_noisy.isnan().any(): - raise ValueError("NaN detected in x_noisy") + # if x_noisy.isnan().any(): + # raise ValueError("NaN detected in x_noisy") # Use normalized embedding for concatenation embedding_expanded = embedding.unsqueeze(1).repeat(1, self.seq_len, 1) c = torch.cat([x_noisy, embedding_expanded], dim=-1) - if c.isnan().any() or c.isinf().any(): - raise ValueError( - f"NaN/Inf detected in concatenated input 'c'. " - f"x_noisy stats: mean={x_noisy.mean():.4f}, std={x_noisy.std():.4f}, " - f"min={x_noisy.min():.4f}, max={x_noisy.max():.4f}. " - f"embedding stats: mean={embedding.mean():.4f}, " - f"std={embedding.std():.4f}, min={embedding.min():.4f}, " - f"max={embedding.max():.4f}" - ) + # if c.isnan().any() or c.isinf().any(): + # raise ValueError( + # f"NaN/Inf detected in concatenated input 'c'. " + # f"x_noisy stats: mean={x_noisy.mean():.4f}, std={x_noisy.std():.4f}, " + # f"min={x_noisy.min():.4f}, max={x_noisy.max():.4f}. " + # f"embedding stats: mean={embedding.mean():.4f}, " + # f"std={embedding.std():.4f}, min={embedding.min():.4f}, " + # f"max={embedding.max():.4f}" + # ) - if t.isnan().any(): - raise ValueError("NaN detected in timestep 't'") + # if t.isnan().any(): + # raise ValueError("NaN detected in timestep 't'") trend, season = self.model(c, t, padding_masks=None) - if trend.isnan().any(): - print("trend") + # if trend.isnan().any(): + # print("trend") - if season.isnan().any(): - print("season") + # if season.isnan().any(): + # print("season") x_recon = self.fc(trend + season) - if x_recon.isnan().any(): - print("X RECON") - if x.isnan().any(): - print("x") + # if x_recon.isnan().any(): + # print("X RECON") + # if x.isnan().any(): + # print("x") # print("REC LOSS", x_recon, x) rec_loss = self.recon_loss_fn(x_recon, x) return rec_loss, cond_classification_logits @@ -299,27 +295,23 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: ts_batch, cond_batch = batch rec_loss, cond_class_logits = self(ts_batch, cond_batch) - # Check for NaN in reconstruction loss early - if torch.isnan(rec_loss) or torch.isinf(rec_loss): - # print(rec_loss) - # print(ts_batch, cond_batch) - raise ValueError( - f"NaN/Inf detected in rec_loss at batch {batch_idx}. " - ) - cond_loss = 0.0 + for var_name, outputs in cond_class_logits.items(): labels = cond_batch[var_name] if var_name in self.continuous_context_vars: - # Continuous variable: use MSE loss - # outputs is shape (batch_size,), labels is shape (batch_size,) - cond_loss += F.mse_loss(outputs, labels.float()) + loss = F.mse_loss(outputs, labels.float()) else: - # Categorical variable: use Cross-Entropy loss - # outputs (logits) is shape (batch_size, num_categories), labels is shape (batch_size,) - cond_loss += self.auxiliary_loss(outputs, labels) + loss = self.auxiliary_loss(outputs, labels) + + cond_loss += loss.mean() + + # if var_name in self.continuous_context_vars: + # print(var_name) + # print(loss) + # print(outputs.mean(), labels.mean()) h, _ = self.context_module(cond_batch) tc_term = ( @@ -332,17 +324,17 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: rec_loss + self.context_reconstruction_loss_weight * cond_loss + tc_term ) - # Check for NaN in total loss - if torch.isnan(total_loss) or torch.isinf(total_loss): - raise ValueError( - f"NaN/Inf detected in total_loss at batch {batch_idx}. " - f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss:.6f}, tc_term: {tc_term.item():.6f}" - ) + # # Check for NaN in total loss + # if torch.isnan(total_loss) or torch.isinf(total_loss): + # raise ValueError( + # f"NaN/Inf detected in total_loss at batch {batch_idx}. " + # f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss:.6f}, tc_term: {tc_term.item():.6f}" + # ) self.log_dict( { - "train_loss": total_loss, - "rec_loss": rec_loss, - "cond_loss": cond_loss, + "train_loss": total_loss.item(), + "rec_loss": rec_loss.item(), + "cond_loss": cond_loss.item(), "tc_loss": tc_term, }, prog_bar=True, From f0f964752e3a5471b1ff3d15491438b6da6d4999 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 19 Nov 2025 12:35:49 -0500 Subject: [PATCH 15/50] Began adding airquality dataset --- cents/config/dataset/airquality.yaml | 32 +++++++ cents/datasets/airquality.py | 132 +++++++++++++++++++++++++++ scripts/eval_pretrained.py | 11 ++- scripts/train.py | 2 +- 4 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 cents/config/dataset/airquality.yaml create mode 100644 cents/datasets/airquality.py diff --git a/cents/config/dataset/airquality.yaml b/cents/config/dataset/airquality.yaml new file mode 100644 index 0000000..92839c5 --- /dev/null +++ b/cents/config/dataset/airquality.yaml @@ -0,0 +1,32 @@ +name: airquality +geography: null +normalize: True +scale: True +use_learned_normalizer: True +threshold: 8 +seq_len: 24 +shuffle: True +skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP) +max_samples: null # Limit dataset size (null = use all data) +path: "./data/airquality" +target_time_series_columns: "PM2.5" +context_time_series_columns: ["TEMP", "DEWP", "PRES", "RAIN", "WSPM", "wd"] +data_columns: ["id", "PM2.5", "timestamp"] +metadata_columns: ["station"] +reduce_cardinality: False +stats_head_type: mlp +context_module_type: sep_mlp + +context_vars: # for each desired context variable, add the name and number of categories + year: 2 + month: 12 + weekday: 7 + +time_series_context_vars: # for each desired continuous context variable, add the name and number of bins +- TEMP +- DEWP +- PRES +- RAIN +- WSPM +- wd + diff --git a/cents/datasets/airquality.py b/cents/datasets/airquality.py new file mode 100644 index 0000000..57286f5 --- /dev/null +++ b/cents/datasets/airquality.py @@ -0,0 +1,132 @@ +import os +import warnings +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides + +from cents.datasets.timeseries_dataset import TimeSeriesDataset + +warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) +# These are warnings for an error that is accounted for in the code +warnings.filterwarnings("ignore", category=RuntimeWarning) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +class AirQualityDataset(TimeSeriesDataset): + def __init__(self, cfg: DictConfig = None, + overrides: Optional[List[str]] = None): + """ + Initializes the AirQuality Dataset. Available at: + https://doi.org/10.24432/C5RK5G. + Hourly Air Quality at Multiple Sites in China, in many measures. + Accompanying location and weather information. + """ + + if cfg is None: + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "airquality.yaml")) + + if overrides: + cfg = apply_overrides(cfg, overrides) + + self.cfg = cfg + self.name = cfg.name + self.normalize = cfg.normalize + self.target_time_series_columns = cfg.target_time_series_columns + self.context_time_series_columns = cfg.context_time_series_columns + self.geography = cfg.geography + self.time_series_dims = len(self.target_time_series_columns) + + self.city_names = ["Aotizhongxin", "Changping", "Dingling", "Dongsi", "Guanyuan", + "Gucheng", "Huairou", "Nongzhanguan", "Shunyi", "Tiantan", "Wanliu", "Wanshouxigong"] + + self._load_data() + + super().__init__( + data=self.data, + time_series_column_names=self.time_series_columns, + context_var_column_names=list(self.cfg.context_vars.keys()), + seq_len=self.cfg.seq_len, + normalize=self.cfg.normalize, + scale=self.cfg.scale, + skip_heavy_processing=cfg.get('skip_heavy_processing', False), + size=cfg.get('max_samples', None) + ) + + def _load_data(self): + """ + Loads in metadata and data for commercial energy dataset. + Categorial CVs: Year, Month, Day, Location + Context Time Series: Temperature, Pressure, Dewpoint, Precipitation, Wind + """ + + module_dir = os.path.dirname(os.path.abspath(__file__)) + path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) + + meta_path = os.path.join(path, "metadata.csv") + if not os.path.exists(meta_path): + raise FileNotFoundError(f"Metadata file not found at {meta_path}") + + + if not self.geography: self.geography = self.city_names + + self.geography = [self.geography] if isinstance(self.geography, str) else self.geography + dfs = [] + for name in self.geography: + fname = f"PRSA_Data_{name}_20130301-20170228.csv" + data_path = os.path.join(path, fname) + + if not os.path.exists(data_path): + raise FileNotFoundError(f"Data file not found at {data_path}") + + dfs.append(pd.read_csv(data_path)) + + self.data = pd.concat(dfs, axis=0)[self.cfg.data_columns] + + + def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: + data = data.copy() + data['timestamp'] = pd.to_datetime(data[["year", "month", "day", "hour"]]) + data['weekday'] = data['timestamp'].dt.day_name() + + ts_cols = self.context_time_series_columns + self.target_time_series_columns + + data = data.sort_values(['location', 'year', 'month', 'day', 'hour']) + + + grouped = ( + data.groupby(["station", "year", "month", "day"], as_index=False) + .agg({**{c: list for c in ts_cols}, + "weekday": 'first'}) + ) + + grouped = grouped[grouped["PM2.5"].apply(len) == self.cfg.seq_len].reset_index( + drop=True + ) + + grouped = self._handle_missing_data(grouped) + + return grouped + + + def _hande_missing_data(self, data): + mask = data[self.context_time_series_columns].applymap(is_all_nan).any(axis=1) + data = data[~mask] + + for col in self.context_time_series_columns: + data[col] = data[col].apply(fill_with_row_mean) + + return data + + + +def is_all_nan(lst): + arr = np.array(lst, dtype=float) + return np.isnan(arr).all() + + +def fill_with_row_mean(lst): + s = pd.Series(lst, dtype=float) + m = s.mean(skipna=True) + return s.fillna(m).tolist() diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 3d3a1f6..bdaefb2 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -12,6 +12,7 @@ from cents.utils.config_loader import load_yaml from pathlib import Path import torch +import os MODEL_KEY = "diffusion_ts" DATASET_OVERRIDES = [ @@ -28,7 +29,7 @@ def main() -> None: - model_ckpt = "cents/outputs/diffusion_ts_commercial_all/2025-11-07_15-09-33/commercial_diffusion_ts_dim1_ctxsep_mlp_statsmlp.ckpt" + model_ckpt = "cents/outputs/diffusion_ts_commercial_all/2025-11-13_19-50-40/commercial_diffusion_ts_dim1_ctxsep_mlp_statsmlp.ckpt" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) @@ -57,7 +58,13 @@ def main() -> None: cfg.eval_disentanglement = eval_cfg.get("eval_disentanglement", True) cfg.job_name = eval_cfg.get("job_name", "default_job") cfg.save_results = True - cfg.save_dir = HOME / f"cents/outputs/diffusion_ts_commercial_all/2025-11-07_15-09-33/eval" + cfg.save_dir = HOME / f"cents/outputs/diffusion_ts_commercial_all/2025-11-13_19-50-40/eval" + + + if not os.path.exists(cfg.save_dir): + os.makedirs(cfg.save_dir) + print("Creating Evaluation Directory") + print("Dataset spec set. Setting up DataGenerator...") # Use the fixed checkpoint with DataGenerator diff --git a/scripts/train.py b/scripts/train.py index ff0ec75..9634a6c 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -28,7 +28,7 @@ def main(args) -> None: trainer_overrides = [ f"trainer.max_epochs={args.epochs}", f"trainer.strategy={args.ddp_strategy}", - f"trainer.devices=0,1", + f"trainer.devices={args.devices}", f"trainer.eval_after_training={args.eval_after_training}", f"train.accelerator={args.accelerator}", "trainer.early_stopping.patience=100", # Stop if no improvement for 100 epochs From 8ece83a76ca3a6dc1b71ee2bdc1adf44e5033291 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 7 Jan 2026 10:23:41 -0500 Subject: [PATCH 16/50] Modified configs for easier customization of context variables and models --- cents/config/dataset/commercial.yaml | 23 ++++++---------- cents/config/dataset/pecanstreet.yaml | 25 +++++++---------- cents/config/model/acgan.yaml | 1 - cents/config/model/diffusion_ts.yaml | 1 - cents/datasets/pecanstreet.py | 2 -- cents/datasets/timeseries_dataset.py | 27 ++++++++++++++----- cents/models/acgan.py | 7 ++++- cents/models/base.py | 8 +++--- cents/models/context.py | 9 +++---- cents/models/diffusion_ts.py | 2 +- cents/models/normalizer.py | 39 +++++++++++++-------------- cents/trainer.py | 13 +++++---- cents/utils/utils.py | 39 +++++++++++++++++++++++++++ scripts/eval_pretrained.py | 13 +++++++-- scripts/train.py | 11 ++++++++ 15 files changed, 140 insertions(+), 80 deletions(-) diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index 86a8722..62fdb38 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -17,20 +17,13 @@ metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sqft", "yearb numeric_context_bins: 5 numeric_cols: ["sqft", "yearbuilt"] # Columns to bin as numeric reduce_cardinality: False -stats_head_type: mlp -context_module_type: sep_mlp -context_vars: # for each desired context variable, add the name and number of categories - year: 2 - month: 12 - weekday: 7 - site_id: 19 - primaryspaceusage: 16 - # sqft: 5 - # yearbuilt: 5 - # sub_primaryspaceusage: 105 - -continuous_context_vars: # for each desired continuous context variable, add the name and number of bins -- sqft -- yearbuilt +context_vars: + year: ["categorical", 2] + month: ["categorical", 12] + weekday: ["categorical", 7] + site_id: ["categorical", 19] + primaryspaceusage: ["categorical", 16] + sqft: ["continuous", None] + yearbuilt: ["continuous", None] \ No newline at end of file diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index d71666d..97ecd41 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -13,21 +13,16 @@ path: "./data/pecanstreet/csv" time_series_columns: ["grid", "solar"] data_columns: ["dataid","local_15min","car1","grid","solar"] metadata_columns: ["dataid","building_type","solar","car1","city","state","total_square_footage","house_construction_year"] -numeric_cols: ["total_square_footage", "house_construction_year"] user_group: all # non_pv_users, all, pv_users numeric_context_bins: 5 -stats_head_type: mlp -context_module_type: sep_mlp -context_vars: # for each desired context variable, add the name and number of categories - month: 12 - weekday: 7 - building_type: 3 - has_solar: 2 # note that the metadata csv file column name is 'solar', which is renamed to avoid conflicts with the 'solar' column in the data csv. - car1: 2 - city: 7 - state: 3 - -continuous_context_vars: -- total_square_footage -- house_construction_year +context_vars: + month: ["categorical", 12] + weekday: ["categorical", 7] + building_type: ["categorical", 3] + has_solar: ["categorical", 2] + car1: ["categorical", 2] + city: ["categorical", 7] + state: ["categorical", 3] + total_square_footage: ["continuous", None] + house_construction_year: ["continuous", None] \ No newline at end of file diff --git a/cents/config/model/acgan.yaml b/cents/config/model/acgan.yaml index 6adcc3c..b72bd60 100644 --- a/cents/config/model/acgan.yaml +++ b/cents/config/model/acgan.yaml @@ -2,7 +2,6 @@ _target_: generator.gan.acgan.ACGAN name: acgan noise_dim: 256 cond_emb_dim: 16 -context_module_type: sep_mlp include_auxiliary_losses: True context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 2d5bea7..7cb76c1 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -4,7 +4,6 @@ context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 noise_dim: 256 cond_emb_dim: 16 -context_module_type: sep_mlp n_layer_enc: 4 n_layer_dec: 5 d_model: 128 diff --git a/cents/datasets/pecanstreet.py b/cents/datasets/pecanstreet.py index 54c4270..170cfef 100644 --- a/cents/datasets/pecanstreet.py +++ b/cents/datasets/pecanstreet.py @@ -62,8 +62,6 @@ def __init__( self.threshold = (-1 * int(cfg.threshold), int(cfg.threshold)) self.time_series_dims = cfg.time_series_dims - self.cfg.time_series_columns = ["grid", "solar"] - self.include_generation = self.time_series_dims > 1 if self.time_series_dims > 1 and self.cfg.user_group in {"non_pv_users", "all"}: diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 627f42f..f01f0a8 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -17,7 +17,7 @@ from cents.datasets.utils import encode_context_variables from cents.models.normalizer import Normalizer from cents.utils.config_loader import load_yaml, apply_overrides -from cents.utils.utils import _ckpt_name, get_normalizer_training_config +from cents.utils.utils import _ckpt_name, get_normalizer_training_config, get_context_config ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -54,6 +54,7 @@ def __init__( overrides: Dict[str, Any] = {}, skip_heavy_processing: bool = False, size: int = None, + categorical_time_series: Dict[str, int] = None, ): # Initialize basic attributes # Handle OmegaConf ListConfig objects @@ -110,6 +111,9 @@ def __init__( self.normalize = normalize self.scale = scale + + # Store categorical time series info + self.categorical_time_series = categorical_time_series or {} if self.scale: assert self.normalize, "Normalization must be enabled if scaling is enabled" @@ -317,6 +321,12 @@ def merge_timeseries_columns(self, df: pd.DataFrame) -> pd.DataFrame: raise ValueError("Incorrect array shape.") else: raise ValueError("Array must have 2 dims.") + + # For categorical time series, ensure they remain as integers + for col in self.time_series_column_names: + if col in self.categorical_time_series: + df[col] = df[col].apply(lambda x: x.astype(np.int32)) + df["timeseries"] = df.apply( lambda r: np.hstack([r[c] for c in self.time_series_column_names]), axis=1 ) @@ -563,7 +573,8 @@ def _get_rarity_cache_path(self) -> str: """Get cache file path for rarity features.""" import hashlib # Create a hash based on dataset characteristics for cache key - context_module_type = getattr(self.cfg, "context_module_type", None) + context_cfg = get_context_config() + context_module_type = context_cfg.static_context.type cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{str(sorted(self.context_vars))}_{context_module_type or ''}" cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8] cache_dir = os.path.join(ROOT_DIR, "cache", "rarity") @@ -575,8 +586,9 @@ def _get_normalization_cache_path(self): import hashlib from pathlib import Path # Create hash based on dataset + normalizer characteristics - context_module_type = getattr(self.cfg, "context_module_type", None) - stats_head_type = getattr(self.cfg, "stats_head_type", None) + context_cfg = get_context_config() + context_module_type = context_cfg.dynamic_context.type + stats_head_type = context_cfg.normalizer.stats_head_type cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{self.normalize}_{self.scale}_{context_module_type or ''}_{stats_head_type or ''}" cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8] cache_dir = Path(ROOT_DIR) / "cache" / "normalized_data" @@ -636,9 +648,10 @@ def _init_normalizer(self) -> None: ) normalizer_dir.mkdir(parents=True, exist_ok=True) - # Get context_module_type and stats_head_type from config - context_module_type = getattr(self.cfg, "context_module_type", None) - stats_head_type = getattr(self.cfg, "stats_head_type", None) + # Get context_module_type and stats_head_type from context config + context_cfg = get_context_config() + context_module_type = context_cfg.dynamic_context.type + stats_head_type = context_cfg.normalizer.stats_head_type cache_path = normalizer_dir / _ckpt_name( self.name, diff --git a/cents/models/acgan.py b/cents/models/acgan.py index c71963b..ebc94c7 100644 --- a/cents/models/acgan.py +++ b/cents/models/acgan.py @@ -198,12 +198,17 @@ def __init__(self, cfg: DictConfig): # cfg.dataset.context_vars, cfg.model.cond_emb_dim # ) continuous_vars = getattr(cfg.dataset, "continuous_context_vars", None) or [] + # Get context module type from context config + from cents.utils.utils import get_context_config + context_cfg = get_context_config() + context_module_type = context_cfg.static_context.type + self.generator = Generator( noise_dim=cfg.model.noise_dim, embedding_dim=cfg.model.cond_emb_dim, final_window_length=cfg.dataset.seq_len, time_series_dims=cfg.dataset.time_series_dims, - context_module_type=cfg.model.context_module_type, + context_module_type=context_module_type, context_vars=cfg.dataset.context_vars, continuous_vars=continuous_vars, ) diff --git a/cents/models/base.py b/cents/models/base.py index 2b4fa63..f140741 100644 --- a/cents/models/base.py +++ b/cents/models/base.py @@ -7,6 +7,7 @@ from cents.models.context import MLPContextModule, SepMLPContextModule # Import to trigger registration from cents.models.context_registry import get_context_module_cls +from cents.utils.utils import get_context_config class BaseModel(pl.LightningModule, ABC): @@ -39,16 +40,17 @@ def __init__(self, cfg: DictConfig = None): if hasattr(cfg.dataset, "context_vars") and cfg.dataset.context_vars: emb_dim = getattr(cfg.model, "cond_emb_dim", 256) - context_module_type = getattr(cfg.model, "context_module_type", "default") + # Get context module type from context config + context_cfg = get_context_config() + context_module_type = context_cfg.static_context.type # Get continuous variables from config if specified - continuous_vars = getattr(cfg.dataset, "continuous_context_vars", None) + # continuous_vars = getattr(cfg.dataset, "continuous_context_vars", None) # Use registry to get the context module class ContextModuleCls = get_context_module_cls(context_module_type) self.context_module = ContextModuleCls( cfg.dataset.context_vars, emb_dim, - continuous_vars=continuous_vars ) else: self.context_module = None diff --git a/cents/models/context.py b/cents/models/context.py index 64ccfc5..7189f31 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -96,8 +96,6 @@ def __init__( embedding_dim: int, init_depth: int = 1, mixing_depth: int = 1, - continuous_vars: Optional[list[str]] = None, - continuous_var_stats: Optional[dict[str, dict[str, float]]] = None # Deprecated, kept for backward compatibility ) -> None: """ Initialize SepMLPContextModule. @@ -112,9 +110,10 @@ def __init__( super().__init__() self.embedding_dim = embedding_dim - self.continuous_vars = continuous_vars or [] - self.categorical_vars = {k: v for k, v in context_vars.items() if k not in self.continuous_vars} - + self.continuous_vars = [k for k, v in context_vars.items() if v[0] == "continuous"] + self.categorical_vars = {k: v[1] for k, v in context_vars.items() if k not in self.continuous_vars} + print(self.continuous_vars, "CONT VARS") + print(self.categorical_vars, "CAT VARS") self.context_embeddings = nn.ModuleDict( { name: nn.Embedding(num_categories, embedding_dim) diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 70d67f9..22b40a6 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -139,7 +139,7 @@ def __init__(self, cfg: DictConfig): self.auxiliary_loss = nn.CrossEntropyLoss() # Get continuous variables from config to distinguish them in loss computation - self.continuous_context_vars = getattr(cfg.dataset, "continuous_context_vars", None) or [] + self.continuous_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "continuous"] def predict_noise_from_start( self, x_t: torch.Tensor, t: torch.Tensor, x0: torch.Tensor diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 49dd5bb..c54be5b 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -16,6 +16,7 @@ from cents.models.context_registry import get_context_module_cls from cents.models.stats_head_registry import register_stats_head, get_stats_head_cls from cents.models.registry import register_model +from cents.utils.utils import get_context_config @register_stats_head("default", "mlp") @@ -211,21 +212,8 @@ def __init__( self.dataset = dataset # Get continuous variables from config if specified - continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] - # Convert to plain Python list if it's a ListConfig from OmegaConf - if continuous_vars: - if isinstance(continuous_vars, ListConfig): - continuous_vars = [str(v) for v in continuous_vars] # Ensure strings - elif isinstance(continuous_vars, list): - continuous_vars = [str(v) for v in continuous_vars] # Ensure strings - else: - continuous_vars = [str(continuous_vars)] - else: - continuous_vars = [] - - # Include both categorical and continuous variables in context_vars - # Ensure all are plain Python strings - categorical_vars = [str(k) for k in dataset_cfg.context_vars.keys()] + continuous_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "continuous"] + categorical_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "categorical"] self.context_vars = categorical_vars + continuous_vars self.time_series_cols = dataset_cfg.time_series_columns[ @@ -234,7 +222,10 @@ def __init__( self.time_series_dims = dataset_cfg.time_series_dims self.do_scale = dataset_cfg.scale - context_module_type = getattr(self.dataset_cfg, "context_module_type", "default") + # Get context config + context_cfg = get_context_config() + context_module_type = context_cfg.dynamic_context.type + stats_head_type = context_cfg.normalizer.stats_head_type # Use registry to get the context module class ContextModuleCls = get_context_module_cls(context_module_type) @@ -242,11 +233,7 @@ def __init__( context_module = ContextModuleCls( self.dataset_cfg.context_vars, 256, - continuous_vars=continuous_vars ) - - # Get stats head type from config - stats_head_type = getattr(self.dataset_cfg, "stats_head_type", "default") self.normalizer_model = _NormalizerModule( cond_module=context_module, @@ -598,6 +585,10 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: df_out = df.copy() self.eval() continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] + + # Get categorical time series from dataset if available + categorical_ts = getattr(self.dataset, 'categorical_time_series', {}) + with torch.no_grad(): for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Normalizing"): ctx = {} @@ -611,6 +602,14 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: for d, col in enumerate(self.time_series_cols): arr = np.asarray(row[col], dtype=np.float32) + + # Skip normalization for categorical time series + if col in categorical_ts: + # Keep as integers, just ensure proper dtype + df_out.at[i, col] = arr.astype(np.int32) + continue + + # Normalize numeric time series z = (arr - mu[d]) / (sigma[d] + 1e-8) if self.do_scale: zmin_, zmax_ = zmin[0, d].item(), zmax[0, d].item() diff --git a/cents/trainer.py b/cents/trainer.py index e032099..75764d8 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -201,16 +201,15 @@ def _instantiate_trainer(self) -> pl.Trainer: f"dim{self.cfg.dataset.time_series_dims}" ] - # Add context_module_type if available (from model or dataset config) - context_module_type = getattr( - self.cfg.model, "context_module_type", - getattr(self.cfg.dataset, "context_module_type", None) - ) + # Add context_module_type from context config + from cents.utils.utils import get_context_config + context_cfg = get_context_config() + context_module_type = context_cfg.static_context.type if context_module_type: filename_parts.append(f"ctx{context_module_type}") - # Add stats_head_type if available (typically in dataset config for normalizer) - stats_head_type = getattr(self.cfg.dataset, "stats_head_type", None) + # Add stats_head_type from context config + stats_head_type = context_cfg.normalizer.stats_head_type if stats_head_type: filename_parts.append(f"stats{stats_head_type}") diff --git a/cents/utils/utils.py b/cents/utils/utils.py index 70088dc..d26eede 100644 --- a/cents/utils/utils.py +++ b/cents/utils/utils.py @@ -66,6 +66,45 @@ def get_normalizer_training_config(): ) return OmegaConf.load(config_path) +_context_config_path = None + + +def set_context_config_path(path: str): + """ + Set a custom path for the context configuration file. + This path will be used by get_context_config() instead of the default. + + Args: + path: Path to the context config YAML file. If None, resets to default. + """ + global _context_config_path + _context_config_path = path + +def get_context_config(path: str = None): + """ + Load the context configuration from config/context/default.yaml or a custom path. + + Args: + path: Optional path to a custom context config file. If None, uses the path + set by set_context_config_path() or defaults to config/context/default.yaml. + + Returns: + OmegaConf config with static_context, normalizer, and dynamic_context sections. + """ + if path is not None: + config_path = path + elif _context_config_path is not None: + print(f"Using custom context config path: {_context_config_path}") + config_path = _context_config_path + else: + config_path = os.path.join( + ROOT_DIR, + "config", + "context", + "default.yaml", + ) + return OmegaConf.load(config_path) + def get_default_trainer_config(): config_path = os.path.join( diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index bdaefb2..bcfba38 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -10,9 +10,11 @@ from cents.datasets.commercial import CommercialDataset from cents.eval.eval import Evaluator from cents.utils.config_loader import load_yaml +from cents.utils.utils import set_context_config_path from pathlib import Path import torch import os +import argparse MODEL_KEY = "diffusion_ts" DATASET_OVERRIDES = [ @@ -27,7 +29,10 @@ HOME = Path.home() -def main() -> None: +def main(args) -> None: + # Set custom context config path if provided + if args.context_config_path: + set_context_config_path(args.context_config_path) model_ckpt = "cents/outputs/diffusion_ts_commercial_all/2025-11-13_19-50-40/commercial_diffusion_ts_dim1_ctxsep_mlp_statsmlp.ckpt" logging.basicConfig( @@ -83,4 +88,8 @@ def main() -> None: if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--context-config-path", type=str, default=None, + help="Path to custom context config YAML file (optional)") + args = parser.parse_args() + main(args) diff --git a/scripts/train.py b/scripts/train.py index 9634a6c..eb52646 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -3,7 +3,9 @@ from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset +from cents.datasets.airquality import AirQualityDataset from cents.trainer import Trainer +from cents.utils.utils import set_context_config_path from pytorch_lightning.callbacks import EarlyStopping import warnings import argparse @@ -14,12 +16,19 @@ def main(args) -> None: MODEL_NAME = args.model_name CR_LOSS_WEIGHT = args.cr_loss_weight TC_LOSS_WEIGHT = args.tc_loss_weight + + # Set custom context config path if provided + if args.context_config_path: + set_context_config_path(args.context_config_path) + # Skip heavy processing for DDP compatibility if args.dataset == "pecanstreet": dataset = PecanStreetDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}, time_series_dims=1, user_group=all"]) elif args.dataset == "commercial": dataset = CommercialDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"]) + elif args.dataset == "airquality": + dataset = AirQualityDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"]) else: raise ValueError(f"Dataset {args.dataset} not supported") @@ -69,6 +78,8 @@ def main(args) -> None: parser.add_argument("--skip_heavy_processing", type=bool, default=True) parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_false") parser.add_argument("--enable_checkpointing", type=bool, default=True) + parser.add_argument("--context-config-path", type=str, default=None, + help="Path to custom context config YAML file (optional)") args = parser.parse_args() main(args) From e90410a855894084dc4ca47df69835f8608a3b43 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Tue, 20 Jan 2026 11:33:08 -0500 Subject: [PATCH 17/50] Dynamic Context Added; CNN Support, Airquality Dataset for Testing --- cents/config/dataset/airquality.yaml | 34 ++-- cents/config/dataset/commercial.yaml | 4 +- cents/config/dataset/pecanstreet.yaml | 4 +- cents/config/trainer/diffusion_ts.yaml | 2 +- cents/datasets/airquality.py | 77 +++++--- cents/datasets/timeseries_dataset.py | 52 ++--- cents/datasets/utils.py | 25 ++- cents/models/base.py | 1 - cents/models/context.py | 176 ++++++++++++++++- cents/models/context_registry.py | 18 +- cents/models/diffusion_ts.py | 3 +- cents/models/normalizer.py | 259 ++++++++++++++++++++----- cents/utils/config_loader.py | 2 + cents/utils/utils.py | 37 +++- scripts/train.py | 8 +- 15 files changed, 555 insertions(+), 147 deletions(-) diff --git a/cents/config/dataset/airquality.yaml b/cents/config/dataset/airquality.yaml index 92839c5..af299be 100644 --- a/cents/config/dataset/airquality.yaml +++ b/cents/config/dataset/airquality.yaml @@ -5,28 +5,24 @@ scale: True use_learned_normalizer: True threshold: 8 seq_len: 24 +time_series_dims: 1 shuffle: True skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP) max_samples: null # Limit dataset size (null = use all data) path: "./data/airquality" -target_time_series_columns: "PM2.5" -context_time_series_columns: ["TEMP", "DEWP", "PRES", "RAIN", "WSPM", "wd"] -data_columns: ["id", "PM2.5", "timestamp"] -metadata_columns: ["station"] +numeric_context_bins: 5 +time_series_columns: "PM2.5" +data_columns: ["No", "PM2.5", "year", "month", "day", "hour", "TEMP", "DEWP", "PRES", "RAIN", "WSPM", "wd", "station"] reduce_cardinality: False -stats_head_type: mlp -context_module_type: sep_mlp - -context_vars: # for each desired context variable, add the name and number of categories - year: 2 - month: 12 - weekday: 7 - -time_series_context_vars: # for each desired continuous context variable, add the name and number of bins -- TEMP -- DEWP -- PRES -- RAIN -- WSPM -- wd +context_vars: + year: ["categorical", 5] + month: ["categorical", 12] + weekday: ["categorical", 7] + TEMP: ["time_series", null] + DEWP: ["time_series", null] + PRES: ["time_series", null] + RAIN: ["time_series", null] + WSPM: ["time_series", null] + wd: ["time_series", 17] + station: ["categorical", 12] \ No newline at end of file diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index 62fdb38..a60d82b 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -25,5 +25,5 @@ context_vars: weekday: ["categorical", 7] site_id: ["categorical", 19] primaryspaceusage: ["categorical", 16] - sqft: ["continuous", None] - yearbuilt: ["continuous", None] \ No newline at end of file + sqft: ["continuous", null] + yearbuilt: ["continuous", null] \ No newline at end of file diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index 97ecd41..0e63035 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -24,5 +24,5 @@ context_vars: car1: ["categorical", 2] city: ["categorical", 7] state: ["categorical", 3] - total_square_footage: ["continuous", None] - house_construction_year: ["continuous", None] \ No newline at end of file + total_square_footage: ["continuous", null] + house_construction_year: ["continuous", null] \ No newline at end of file diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 551f1ce..07b1b4e 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -5,7 +5,7 @@ strategy: ddp_find_unused_parameters_false gradient_accumulate_every: 4 log_every_n_steps: 1 batch_size: 512 -max_epochs: 200 +max_epochs: 2000 base_lr: 1e-4 eval_after_training: False diff --git a/cents/datasets/airquality.py b/cents/datasets/airquality.py index 57286f5..2553f1b 100644 --- a/cents/datasets/airquality.py +++ b/cents/datasets/airquality.py @@ -1,3 +1,4 @@ +from ast import Str import os import warnings from typing import Any, Dict, List, Optional @@ -33,25 +34,33 @@ def __init__(self, cfg: DictConfig = None, self.cfg = cfg self.name = cfg.name self.normalize = cfg.normalize - self.target_time_series_columns = cfg.target_time_series_columns - self.context_time_series_columns = cfg.context_time_series_columns + if isinstance(cfg.time_series_columns, str): + cfg.time_series_columns = [cfg.time_series_columns] + self.target_time_series_columns = cfg.time_series_columns self.geography = cfg.geography - self.time_series_dims = len(self.target_time_series_columns) + self.time_series_dims = cfg.time_series_dims self.city_names = ["Aotizhongxin", "Changping", "Dingling", "Dongsi", "Guanyuan", "Gucheng", "Huairou", "Nongzhanguan", "Shunyi", "Tiantan", "Wanliu", "Wanshouxigong"] - + self.context_time_series_columns = {k:v[1] for k,v in self.cfg.context_vars.items() if v[0] == "time_series"} + self.context_series_names = list(self.context_time_series_columns.keys()) + + self.categorical_time_series = { + k: v[1] for k, v in self.cfg.context_vars.items() + if v[0] == "time_series" and v[1] is not None + } self._load_data() super().__init__( data=self.data, - time_series_column_names=self.time_series_columns, + time_series_column_names=self.target_time_series_columns, context_var_column_names=list(self.cfg.context_vars.keys()), seq_len=self.cfg.seq_len, normalize=self.cfg.normalize, scale=self.cfg.scale, skip_heavy_processing=cfg.get('skip_heavy_processing', False), - size=cfg.get('max_samples', None) + size=cfg.get('max_samples', None), + categorical_time_series=self.categorical_time_series, ) def _load_data(self): @@ -64,11 +73,6 @@ def _load_data(self): module_dir = os.path.dirname(os.path.abspath(__file__)) path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) - meta_path = os.path.join(path, "metadata.csv") - if not os.path.exists(meta_path): - raise FileNotFoundError(f"Metadata file not found at {meta_path}") - - if not self.geography: self.geography = self.city_names self.geography = [self.geography] if isinstance(self.geography, str) else self.geography @@ -80,51 +84,74 @@ def _load_data(self): if not os.path.exists(data_path): raise FileNotFoundError(f"Data file not found at {data_path}") - dfs.append(pd.read_csv(data_path)) + dfs.append(pd.read_csv(data_path)[self.cfg.data_columns]) - self.data = pd.concat(dfs, axis=0)[self.cfg.data_columns] - + self.data = pd.concat(dfs, axis=0) def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: data = data.copy() data['timestamp'] = pd.to_datetime(data[["year", "month", "day", "hour"]]) data['weekday'] = data['timestamp'].dt.day_name() - ts_cols = self.context_time_series_columns + self.target_time_series_columns + ts_cols = self.context_series_names + self.target_time_series_columns + + data = data.sort_values(['station', 'year', 'month', 'day', 'hour']) - data = data.sort_values(['location', 'year', 'month', 'day', 'hour']) + # Map month integer to month name string as quickly as possible + months = ["January", "February", "March", "April", "May", "June", + "July", "August", "September", "October", "November", "December"] + data['month'] = data['month'].map(lambda x: months[x-1]) + group_keys = ["station", "year", "month", "day", "weekday"] grouped = ( - data.groupby(["station", "year", "month", "day"], as_index=False) - .agg({**{c: list for c in ts_cols}, - "weekday": 'first'}) + data.groupby(group_keys, as_index=False, sort=False) + .agg({c: list for c in ts_cols}) ) + # Convert lists -> numpy arrays (fast + deterministic) + for c in ts_cols: + grouped[c] = grouped[c].map(np.asarray) + grouped = grouped[grouped["PM2.5"].apply(len) == self.cfg.seq_len].reset_index( drop=True ) grouped = self._handle_missing_data(grouped) + + # Convert all lists in time series columns into tuples to make them hashable + for c in ts_cols: + grouped[c] = grouped[c].map(tuple) return grouped - def _hande_missing_data(self, data): - mask = data[self.context_time_series_columns].applymap(is_all_nan).any(axis=1) + def _handle_missing_data(self, data): + # Only handle missing data for numeric time series + numeric_series = [c for c in self.context_series_names if c not in self.categorical_time_series] + + mask = data[numeric_series].applymap(is_all_nan).any(axis=1) if numeric_series else pd.Series([False] * len(data)) data = data[~mask] - for col in self.context_time_series_columns: + for col in numeric_series: data[col] = data[col].apply(fill_with_row_mean) + data[list(self.categorical_time_series.keys())] + + mask = data[list(self.categorical_time_series.keys())].applymap(is_any_nan).any(axis=1) + data = data[~mask] + + data = data.loc[data["PM2.5"].apply(lambda x: not np.isnan(x).any())] + return data -def is_all_nan(lst): - arr = np.array(lst, dtype=float) - return np.isnan(arr).all() +def is_all_nan(arr): + return pd.isna(arr).all() +def is_any_nan(arr): + return pd.isna(arr).any() def fill_with_row_mean(lst): s = pd.Series(lst, dtype=float) diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index f01f0a8..e23005e 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -67,7 +67,7 @@ def __init__( if isinstance(time_series_column_names, list) else [time_series_column_names] ) - self.time_series_dims = len(self.time_series_column_names) + self.time_series_dims = self.cfg.time_series_dims self.context_vars = context_var_column_names or [] self.seq_len = seq_len @@ -93,21 +93,7 @@ def __init__( self.name = "custom" # Add continuous variables to context_vars if specified - continuous_vars = getattr(self.cfg, "continuous_context_vars", None) or [] - # Convert to plain Python list if it's a ListConfig from OmegaConf - if continuous_vars: - if isinstance(continuous_vars, ListConfig): - continuous_vars = [str(v) for v in continuous_vars] - elif isinstance(continuous_vars, list): - continuous_vars = [str(v) for v in continuous_vars] - else: - continuous_vars = [str(continuous_vars)] - else: - continuous_vars = [] - - # Ensure continuous variables are included in self.context_vars - if continuous_vars: - self.context_vars = list(self.context_vars) + [v for v in continuous_vars if v not in self.context_vars] + self.continuous_vars = [k for k, v in self.cfg.context_vars.items() if v[0] == "continuous"] self.normalize = normalize self.scale = scale @@ -121,8 +107,7 @@ def __init__( # Preprocess and optionally encode context self.data = self._preprocess_data(data) - continuous_vars = getattr(self.cfg, "continuous_context_vars", None) or [] - if continuous_vars: + if self.continuous_vars: self._normalize_continuous_vars() if size is not None: @@ -132,11 +117,11 @@ def __init__( self.data, self.context_var_codes = self._encode_context_vars(self.data) self._save_context_var_codes() + is_ddp_subprocess = self._is_ddp_subprocess() if self.normalize: self._init_normalizer() cache_path = self._get_normalization_cache_path() - if cache_path.exists(): print(f"[{'DDP Subprocess' if is_ddp_subprocess else 'Main Process'}] Loading pre-normalized data from cache") with open(cache_path, 'rb') as f: @@ -146,7 +131,6 @@ def __init__( if not is_ddp_subprocess: print("[Main Process] Normalizing data...") self.data = self._normalizer.transform(self.data) - # Save to cache for subprocesses (only main process) if not is_ddp_subprocess: cache_path.parent.mkdir(parents=True, exist_ok=True) @@ -155,6 +139,7 @@ def __init__( print(f"[Main Process] Cached normalized data for subprocesses") self.data = self.merge_timeseries_columns(self.data) self.data = self.data.reset_index() + # Check if we should skip heavy processing for DDP if is_ddp_subprocess and skip_heavy_processing: @@ -265,7 +250,7 @@ def get_train_dataloader( # if col in self.data.columns: # print(self.data[col].mean()) return DataLoader( - self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, persistent_workers=persistent_workers + self, batch_size=batch_size, shuffle=shuffle, num_workers=8, persistent_workers=persistent_workers ) def split_timeseries(self, df: pd.DataFrame) -> pd.DataFrame: @@ -364,13 +349,17 @@ def _encode_context_vars( Returns: Tuple of encoded DataFrame and mapping codes. """ - continuous_vars = getattr(self.cfg, 'continuous_context_vars', None) + continuous_vars = [k for k, v in self.cfg.context_vars.items() if v[0] == "continuous"] + time_series_cols = [k for k, v in self.cfg.context_vars.items() if v[0] == "time_series"] + numeric_cols = [k for k, v in self.cfg.context_vars.items() if v[0] == "categorical" and v[1] == None] encoded_data, mapping = encode_context_variables( data=data, columns_to_encode=self.context_vars, bins=self.numeric_context_bins, - numeric_cols=getattr(self.cfg, 'numeric_cols', None), + numeric_cols=numeric_cols, continuous_vars=continuous_vars, + time_series_cols=time_series_cols, + categorical_time_series=self.categorical_time_series, ) return encoded_data, mapping @@ -380,14 +369,13 @@ def _normalize_continuous_vars(self): Normalize continuous context variables in the dataset using z-score normalization. This is done once during dataset initialization, so models receive pre-normalized values. """ - continuous_vars = getattr(self.cfg, "continuous_context_vars", None) or [] - if not continuous_vars: + if not self.continuous_vars: return # Store stats for potential inverse transform if needed self.continuous_var_stats = {} - for var_name in continuous_vars: + for var_name in self.continuous_vars: if var_name in self.data.columns: values = self.data[var_name] # Compute mean and std @@ -650,15 +638,17 @@ def _init_normalizer(self) -> None: # Get context_module_type and stats_head_type from context config context_cfg = get_context_config() - context_module_type = context_cfg.dynamic_context.type - stats_head_type = context_cfg.normalizer.stats_head_type - + + self.dynamic_module_type = context_cfg.dynamic_context.type + self.static_module_type = context_cfg.static_context.type + self.stats_head_type = context_cfg.normalizer.stats_head_type cache_path = normalizer_dir / _ckpt_name( self.name, "normalizer", self.time_series_dims, - context_module_type=context_module_type, - stats_head_type=stats_head_type + static_module_type=self.static_module_type, + stats_head_type=self.stats_head_type, + dynamic_module_type=self.dynamic_module_type, ) ncfg = get_normalizer_training_config() diff --git a/cents/datasets/utils.py b/cents/datasets/utils.py index 244a52e..3b12810 100644 --- a/cents/datasets/utils.py +++ b/cents/datasets/utils.py @@ -108,7 +108,10 @@ def split_dataset(dataset: Dataset, val_split: float = 0.1) -> Tuple[Dataset, Da def encode_context_variables( - data: pd.DataFrame, columns_to_encode: List[str], bins: int, numeric_cols: List[str] = None, continuous_vars: List[str] = None + data: pd.DataFrame, columns_to_encode: List[str], bins: int, + numeric_cols: List[str] = None, continuous_vars: List[str] = None, + time_series_cols: List[str] = None, + categorical_time_series: Dict[str, int] = None, ) -> Tuple[pd.DataFrame, Dict[str, Dict[int, Any]]]: """ Encodes specified columns in the DataFrame either by binning numeric columns @@ -159,10 +162,11 @@ def encode_context_variables( for col in columns_to_encode: # Skip continuous variables - they should remain as float values - if col in continuous_vars: + if col in categorical_time_series: + encoded_data[col], mapping[col] = encode_list_column(encoded_data[col]) + elif col in time_series_cols or col in continuous_vars: continue - - if numeric_cols and col in numeric_cols: + elif numeric_cols and col in numeric_cols: # Numeric column: Perform binning # Handle NaN values by filling with median before binning if encoded_data[col].isna().all(): @@ -251,3 +255,16 @@ def convert_generated_data_to_df( records.append(record) return pd.DataFrame.from_records(records) + +def encode_list_column(series: pd.Series): + count = 0 + for x in series: + for t in x: + if pd.isna(t): + count += 1 + vocab = sorted({t for x in series for t in x}) + tok2id = {t: i for i, t in enumerate(vocab)} + encoded = series.apply(lambda x: [tok2id[t] for t in x]) + encoded = encoded.map(tuple) + mapping = dict(enumerate(vocab)) # id -> token + return encoded, mapping \ No newline at end of file diff --git a/cents/models/base.py b/cents/models/base.py index f140741..73b1f83 100644 --- a/cents/models/base.py +++ b/cents/models/base.py @@ -43,7 +43,6 @@ def __init__(self, cfg: DictConfig = None): # Get context module type from context config context_cfg = get_context_config() context_module_type = context_cfg.static_context.type - # Get continuous variables from config if specified # continuous_vars = getattr(cfg.dataset, "continuous_context_vars", None) # Use registry to get the context module class diff --git a/cents/models/context.py b/cents/models/context.py index 7189f31..1b17b2b 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -111,9 +111,7 @@ def __init__( self.embedding_dim = embedding_dim self.continuous_vars = [k for k, v in context_vars.items() if v[0] == "continuous"] - self.categorical_vars = {k: v[1] for k, v in context_vars.items() if k not in self.continuous_vars} - print(self.continuous_vars, "CONT VARS") - print(self.categorical_vars, "CAT VARS") + self.categorical_vars = {k: v[1] for k, v in context_vars.items() if v[0] == "categorical"} self.context_embeddings = nn.ModuleDict( { name: nn.Embedding(num_categories, embedding_dim) @@ -258,4 +256,174 @@ def forward(self, context_vars): # The training step will need to distinguish between them all_outputs = {**classification_logits, **regression_outputs} - return embedding, all_outputs \ No newline at end of file + return embedding, all_outputs + + +@register_context_module("dynamic_cnn") +class DynamicContextModule(BaseContextModule): + """ + Context module for processing dynamic (time series) context variables. + Uses 1D convolutions to encode time series sequences into embeddings. + """ + + def __init__( + self, + context_vars: dict[str, int], + embedding_dim: int, + seq_len: int = None, + ): + """ + Initialize DynamicContextModule. + + Args: + context_vars: Mapping of variable names to category counts (for categorical time series) + or None (for numeric time series). Format: {name: [type, num_categories]} + embedding_dim: Size of embedding vectors. + seq_len: Sequence length of time series context variables. + """ + super().__init__() + self.embedding_dim = embedding_dim + + # Separate categorical and numeric time series + self.categorical_ts_vars = { + k: v[1] for k, v in context_vars.items() + if v[0] == "time_series" and v[1] is not None + } + self.numeric_ts_vars = [ + k for k, v in context_vars.items() + if v[0] == "time_series" and v[1] is None + ] + + # For categorical time series, use embedding + CNN + self.ts_embeddings = nn.ModuleDict({ + name: nn.Embedding(num_categories, embedding_dim) + for name, num_categories in self.categorical_ts_vars.items() + }) + + # CNN encoders for each time series variable + # For categorical: input is (batch, seq_len) -> embedding -> (batch, seq_len, emb_dim) -> CNN + # For numeric: input is (batch, seq_len) -> CNN + self.ts_encoders = nn.ModuleDict() + + for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars: + # 1D CNN to encode time series: (batch, channels, seq_len) -> (batch, embedding_dim) + encoder = nn.Sequential( + nn.Conv1d(embedding_dim if name in self.categorical_ts_vars else 1, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv1d(64, 128, kernel_size=3, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool1d(1), # Global average pooling + nn.Flatten(), + nn.Linear(128, embedding_dim), + ) + self.ts_encoders[name] = encoder + + # Mixing MLP to combine all time series embeddings + total_dim = embedding_dim * (len(self.categorical_ts_vars) + len(self.numeric_ts_vars)) + if total_dim > 0: + self.mixing_mlp = nn.Sequential( + nn.Linear(total_dim, 128), + nn.ReLU(), + nn.Linear(128, embedding_dim), + ) + else: + self.mixing_mlp = nn.Identity() + + # Initialize weights with Kaiming initialization + self._initialize_weights() + + def _initialize_weights(self): + """ + Initialize weights using Kaiming (He) initialization for better training with ReLU activations. + This is particularly important for the CNN layers and Linear layers. + """ + for module in self.modules(): + if isinstance(module, nn.Conv1d): + # Kaiming initialization for Conv1d layers (already default for ReLU, but make explicit) + nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.Linear): + # Kaiming initialization for Linear layers (better than default Xavier for ReLU) + nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') + if module.bias is not None: + nn.init.constant_(module.bias, 0) + # Note: Embedding layers keep their default initialization (normal with std=1.0) + # which is appropriate for embeddings + + def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """ + Process dynamic (time series) context variables. + + Args: + context_vars: Dict mapping variable names to tensors. + For categorical TS: (batch, seq_len) with integer values + For numeric TS: (batch, seq_len) with float values + + Returns: + embedding: Combined embedding of shape (batch_size, embedding_dim) + outputs: Empty dict for compatibility + """ + embeddings = [] + + # Process categorical time series + for name in self.categorical_ts_vars.keys(): + if name in context_vars: + # Input: (batch, seq_len) with integer indices + ts_data = context_vars[name] # (batch, seq_len) + # Check for NaN/Inf in input + if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): + raise ValueError(f"NaN/Inf detected in categorical time series input '{name}'") + # Embed: (batch, seq_len) -> (batch, seq_len, embedding_dim) + embedded = self.ts_embeddings[name](ts_data) + # Transpose for CNN: (batch, embedding_dim, seq_len) + embedded = embedded.transpose(1, 2) + # Check for NaN after embedding + if torch.isnan(embedded).any() or torch.isinf(embedded).any(): + raise ValueError(f"NaN/Inf detected after embedding for '{name}'") + # Encode: (batch, embedding_dim, seq_len) -> (batch, embedding_dim) + encoded = self.ts_encoders[name](embedded) + # Check for NaN after encoding + if torch.isnan(encoded).any() or torch.isinf(encoded).any(): + raise ValueError(f"NaN/Inf detected after encoding for '{name}'") + embeddings.append(encoded) + + # Process numeric time series + for name in self.numeric_ts_vars: + if name in context_vars: + # Input: (batch, seq_len) with float values + ts_data = context_vars[name] # (batch, seq_len) + # Ensure numeric time series are float type (not long/int) + if not ts_data.is_floating_point(): + ts_data = ts_data.float() + # Check for NaN/Inf in input + if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): + raise ValueError(f"NaN/Inf detected in numeric time series input '{name}'") + # Replace NaN/Inf with zeros to prevent propagation + ts_data = torch.where(torch.isfinite(ts_data), ts_data, torch.zeros_like(ts_data)) + # Add channel dimension: (batch, 1, seq_len) + ts_data = ts_data.unsqueeze(1) + # Encode: (batch, 1, seq_len) -> (batch, embedding_dim) + encoded = self.ts_encoders[name](ts_data) + # Check for NaN after encoding + if torch.isnan(encoded).any() or torch.isinf(encoded).any(): + raise ValueError(f"NaN/Inf detected after encoding numeric TS '{name}'") + embeddings.append(encoded) + + if not embeddings: + # No dynamic context variables, return zero embedding + batch_size = next(iter(context_vars.values())).size(0) if context_vars else 1 + embedding = torch.zeros(batch_size, self.embedding_dim, device=next(iter(context_vars.values())).device if context_vars else None) + return embedding, {} + + # Combine all time series embeddings + combined = torch.cat(embeddings, dim=1) # (batch, total_dim) + # Check for NaN before mixing + if torch.isnan(combined).any() or torch.isinf(combined).any(): + raise ValueError(f"NaN/Inf detected in combined embeddings before mixing MLP") + embedding = self.mixing_mlp(combined) # (batch, embedding_dim) + # Check for NaN after mixing + if torch.isnan(embedding).any() or torch.isinf(embedding).any(): + raise ValueError(f"NaN/Inf detected in final embedding after mixing MLP") + + return embedding, {} \ No newline at end of file diff --git a/cents/models/context_registry.py b/cents/models/context_registry.py index 8222e92..f497341 100644 --- a/cents/models/context_registry.py +++ b/cents/models/context_registry.py @@ -24,12 +24,13 @@ def decorator(cls): return decorator -def get_context_module_cls(key: str) -> type: +def get_context_module_cls(key: str, subkey: str = None) -> type: """ - Fetch the context module class for `key`. Raises if not found. + Fetch the context module class for `key` (and optionally `subkey`). Raises if not found. Args: - key: The name of the context module to retrieve. + key: The name of the context module to retrieve (e.g., "default", "dynamic"). + subkey: Optional subkey for two-part registration (e.g., "mlp", "cnn"). Returns: The context module class. @@ -37,11 +38,20 @@ def get_context_module_cls(key: str) -> type: Raises: ValueError: If the key is not found in the registry. """ + # Try two-part key first if subkey is provided + if subkey is not None: + two_part_key = f"{key}_{subkey}" + if two_part_key in _CONTEXT_MODULE_REGISTRY: + return _CONTEXT_MODULE_REGISTRY[two_part_key] + + # Try single key try: return _CONTEXT_MODULE_REGISTRY[key] except KeyError: + available = list(_CONTEXT_MODULE_REGISTRY.keys()) raise ValueError( - f"Unknown context module '{key}'. Available: {list(_CONTEXT_MODULE_REGISTRY.keys())}" + f"Unknown context module '{key}'" + (f" with subkey '{subkey}'" if subkey else "") + + f". Available: {available}" ) diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 22b40a6..bbf3134 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -140,6 +140,7 @@ def __init__(self, cfg: DictConfig): # Get continuous variables from config to distinguish them in loss computation self.continuous_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "continuous"] + self.categorical_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "categorical"] def predict_noise_from_start( self, x_t: torch.Tensor, t: torch.Tensor, x0: torch.Tensor @@ -303,7 +304,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: if var_name in self.continuous_context_vars: loss = F.mse_loss(outputs, labels.float()) - else: + elif var_name in self.categorical_context_vars: loss = self.auxiliary_loss(outputs, labels) cond_loss += loss.mean() diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index c54be5b..78fb1d8 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -12,7 +12,7 @@ from cents.datasets.utils import split_timeseries from cents.models.base import NormalizerModel -from cents.models.context import MLPContextModule, SepMLPContextModule # Import to trigger registration +from cents.models.context import MLPContextModule, SepMLPContextModule, DynamicContextModule # Import to trigger registration from cents.models.context_registry import get_context_module_cls from cents.models.stats_head_registry import register_stats_head, get_stats_head_cls from cents.models.registry import register_model @@ -136,7 +136,8 @@ class _NormalizerModule(nn.Module): def __init__( self, - cond_module: nn.Module, + static_cond_module: nn.Module = None, + dynamic_cond_module: nn.Module = None, hidden_dim: int = 512, time_series_dims: int = 2, do_scale: bool = True, @@ -144,15 +145,36 @@ def __init__( ): """ Args: - cond_module: ContextModule instance for embedding context variables. + static_cond_module: ContextModule instance for static context variables (categorical + continuous). + dynamic_cond_module: ContextModule instance for dynamic context variables (time_series). hidden_dim: Hidden dimension size for the stats head. time_series_dims: Number of time series dimensions. do_scale: Whether to include scaling predictions. stats_head_type: Type of stats head to use (from registry). """ super().__init__() - self.cond_module = cond_module - self.embedding_dim = cond_module.embedding_dim + self.static_cond_module = static_cond_module + self.dynamic_cond_module = dynamic_cond_module + + # Determine embedding dimension from available modules + if static_cond_module is not None: + self.embedding_dim = static_cond_module.embedding_dim + elif dynamic_cond_module is not None: + self.embedding_dim = dynamic_cond_module.embedding_dim + else: + raise ValueError("At least one of static_cond_module or dynamic_cond_module must be provided") + + # If both modules exist, combine their embeddings + if static_cond_module is not None and dynamic_cond_module is not None: + # Combine embeddings from both modules + combined_dim = static_cond_module.embedding_dim + dynamic_cond_module.embedding_dim + self.combine_mlp = nn.Sequential( + nn.Linear(combined_dim, self.embedding_dim), + nn.ReLU(), + ) + else: + self.combine_mlp = None + # Use registry to get the stats head class StatsHeadCls = get_stats_head_cls(stats_head_type) self.stats_head = StatsHeadCls( @@ -162,25 +184,68 @@ def __init__( do_scale=do_scale, ) - def forward(self, cat_vars_dict: dict): + def forward(self, context_vars_dict: dict): """ Compute normalization parameters from categorical context. Args: - cat_vars_dict: Mapping of context variable names to label tensors. + context_vars_dict: Mapping of context variable names to label tensors. + Static vars: single values (categorical: long, continuous: float) + Dynamic vars: time series sequences (batch, seq_len) Returns: Tuple of (pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped). """ - # Ensure all tensors in the dict are on the same device and properly connected - # This helps with DataLoader multiprocessing issues - device = next(self.cond_module.parameters()).device - cat_vars_dict = { - k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v - for k, v in cat_vars_dict.items() - } + embeddings = [] + + # Process static context variables + if self.static_cond_module is not None: + # Filter static context variables + static_vars = { + k: v for k, v in context_vars_dict.items() + if k not in getattr(self, '_dynamic_var_names', []) + } + if static_vars: + device = next(self.static_cond_module.parameters()).device + static_vars = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in static_vars.items() + } + static_embedding, _ = self.static_cond_module(static_vars) + embeddings.append(static_embedding) + + # Process dynamic context variables + if self.dynamic_cond_module is not None: + # Filter dynamic context variables + dynamic_var_names = getattr(self, '_dynamic_var_names', []) + dynamic_vars = { + k: v for k, v in context_vars_dict.items() + if k in dynamic_var_names + } + if dynamic_vars: + device = next(self.dynamic_cond_module.parameters()).device + dynamic_vars = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in dynamic_vars.items() + } + dynamic_embedding, _ = self.dynamic_cond_module(dynamic_vars) + # Check for NaN in dynamic embedding + if torch.isnan(dynamic_embedding).any() or torch.isinf(dynamic_embedding).any(): + raise ValueError( + f"NaN/Inf detected in dynamic embedding. " + f"Dynamic vars: {list(dynamic_vars.keys())}" + ) + embeddings.append(dynamic_embedding) + + # Combine embeddings if both exist + if len(embeddings) == 2: + combined = torch.cat(embeddings, dim=1) + embedding = self.combine_mlp(combined) + elif len(embeddings) == 1: + embedding = embeddings[0] + else: + raise ValueError("No context variables provided") - embedding, _ = self.cond_module(cat_vars_dict) return self.stats_head(embedding) @@ -214,35 +279,69 @@ def __init__( # Get continuous variables from config if specified continuous_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "continuous"] categorical_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "categorical"] - self.context_vars = categorical_vars + continuous_vars + dynamic_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "time_series"] + + self.static_context_vars = categorical_vars + continuous_vars + self.dynamic_context_vars = dynamic_vars + self.context_vars = self.static_context_vars + self.dynamic_context_vars self.time_series_cols = dataset_cfg.time_series_columns[ : dataset_cfg.time_series_dims ] self.time_series_dims = dataset_cfg.time_series_dims self.do_scale = dataset_cfg.scale + self.seq_len = dataset_cfg.seq_len # Get context config - context_cfg = get_context_config() - context_module_type = context_cfg.dynamic_context.type - stats_head_type = context_cfg.normalizer.stats_head_type + # context_cfg = get_context_config() + + self.static_module_type = self.dataset.static_module_type + self.dynamic_module_type = self.dataset.dynamic_module_type + self.stats_head_type = self.dataset.stats_head_type - # Use registry to get the context module class - ContextModuleCls = get_context_module_cls(context_module_type) - # Create context module - it will be stored in normalizer_model.cond_module - context_module = ContextModuleCls( - self.dataset_cfg.context_vars, - 256, - ) + # Create static context module (for categorical + continuous) + static_context_module = None + if self.static_context_vars: + StaticContextModuleCls = get_context_module_cls(self.static_module_type) + # Filter context_vars to only static ones + static_context_vars_dict = { + k: v for k, v in self.dataset_cfg.context_vars.items() + if k in self.static_context_vars + } + static_context_module = StaticContextModuleCls( + static_context_vars_dict, + 256, + ) + + # Create dynamic context module (for time_series) + dynamic_context_module = None + if self.dynamic_context_vars and self.dynamic_module_type is not None: + DynamicContextModuleCls = get_context_module_cls("dynamic", self.dynamic_module_type) + # Filter context_vars to only dynamic ones + dynamic_context_vars_dict = { + k: v for k, v in self.dataset_cfg.context_vars.items() + if k in self.dynamic_context_vars + } + dynamic_context_module = DynamicContextModuleCls( + dynamic_context_vars_dict, + 256, + seq_len=self.seq_len, + ) self.normalizer_model = _NormalizerModule( - cond_module=context_module, + static_cond_module=static_context_module, + dynamic_cond_module=dynamic_context_module, hidden_dim=512, time_series_dims=self.time_series_dims, do_scale=self.do_scale, - stats_head_type=stats_head_type, + stats_head_type=self.stats_head_type, ) - self.context_module = self.normalizer_model.cond_module + # Store dynamic var names for filtering in forward + self.normalizer_model._dynamic_var_names = self.dynamic_context_vars + # For backward compatibility, expose the static context module + self.context_module = self.normalizer_model.static_cond_module + # Expose the dynamic context module at top level so it shows in model summary + self.dynamic_cond_module = self.normalizer_model.dynamic_cond_module # Will be populated in setup() self.group_stats = {} @@ -332,8 +431,8 @@ def training_step(self, batch, batch_idx: int): Returns: loss tensor. """ - cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch - pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped = self(cat_vars_dict) + context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch + pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped = self(context_vars_dict) # Use standard MSE loss for mu loss_mu = F.mse_loss(pred_mu, mu_t) @@ -442,8 +541,8 @@ def train_dataloader(self): ds, batch_size=self.normalizer_training_cfg.batch_size, shuffle=True, - num_workers=8, # Reduce workers to avoid synchronization issues - persistent_workers=True, + num_workers=4, # Use fewer workers to reduce overhead + persistent_workers=False, # Disable to avoid multiprocessing cleanup issues pin_memory=torch.cuda.is_available(), # Helps with GPU transfer prefetch_factor=2, # Reduce prefetch to avoid memory issues ) @@ -453,12 +552,29 @@ def _compute_group_stats(self) -> dict: Compute per-group (context combination) statistics from raw data. Returns: - Mapping from context tuple to (mu_array, std_array, zmin_array, zmax_array). + Mapping from context tuple to (mu_array, std_array, zmin_array, zmax_array, dynamic_ctx_dict). """ df = self.dataset.data.copy() grouped_stats = {} - for group_vals, group_df in df.groupby(self.context_vars): + for group_vals, group_df in df.groupby(self.static_context_vars): dimension_points = [[] for _ in range(self.time_series_dims)] + # Store dynamic context variables (time series) for this group + # We'll use the first row's dynamic context as representative + dynamic_ctx_dict = {} + if self.dynamic_context_vars: + first_row = group_df.iloc[0] + for var_name in self.dynamic_context_vars: + if var_name in first_row: + # Get the time series sequence + ts_data = first_row[var_name] + if isinstance(ts_data, np.ndarray): + dynamic_ctx_dict[var_name] = ts_data + elif isinstance(ts_data, list): + dynamic_ctx_dict[var_name] = np.array(ts_data) + else: + # If it's a scalar, repeat it to match seq_len + dynamic_ctx_dict[var_name] = np.full(self.seq_len, ts_data) + for _, row in group_df.iterrows(): for d, col_name in enumerate(self.time_series_cols): arr = np.array(row[col_name], dtype=np.float32).flatten() @@ -494,6 +610,7 @@ def _compute_group_stats(self) -> dict: std_array, z_min_array, z_max_array, + dynamic_ctx_dict, ) return grouped_stats @@ -502,15 +619,16 @@ def _create_training_dataset(self) -> Dataset: Build an internal Dataset yielding true stats for each context group. Returns: - PyTorch Dataset of samples (cat_vars_dict, mu, sigma, zmin, zmax). + PyTorch Dataset of samples (context_vars_dict, mu, sigma, zmin, zmax). """ data_tuples = [ - (ctx_tuple, mu_arr, sigma_arr, zmin_arr, zmax_arr) + (ctx_tuple, mu_arr, sigma_arr, zmin_arr, zmax_arr, dynamic_ctx_dict) for ctx_tuple, ( mu_arr, sigma_arr, zmin_arr, zmax_arr, + dynamic_ctx_dict, ) in self.group_stats.items() ] @@ -521,11 +639,13 @@ class _TrainSet(Dataset): Adapter Dataset to wrap group_stats tuples for DataLoader. """ - def __init__(self, samples, context_vars, do_scale, continuous_vars): + def __init__(self, samples, static_context_vars, dynamic_context_vars, do_scale, continuous_vars, dataset_cfg): self.samples = samples - self.context_vars = context_vars + self.static_context_vars = static_context_vars + self.dynamic_context_vars = dynamic_context_vars self.do_scale = do_scale self.continuous_vars = continuous_vars + self.dataset_cfg = dataset_cfg def __len__(self) -> int: return len(self.samples) @@ -538,27 +658,52 @@ def __getitem__(self, idx: int): idx: Index of the sample. Returns: - cat_vars_dict: Tensor dict of context labels. + context_vars_dict: Tensor dict of context labels (static + dynamic). mu_t: True mean tensor. sigma_t: True std tensor. zmin_t: True min z-score tensor or None. zmax_t: True max z-score tensor or None. """ - ctx_tuple, mu_arr, sigma_arr, zmin_arr, zmax_arr = self.samples[idx] - cat_vars_dict = {} + ctx_tuple, mu_arr, sigma_arr, zmin_arr, zmax_arr, dynamic_ctx_dict = self.samples[idx] + context_vars_dict = {} - for i, var_name in enumerate(self.context_vars): + # Process static context variables + for i, var_name in enumerate(self.static_context_vars): if var_name in self.continuous_vars: - cat_vars_dict[var_name] = torch.tensor(ctx_tuple[i], dtype=torch.float32) + context_vars_dict[var_name] = torch.tensor(ctx_tuple[i], dtype=torch.float32) else: - cat_vars_dict[var_name] = torch.tensor(ctx_tuple[i], dtype=torch.long) + context_vars_dict[var_name] = torch.tensor(ctx_tuple[i], dtype=torch.long) + + # Process dynamic context variables (time series) + for var_name in self.dynamic_context_vars: + if var_name in dynamic_ctx_dict: + ts_data = dynamic_ctx_dict[var_name] + # Convert to tensor + if isinstance(ts_data, np.ndarray): + # Check if it's categorical (integer) or numeric (float) + var_info = self.dataset_cfg.context_vars.get(var_name, None) + if var_info and var_info[1] is not None: + # Categorical time series + context_vars_dict[var_name] = torch.from_numpy(ts_data).long() + else: + # Numeric time series + context_vars_dict[var_name] = torch.from_numpy(ts_data).float() + else: + # Fallback: convert to array first + ts_array = np.array(ts_data) + var_info = self.dataset_cfg.context_vars.get(var_name, None) + if var_info and var_info[1] is not None: + context_vars_dict[var_name] = torch.from_numpy(ts_array).long() + else: + context_vars_dict[var_name] = torch.from_numpy(ts_array).float() + mu_t = torch.from_numpy(mu_arr).float() sigma_t = torch.from_numpy(sigma_arr).float() zmin_t = torch.from_numpy(zmin_arr).float() if self.do_scale else None zmax_t = torch.from_numpy(zmax_arr).float() if self.do_scale else None - return cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t + return context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t - return _TrainSet(data_tuples, self.context_vars, self.do_scale, continuous_vars) + return _TrainSet(data_tuples, self.static_context_vars, self.dynamic_context_vars, self.do_scale, continuous_vars, self.dataset_cfg) def transform(self, df: pd.DataFrame) -> pd.DataFrame: """ @@ -595,7 +740,16 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: for v in self.context_vars: if v in continuous_vars: ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + elif v in self.dynamic_context_vars: + # Dynamic (time series) variable + if v in categorical_ts: + # Categorical time series - keep as long + ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) + else: + # Numeric time series - convert to float32 + ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) else: + # Static categorical variable ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) mu, sigma, zmin, zmax, _ = self(ctx) mu, sigma = mu[0].cpu().numpy(), sigma[0].cpu().numpy() @@ -642,12 +796,23 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: df_out = df.copy() self.eval() continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] + # Get categorical time series from dataset if available + categorical_ts = getattr(self.dataset, 'categorical_time_series', {}) with torch.no_grad(): for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Inverse normalizing"): ctx = {} for v in self.context_vars: if v in continuous_vars: + # Static continuous variable ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + elif v in self.dynamic_context_vars: + # Dynamic (time series) variable + if v in categorical_ts: + # Categorical time series - keep as long + ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) + else: + # Numeric time series - convert to float32 + ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) else: ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) mu, sigma, zmin, zmax, _ = self(ctx) diff --git a/cents/utils/config_loader.py b/cents/utils/config_loader.py index 7ebd1f1..7882107 100644 --- a/cents/utils/config_loader.py +++ b/cents/utils/config_loader.py @@ -32,6 +32,8 @@ def _coerce_scalar(value: str): v = value.strip() if v.lower() in ("true", "false"): return v.lower() == "true" + if v.lower() in ("null", "none"): + return None try: return int(v) except ValueError: diff --git a/cents/utils/utils.py b/cents/utils/utils.py index d26eede..fb108a5 100644 --- a/cents/utils/utils.py +++ b/cents/utils/utils.py @@ -13,8 +13,9 @@ def _ckpt_name( dims: int, *, ext: str = "ckpt", - context_module_type: str = None, - stats_head_type: str = None + static_module_type: str = None, + stats_head_type: str = None, + dynamic_module_type: str = None, ) -> str: """ Generate checkpoint filename with optional context_module_type and stats_head_type. @@ -32,12 +33,15 @@ def _ckpt_name( """ parts = [dataset, model, f"dim{dims}"] - if context_module_type: - parts.append(f"ctx{context_module_type}") + if static_module_type: + parts.append(f"ctx{static_module_type}") if stats_head_type: parts.append(f"stats{stats_head_type}") + if dynamic_module_type: + parts.append(f"dyn{dynamic_module_type}") + return "_".join(parts) + f".{ext}" @@ -67,6 +71,7 @@ def get_normalizer_training_config(): return OmegaConf.load(config_path) _context_config_path = None +_context_overrides = [] def set_context_config_path(path: str): @@ -80,9 +85,22 @@ def set_context_config_path(path: str): global _context_config_path _context_config_path = path + +def set_context_overrides(overrides: list): + """ + Set overrides to apply to the context configuration. + + Args: + overrides: List of override strings (e.g., ["static_context.type=mlp", "dynamic_context.type=cnn"]) + """ + global _context_overrides + _context_overrides = overrides if overrides else [] + + def get_context_config(path: str = None): """ Load the context configuration from config/context/default.yaml or a custom path. + Overrides can be applied if set via set_context_overrides(). Args: path: Optional path to a custom context config file. If None, uses the path @@ -103,7 +121,16 @@ def get_context_config(path: str = None): "context", "default.yaml", ) - return OmegaConf.load(config_path) + + cfg = OmegaConf.load(config_path) + + # Apply overrides if any + if _context_overrides: + from cents.utils.config_loader import apply_overrides + cfg = apply_overrides(cfg, _context_overrides) + print(f"Applied context config overrides: {_context_overrides}") + + return cfg def get_default_trainer_config(): diff --git a/scripts/train.py b/scripts/train.py index eb52646..c52e812 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -5,7 +5,7 @@ from cents.datasets.commercial import CommercialDataset from cents.datasets.airquality import AirQualityDataset from cents.trainer import Trainer -from cents.utils.utils import set_context_config_path +from cents.utils.utils import set_context_config_path, set_context_overrides from pytorch_lightning.callbacks import EarlyStopping import warnings import argparse @@ -21,6 +21,10 @@ def main(args) -> None: if args.context_config_path: set_context_config_path(args.context_config_path) + # Set context config overrides if provided + if args.context_overrides: + set_context_overrides(args.context_overrides) + # Skip heavy processing for DDP compatibility if args.dataset == "pecanstreet": @@ -80,6 +84,8 @@ def main(args) -> None: parser.add_argument("--enable_checkpointing", type=bool, default=True) parser.add_argument("--context-config-path", type=str, default=None, help="Path to custom context config YAML file (optional)") + parser.add_argument("--context-overrides", type=str, nargs="*", default=[], + help="Override context config values (e.g., 'static_context.type=mlp' 'dynamic_context.type=cnn')") args = parser.parse_args() main(args) From 11e8f46c4c74378f53a19ad5b83e0f43500107d9 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Tue, 20 Jan 2026 12:21:37 -0500 Subject: [PATCH 18/50] Tranformer context, distributed normalizer training --- cents/config/context/default.yaml | 24 +++ cents/config/trainer/diffusion_ts.yaml | 2 +- cents/config/trainer/normalizer.yaml | 6 +- cents/datasets/timeseries_dataset.py | 1 + cents/models/context.py | 251 ++++++++++++++++++++++++- cents/models/normalizer.py | 11 +- 6 files changed, 286 insertions(+), 9 deletions(-) create mode 100644 cents/config/context/default.yaml diff --git a/cents/config/context/default.yaml b/cents/config/context/default.yaml new file mode 100644 index 0000000..7835592 --- /dev/null +++ b/cents/config/context/default.yaml @@ -0,0 +1,24 @@ +# Context configuration +# This file defines the context modules used across the codebase + +# Static context: used by generative models (ACGAN, Diffusion_TS) for conditioning +static_context: + type: sep_mlp # Context module type (e.g., "mlp", "sep_mlp") + # Future parameters can be added here: + # n_layers: 2 + # hidden_dim: 256 + +# Normalizer: stats head configuration for the normalizer +normalizer: + stats_head_type: mlp # Stats head type (e.g., "mlp") + # Future parameters can be added here: + # n_layers: 3 + # hidden_dim: 512 + +# Dynamic context: context module used by the normalizer for time series context variables +dynamic_context: + type: cnn # Context module type for dynamic context (e.g., "cnn") + # Future parameters can be added here: + # n_layers: 2 + # hidden_dim: 256 + diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 07b1b4e..25196b8 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -5,7 +5,7 @@ strategy: ddp_find_unused_parameters_false gradient_accumulate_every: 4 log_every_n_steps: 1 batch_size: 512 -max_epochs: 2000 +max_epochs: 5000 base_lr: 1e-4 eval_after_training: False diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index 2250311..975649d 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,11 +1,11 @@ -strategy: auto +strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: 1 +devices: auto precision: 16-mixed log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 -n_epochs: 1000 +n_epochs: 2000 batch_size: 8192 lr: 1e-5 gradient_clip_val: 1.0 diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index e23005e..34f00c2 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -675,6 +675,7 @@ def _init_normalizer(self) -> None: # train and cache a single state dict print("[Cents] Training normalizer…") + print(f"[Cents] devices: {ncfg.devices}") trainer = pl.Trainer( max_epochs=ncfg.n_epochs, accelerator=ncfg.accelerator, diff --git a/cents/models/context.py b/cents/models/context.py index 1b17b2b..94a75ac 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -119,6 +119,9 @@ def __init__( } ) + print(self.continuous_vars, "CONT VARS") + print(self.categorical_vars, "CAT VARS") + # For continuous variables, use a simple linear projection self.continuous_projections = nn.ModuleDict( { @@ -260,7 +263,7 @@ def forward(self, context_vars): @register_context_module("dynamic_cnn") -class DynamicContextModule(BaseContextModule): +class DynamicContextModule_CNN(BaseContextModule): """ Context module for processing dynamic (time series) context variables. Uses 1D convolutions to encode time series sequences into embeddings. @@ -426,4 +429,248 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, if torch.isnan(embedding).any() or torch.isinf(embedding).any(): raise ValueError(f"NaN/Inf detected in final embedding after mixing MLP") - return embedding, {} \ No newline at end of file + return embedding, {} + + +@register_context_module("dynamic_transformer") +class DynamicContextModule_Transformer(BaseContextModule): + """ + Context module for processing dynamic (time series) context variables. + Uses Transformer encoder to encode time series sequences into embeddings. + """ + + def __init__( + self, + context_vars: dict[str, int], + embedding_dim: int, + seq_len: int = None, + n_layers: int = 2, + n_heads: int = 4, + dropout: float = 0.1, + dim_feedforward: int = 256, + ): + """ + Initialize DynamicContextModule_Transformer. + + Args: + context_vars: Mapping of variable names to category counts (for categorical time series) + or None (for numeric time series). Format: {name: [type, num_categories]} + embedding_dim: Size of embedding vectors. + seq_len: Sequence length of time series context variables. + n_layers: Number of transformer encoder layers. + n_heads: Number of attention heads. + dropout: Dropout probability. + dim_feedforward: Dimension of feedforward network in transformer. + """ + super().__init__() + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + # Separate categorical and numeric time series + self.categorical_ts_vars = { + k: v[1] for k, v in context_vars.items() + if v[0] == "time_series" and v[1] is not None + } + self.numeric_ts_vars = [ + k for k, v in context_vars.items() + if v[0] == "time_series" and v[1] is None + ] + + # For categorical time series, use embedding + self.ts_embeddings = nn.ModuleDict({ + name: nn.Embedding(num_categories, embedding_dim) + for name, num_categories in self.categorical_ts_vars.items() + }) + + # For numeric time series, use linear projection to embedding_dim + self.ts_projections = nn.ModuleDict({ + name: nn.Linear(1, embedding_dim) + for name in self.numeric_ts_vars + }) + + # Positional encoding for transformer + if seq_len is not None: + self.pos_encodings = nn.ParameterDict({ + name: nn.Parameter(torch.zeros(1, seq_len, embedding_dim)) + for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars + }) + else: + # If seq_len not provided, use learnable positional encoding that can adapt + self.pos_encodings = None + + # Transformer encoder for each time series variable + encoder_layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=n_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation='gelu', + batch_first=True, + ) + + self.ts_encoders = nn.ModuleDict({ + name: nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars + }) + + # Pooling layer to get fixed-size embedding from sequence + # Use learnable weighted pooling (attention pooling) + self.pooling_layers = nn.ModuleDict({ + name: nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.Tanh(), + nn.Linear(embedding_dim, 1, bias=False), + ) + for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars + }) + + # Mixing MLP to combine all time series embeddings + total_dim = embedding_dim * (len(self.categorical_ts_vars) + len(self.numeric_ts_vars)) + if total_dim > 0: + self.mixing_mlp = nn.Sequential( + nn.Linear(total_dim, 128), + nn.ReLU(), + nn.Linear(128, embedding_dim), + ) + else: + self.mixing_mlp = nn.Identity() + + # Initialize weights + self._initialize_weights() + + def _initialize_weights(self): + """ + Initialize weights using appropriate initialization strategies. + """ + for module in self.modules(): + if isinstance(module, nn.Linear): + # Xavier initialization for transformer linear layers + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.Parameter): + # Initialize positional encodings + if module.dim() == 3: # (1, seq_len, embedding_dim) + nn.init.normal_(module, std=0.02) + # Note: Embedding layers keep their default initialization + # Transformer encoder layers use their own initialization + + def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """ + Process dynamic (time series) context variables using Transformer. + + Args: + context_vars: Dict mapping variable names to tensors. + For categorical TS: (batch, seq_len) with integer values + For numeric TS: (batch, seq_len) with float values + + Returns: + embedding: Combined embedding of shape (batch_size, embedding_dim) + outputs: Empty dict for compatibility + """ + embeddings = [] + + # Process categorical time series + for name in self.categorical_ts_vars.keys(): + if name in context_vars: + # Input: (batch, seq_len) with integer indices + ts_data = context_vars[name] # (batch, seq_len) + # Check for NaN/Inf in input + if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): + raise ValueError(f"NaN/Inf detected in categorical time series input '{name}'") + + # Embed: (batch, seq_len) -> (batch, seq_len, embedding_dim) + embedded = self.ts_embeddings[name](ts_data) + + # Add positional encoding if available + if self.pos_encodings is not None and name in self.pos_encodings: + seq_len_actual = embedded.size(1) + pos_enc = self.pos_encodings[name][:, :seq_len_actual, :] + embedded = embedded + pos_enc + + # Check for NaN after embedding + if torch.isnan(embedded).any() or torch.isinf(embedded).any(): + raise ValueError(f"NaN/Inf detected after embedding for '{name}'") + + # Encode with transformer: (batch, seq_len, embedding_dim) -> (batch, seq_len, embedding_dim) + encoded = self.ts_encoders[name](embedded) + + # Check for NaN after encoding + if torch.isnan(encoded).any() or torch.isinf(encoded).any(): + raise ValueError(f"NaN/Inf detected after transformer encoding for '{name}'") + + # Pool to fixed size: (batch, seq_len, embedding_dim) -> (batch, embedding_dim) + # Use attention-based pooling + attention_weights = self.pooling_layers[name](encoded) # (batch, seq_len, 1) + attention_weights = torch.softmax(attention_weights, dim=1) + pooled = (encoded * attention_weights).sum(dim=1) # (batch, embedding_dim) + + embeddings.append(pooled) + + # Process numeric time series + for name in self.numeric_ts_vars: + if name in context_vars: + # Input: (batch, seq_len) with float values + ts_data = context_vars[name] # (batch, seq_len) + # Ensure numeric time series are float type (not long/int) + if not ts_data.is_floating_point(): + ts_data = ts_data.float() + + # Check for NaN/Inf in input + if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): + raise ValueError(f"NaN/Inf detected in numeric time series input '{name}'") + + # Replace NaN/Inf with zeros to prevent propagation + ts_data = torch.where(torch.isfinite(ts_data), ts_data, torch.zeros_like(ts_data)) + + # Project to embedding_dim: (batch, seq_len) -> (batch, seq_len, embedding_dim) + ts_data_expanded = ts_data.unsqueeze(-1) # (batch, seq_len, 1) + embedded = self.ts_projections[name](ts_data_expanded) # (batch, seq_len, embedding_dim) + + # Add positional encoding if available + if self.pos_encodings is not None and name in self.pos_encodings: + seq_len_actual = embedded.size(1) + pos_enc = self.pos_encodings[name][:, :seq_len_actual, :] + embedded = embedded + pos_enc + + # Check for NaN after projection + if torch.isnan(embedded).any() or torch.isinf(embedded).any(): + raise ValueError(f"NaN/Inf detected after projection for '{name}'") + + # Encode with transformer: (batch, seq_len, embedding_dim) -> (batch, seq_len, embedding_dim) + encoded = self.ts_encoders[name](embedded) + + # Check for NaN after encoding + if torch.isnan(encoded).any() or torch.isinf(encoded).any(): + raise ValueError(f"NaN/Inf detected after transformer encoding numeric TS '{name}'") + + # Pool to fixed size: (batch, seq_len, embedding_dim) -> (batch, embedding_dim) + # Use attention-based pooling + attention_weights = self.pooling_layers[name](encoded) # (batch, seq_len, 1) + attention_weights = torch.softmax(attention_weights, dim=1) + pooled = (encoded * attention_weights).sum(dim=1) # (batch, embedding_dim) + + embeddings.append(pooled) + + if not embeddings: + # No dynamic context variables, return zero embedding + batch_size = next(iter(context_vars.values())).size(0) if context_vars else 1 + embedding = torch.zeros(batch_size, self.embedding_dim, device=next(iter(context_vars.values())).device if context_vars else None) + return embedding, {} + + # Combine all time series embeddings + combined = torch.cat(embeddings, dim=1) # (batch, total_dim) + # Check for NaN before mixing + if torch.isnan(combined).any() or torch.isinf(combined).any(): + raise ValueError(f"NaN/Inf detected in combined embeddings before mixing MLP") + embedding = self.mixing_mlp(combined) # (batch, embedding_dim) + # Check for NaN after mixing + if torch.isnan(embedding).any() or torch.isinf(embedding).any(): + raise ValueError(f"NaN/Inf detected in final embedding after mixing MLP") + + return embedding, {} + + def on_after_backward(self): + unused = [n for n,p in self.named_parameters() if p.requires_grad and p.grad is None] + if unused: + print("UNUSED:", unused[:50]) \ No newline at end of file diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 78fb1d8..ec03ab9 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -12,7 +12,7 @@ from cents.datasets.utils import split_timeseries from cents.models.base import NormalizerModel -from cents.models.context import MLPContextModule, SepMLPContextModule, DynamicContextModule # Import to trigger registration +from cents.models.context import MLPContextModule, SepMLPContextModule, DynamicContextModule_CNN, DynamicContextModule_Transformer # Import to trigger registration from cents.models.context_registry import get_context_module_cls from cents.models.stats_head_registry import register_stats_head, get_stats_head_cls from cents.models.registry import register_model @@ -304,12 +304,12 @@ def __init__( if self.static_context_vars: StaticContextModuleCls = get_context_module_cls(self.static_module_type) # Filter context_vars to only static ones - static_context_vars_dict = { + self.static_context_vars_dict = { k: v for k, v in self.dataset_cfg.context_vars.items() if k in self.static_context_vars } static_context_module = StaticContextModuleCls( - static_context_vars_dict, + self.static_context_vars_dict, 256, ) @@ -516,6 +516,7 @@ def on_train_batch_end(self, outputs, batch, batch_idx): param_count += 1 if param_norm.item() < 1e-8: zero_grad_count += 1 + print(name, "HAS ZERO GRAD") else: # Parameter has no gradient - this might indicate a problem if 'cond_module' in name or 'stats_head' in name: @@ -546,6 +547,10 @@ def train_dataloader(self): pin_memory=torch.cuda.is_available(), # Helps with GPU transfer prefetch_factor=2, # Reduce prefetch to avoid memory issues ) + # def on_after_backward(self): + # unused = [n for n,p in self.named_parameters() if p.requires_grad and p.grad is None] + # if unused: + # print("UNUSED:", unused[:50]) def _compute_group_stats(self) -> dict: """ From 83da97c527e28ac67dfa32248429021df7d6c378 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Tue, 20 Jan 2026 16:38:57 -0500 Subject: [PATCH 19/50] Added Gaussian NLL Objective for Normalizer Training --- cents/config/trainer/normalizer.yaml | 1 + cents/datasets/airquality.py | 4 +- cents/datasets/commercial.py | 6 ++- cents/datasets/pecanstreet.py | 4 +- cents/datasets/timeseries_dataset.py | 8 ++- cents/models/context.py | 5 -- cents/models/normalizer.py | 78 +++++++++++++++++++++++++--- scripts/train.py | 17 ++++-- 8 files changed, 102 insertions(+), 21 deletions(-) diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index 975649d..a177088 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -11,6 +11,7 @@ lr: 1e-5 gradient_clip_val: 1.0 save_cycle: 5000 eval_after_training: False +loss_type: gaussian_nll # Options: "mse" or "gaussian_nll" checkpoint: save_last: False diff --git a/cents/datasets/airquality.py b/cents/datasets/airquality.py index 2553f1b..06d8207 100644 --- a/cents/datasets/airquality.py +++ b/cents/datasets/airquality.py @@ -17,7 +17,8 @@ class AirQualityDataset(TimeSeriesDataset): def __init__(self, cfg: DictConfig = None, - overrides: Optional[List[str]] = None): + overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False): """ Initializes the AirQuality Dataset. Available at: https://doi.org/10.24432/C5RK5G. @@ -61,6 +62,7 @@ def __init__(self, cfg: DictConfig = None, skip_heavy_processing=cfg.get('skip_heavy_processing', False), size=cfg.get('max_samples', None), categorical_time_series=self.categorical_time_series, + force_retrain_normalizer=force_retrain_normalizer, ) def _load_data(self): diff --git a/cents/datasets/commercial.py b/cents/datasets/commercial.py index 14a207b..c84c6c4 100644 --- a/cents/datasets/commercial.py +++ b/cents/datasets/commercial.py @@ -16,7 +16,8 @@ class CommercialDataset(TimeSeriesDataset): def __init__(self, cfg: DictConfig = None, - overrides: Optional[List[str]] = None): + overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False): """ Initializes the commercial energy dataset. @@ -53,7 +54,8 @@ def __init__(self, cfg: DictConfig = None, normalize=self.cfg.normalize, scale=self.cfg.scale, skip_heavy_processing=cfg.get('skip_heavy_processing', False), - size=cfg.get('max_samples', None) + size=cfg.get('max_samples', None), + force_retrain_normalizer=force_retrain_normalizer, ) def _load_data(self): diff --git a/cents/datasets/pecanstreet.py b/cents/datasets/pecanstreet.py index 170cfef..2c551f8 100644 --- a/cents/datasets/pecanstreet.py +++ b/cents/datasets/pecanstreet.py @@ -33,6 +33,7 @@ def __init__( self, cfg: Optional[DictConfig] = None, overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False, ): """ Initialize and preprocess the PecanStreet dataset. @@ -83,7 +84,8 @@ def __init__( normalize=self.cfg.normalize, scale=self.cfg.scale, skip_heavy_processing=cfg.get('skip_heavy_processing', False), - size=cfg.get('max_samples', None) + size=cfg.get('max_samples', None), + force_retrain_normalizer=force_retrain_normalizer, ) def _load_data(self) -> None: diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 34f00c2..f335c52 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -55,6 +55,7 @@ def __init__( skip_heavy_processing: bool = False, size: int = None, categorical_time_series: Dict[str, int] = None, + force_retrain_normalizer: bool = False, ): # Initialize basic attributes # Handle OmegaConf ListConfig objects @@ -97,6 +98,7 @@ def __init__( self.normalize = normalize self.scale = scale + self.force_retrain_normalizer = force_retrain_normalizer # Store categorical time series info self.categorical_time_series = categorical_time_series or {} @@ -658,8 +660,8 @@ def _init_normalizer(self) -> None: dataset=self, ) - # attempt to load existing state dict - if cache_path.exists(): + # attempt to load existing state dict (unless force_retrain_normalizer is True) + if cache_path.exists() and not self.force_retrain_normalizer: try: state = torch.load(cache_path, map_location="cpu") sd = state.get("state_dict", state) @@ -672,6 +674,8 @@ def _init_normalizer(self) -> None: cache_path.unlink() except OSError: pass + elif self.force_retrain_normalizer and cache_path.exists(): + print(f"[Cents] Force retrain enabled, ignoring cached normalizer at {cache_path}") # train and cache a single state dict print("[Cents] Training normalizer…") diff --git a/cents/models/context.py b/cents/models/context.py index 94a75ac..cdd7d14 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -119,9 +119,6 @@ def __init__( } ) - print(self.continuous_vars, "CONT VARS") - print(self.categorical_vars, "CAT VARS") - # For continuous variables, use a simple linear projection self.continuous_projections = nn.ModuleDict( { @@ -172,8 +169,6 @@ def __init__( ) def forward(self, context_vars): - #print(self.continuous_vars, "CONT VARS") - #print(context_vars, "VARS") encodings = {} # Process categorical variables (only those present in context_vars) diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index ec03ab9..ee8905c 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -299,6 +299,9 @@ def __init__( self.dynamic_module_type = self.dataset.dynamic_module_type self.stats_head_type = self.dataset.stats_head_type + # Get loss type from config (default to "mse") + self.loss_type = getattr(self.normalizer_training_cfg, "loss_type", "mse") + # Create static context module (for categorical + continuous) static_context_module = None if self.static_context_vars: @@ -420,6 +423,60 @@ def forward(self, cat_vars_dict: dict): """ return self.normalizer_model(cat_vars_dict) + def _compute_loss_mse(self, pred_mu, pred_sigma, pred_log_sigma_unclamped, mu_t, sigma_t): + """ + Compute MSE loss for mu and sigma. + + Args: + pred_mu: Predicted means + pred_sigma: Predicted standard deviations + pred_log_sigma_unclamped: Unclamped log sigma predictions + mu_t: Target means + sigma_t: Target standard deviations + + Returns: + loss_mu, loss_sigma + """ + # Use standard MSE loss for mu + loss_mu = F.mse_loss(pred_mu, mu_t) + + # Use log-space loss for sigma - this is more numerically stable + # and handles scale differences better + target_log_sigma = torch.log(sigma_t + 1e-8) # Add small epsilon to avoid log(0) + loss_sigma = F.mse_loss(pred_log_sigma_unclamped, target_log_sigma) + + return loss_mu, loss_sigma + + def _compute_loss_gaussian_nll(self, pred_mu, pred_sigma, mu_t, sigma_t): + """ + Compute Gaussian Negative Log Likelihood loss. + + Treats target mu_t as observations from N(pred_mu, pred_sigma^2). + For sigma, still uses log-space MSE since it's a scale parameter. + + Args: + pred_mu: Predicted means + pred_sigma: Predicted standard deviations + mu_t: Target means (treated as observations) + sigma_t: Target standard deviations + + Returns: + loss_mu, loss_sigma + """ + # Use Gaussian NLL for mu: treat mu_t as observations from N(pred_mu, pred_sigma^2) + # GaussianNLLLoss expects: input (mean), target (observations), var (variance) + # We need to ensure variance is positive and not too small + pred_var = torch.clamp(pred_sigma ** 2, min=1e-6) + gaussian_nll = nn.GaussianNLLLoss(reduction='mean') + loss_mu = gaussian_nll(pred_mu, mu_t, pred_var) + + # For sigma, still use log-space MSE (sigma is a scale parameter, not a location) + pred_log_sigma = torch.log(pred_sigma + 1e-8) + target_log_sigma = torch.log(sigma_t + 1e-8) + loss_sigma = F.mse_loss(pred_log_sigma, target_log_sigma) + + return loss_mu, loss_sigma + def training_step(self, batch, batch_idx: int): """ Training step: regress predicted stats against true group stats. @@ -434,13 +491,20 @@ def training_step(self, batch, batch_idx: int): context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped = self(context_vars_dict) - # Use standard MSE loss for mu - loss_mu = F.mse_loss(pred_mu, mu_t) - - # Use log-space loss for sigma - this is more numerically stable - # and handles scale differences better - target_log_sigma = torch.log(sigma_t + 1e-8) # Add small epsilon to avoid log(0) - loss_sigma = F.mse_loss(pred_log_sigma_unclamped, target_log_sigma) + # Compute loss based on loss_type + if self.loss_type == "mse": + loss_mu, loss_sigma = self._compute_loss_mse( + pred_mu, pred_sigma, pred_log_sigma_unclamped, mu_t, sigma_t + ) + elif self.loss_type == "gaussian_nll": + loss_mu, loss_sigma = self._compute_loss_gaussian_nll( + pred_mu, pred_sigma, mu_t, sigma_t + ) + else: + raise ValueError( + f"Unknown loss_type: {self.loss_type}. " + f"Supported types: 'mse', 'gaussian_nll'" + ) total_loss = loss_mu + loss_sigma diff --git a/scripts/train.py b/scripts/train.py index c52e812..9e733b4 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -28,11 +28,20 @@ def main(args) -> None: # Skip heavy processing for DDP compatibility if args.dataset == "pecanstreet": - dataset = PecanStreetDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}, time_series_dims=1, user_group=all"]) + dataset = PecanStreetDataset( + overrides=[f"skip_heavy_processing={args.skip_heavy_processing}, time_series_dims=1, user_group=all"], + force_retrain_normalizer=args.force_retrain_normalizer + ) elif args.dataset == "commercial": - dataset = CommercialDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"]) + dataset = CommercialDataset( + overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"], + force_retrain_normalizer=args.force_retrain_normalizer + ) elif args.dataset == "airquality": - dataset = AirQualityDataset(overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"]) + dataset = AirQualityDataset( + overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"], + force_retrain_normalizer=args.force_retrain_normalizer + ) else: raise ValueError(f"Dataset {args.dataset} not supported") @@ -86,6 +95,8 @@ def main(args) -> None: help="Path to custom context config YAML file (optional)") parser.add_argument("--context-overrides", type=str, nargs="*", default=[], help="Override context config values (e.g., 'static_context.type=mlp' 'dynamic_context.type=cnn')") + parser.add_argument("--force-retrain-normalizer", type=bool, default=False, + help="Force retraining of normalizer even if cached version exists") args = parser.parse_args() main(args) From d8c8877d83caaa3c5033142dc595a28835cf4296 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 21 Jan 2026 15:02:15 -0500 Subject: [PATCH 20/50] Resume from checkpoint --- cents/config/trainer/normalizer.yaml | 2 +- cents/trainer.py | 21 +++++++++++++++------ scripts/train.py | 4 +++- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index a177088..a42ac0e 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -11,7 +11,7 @@ lr: 1e-5 gradient_clip_val: 1.0 save_cycle: 5000 eval_after_training: False -loss_type: gaussian_nll # Options: "mse" or "gaussian_nll" +loss_type: mse # Options: "mse" or "gaussian_nll" checkpoint: save_last: False diff --git a/cents/trainer.py b/cents/trainer.py index 75764d8..6e827ee 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -70,15 +70,19 @@ def __init__( self.model = self._instantiate_model() self.pl_trainer = self._instantiate_trainer() - def fit(self) -> "Trainer": + def fit(self, ckpt_path: Optional[str] = None) -> "Trainer": """ Start training. + Args: + ckpt_path: Optional path to checkpoint file (.ckpt) to resume training from. + If provided, training will resume from this checkpoint. + Returns: Self, to allow method chaining. """ if self.model_type == "normalizer": - self.pl_trainer.fit(self.model) + self.pl_trainer.fit(self.model, ckpt_path=ckpt_path) else: train_loader = self.dataset.get_train_dataloader( batch_size=self.cfg.trainer.batch_size, @@ -86,7 +90,7 @@ def fit(self) -> "Trainer": num_workers=6, # Maximum for 7.5GB/10GB GPU usage persistent_workers=True, ) - self.pl_trainer.fit(self.model, train_loader, None) + self.pl_trainer.fit(self.model, train_loader, None, ckpt_path=ckpt_path) return self def get_data_generator(self) -> DataGenerator: @@ -204,15 +208,20 @@ def _instantiate_trainer(self) -> pl.Trainer: # Add context_module_type from context config from cents.utils.utils import get_context_config context_cfg = get_context_config() - context_module_type = context_cfg.static_context.type - if context_module_type: - filename_parts.append(f"ctx{context_module_type}") + static_context_module_type = context_cfg.static_context.type + if static_context_module_type: + filename_parts.append(f"ctx{static_context_module_type}") + + dynamic_context_module_type = context_cfg.dynamic_context.type + if dynamic_context_module_type: + filename_parts.append(f"dyn{dynamic_context_module_type}") # Add stats_head_type from context config stats_head_type = context_cfg.normalizer.stats_head_type if stats_head_type: filename_parts.append(f"stats{stats_head_type}") + callbacks.append( ModelCheckpoint( dirpath=self.cfg.run_dir, diff --git a/scripts/train.py b/scripts/train.py index 9e733b4..9a0e151 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -72,7 +72,7 @@ def main(args) -> None: overrides=trainer_overrides, ) - trainer.fit() + trainer.fit(ckpt_path=args.resume_from_checkpoint) if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -97,6 +97,8 @@ def main(args) -> None: help="Override context config values (e.g., 'static_context.type=mlp' 'dynamic_context.type=cnn')") parser.add_argument("--force-retrain-normalizer", type=bool, default=False, help="Force retraining of normalizer even if cached version exists") + parser.add_argument("--resume-from-checkpoint", type=str, default=None, + help="Path to checkpoint file (.ckpt) to resume training from") args = parser.parse_args() main(args) From bdb78edec21c1dba4a9aa991c6b8f91226b3f1bd Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 21 Jan 2026 15:05:59 -0500 Subject: [PATCH 21/50] Added simple generate script --- cents/data_generator.py | 4 +- cents/datasets/timeseries_dataset.py | 31 ++++- cents/datasets/utils.py | 22 +-- scripts/eval_pretrained.py | 10 +- scripts/generate.py | 195 +++++++++++++++++++++++++++ 5 files changed, 240 insertions(+), 22 deletions(-) create mode 100644 scripts/generate.py diff --git a/cents/data_generator.py b/cents/data_generator.py index 31242a8..58836c2 100644 --- a/cents/data_generator.py +++ b/cents/data_generator.py @@ -57,7 +57,7 @@ def __init__( self.device = get_device(device) self.cfg = cfg self._ctx_buff: Dict[str, torch.Tensor] = {} - + self.dataset = dataset if model is not None: self.model = model.to(self.device).eval() self.normalizer = normalizer @@ -243,7 +243,7 @@ def load_from_checkpoint( self.normalizer = Normalizer( dataset_cfg=self.cfg.dataset, normalizer_training_cfg=get_normalizer_training_config(), - dataset=None, + dataset=self.dataset, ) state = torch.load(normalizer_ckpt, map_location=device) sd = state.get("state_dict", state) diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index f335c52..8e870df 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -83,10 +83,9 @@ def __init__( cfg = apply_overrides(cfg, dyn) cfg.time_series_columns = self.time_series_column_names self.numeric_context_bins = cfg.numeric_context_bins - # context_vars = self._get_context_var_dict(data) - # cfg.context_vars = context_vars self.cfg = cfg + self.context_var_dict = self.cfg.context_vars self.numeric_context_bins = self.cfg.numeric_context_bins if not hasattr(self, "threshold"): self.threshold = (-self.cfg.threshold, self.cfg.threshold) @@ -389,8 +388,13 @@ def _normalize_continuous_vars(self): std_val = 1.0 print(f"[Dataset] Warning: {var_name} has zero std, using std=1.0") - # Store stats for reference - self.continuous_var_stats[var_name] = {'mean': mean_val, 'std': std_val} + # Store stats for reference and for sampling in original range + self.continuous_var_stats[var_name] = { + 'mean': mean_val, + 'std': std_val, + 'min': float(values.min()), + 'max': float(values.max()), + } # Normalize the values in-place: (x - mean) / std self.data[var_name] = (values - mean_val) / std_val @@ -447,8 +451,21 @@ def sample_random_context_vars(self) -> Dict[str, torch.Tensor]: dict: Random context index tensors. """ ctx = {} - for var, n in self._get_context_var_dict(self.data).items(): - ctx[var] = torch.randint(0, n, (), dtype=torch.long) + + for var, info in self.context_var_dict.items(): + if info[0] == "categorical": + ctx[var] = torch.randint(0, info[1], (), dtype=torch.long) + elif info[0] == "continuous": + # Sample from original [min, max] then z-score to match training data + stats = getattr(self, "continuous_var_stats", {}).get(var) + if stats is not None and "min" in stats and "max" in stats: + x_raw = np.random.uniform(stats["min"], stats["max"]) + x_norm = (x_raw - stats["mean"]) / stats["std"] + ctx[var] = torch.tensor(x_norm, dtype=torch.float32) + else: + raise ValueError(f"Continuous variable {var} has no stats") + else: + raise ValueError(f"Invalid context variable type: {info[0]}") return ctx def get_context_var_combination_rarities( @@ -653,6 +670,8 @@ def _init_normalizer(self) -> None: dynamic_module_type=self.dynamic_module_type, ) + print(f"[Cents] cache_path: {cache_path}") + ncfg = get_normalizer_training_config() self._normalizer = Normalizer( dataset_cfg=self.cfg, diff --git a/cents/datasets/utils.py b/cents/datasets/utils.py index 3b12810..0aa2cfc 100644 --- a/cents/datasets/utils.py +++ b/cents/datasets/utils.py @@ -239,18 +239,22 @@ def convert_generated_data_to_df( n_samples = data_np.shape[0] - if decode: - if mapping is None: - raise ValueError("Mapping must be provided when decode=True.") - cond_vars = { - var: mapping[var][code.item()] for var, code in context_vars.items() - } - else: - cond_vars = {var: int(code.item()) for var, code in context_vars.items()} + def _get_code_at(code: Any, i: int) -> Any: + if isinstance(code, torch.Tensor) and code.dim() == 1 and code.shape[0] == n_samples: + return code[i].item() + return code.item() if isinstance(code, torch.Tensor) else code records = [] for i in range(n_samples): - record = cond_vars.copy() + record = {} + for var, code in context_vars.items(): + v = _get_code_at(code, i) + if decode: + if mapping is None: + raise ValueError("Mapping must be provided when decode=True.") + record[var] = mapping[var][v] + else: + record[var] = v if isinstance(v, float) else int(v) record["timeseries"] = data_np[i] records.append(record) diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index bcfba38..b75a459 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -34,15 +34,15 @@ def main(args) -> None: if args.context_config_path: set_context_config_path(args.context_config_path) - model_ckpt = "cents/outputs/diffusion_ts_commercial_all/2025-11-13_19-50-40/commercial_diffusion_ts_dim1_ctxsep_mlp_statsmlp.ckpt" + model_ckpt = "cents/outputs/diffusion_ts_pecanstreet_all/2026-01-20_13-25-56/pecanstreet_diffusion_ts_dim1_ctxsep_mlp_statsmlp.ckpt" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) print("Loading dataset...") - # dataset = PecanStreetDataset(overrides=DATASET_OVERRIDES + PECAN_OVERRIDES) - dataset = CommercialDataset(overrides = DATASET_OVERRIDES) + dataset = PecanStreetDataset(overrides=DATASET_OVERRIDES + PECAN_OVERRIDES) + # dataset = CommercialDataset(overrides = DATASET_OVERRIDES) - normalizer_ckpt = HOME / ".cache/cents/checkpoints/commercial/normalizer/commercial_normalizer_dim1_ctxsep_mlp_statsmlp.ckpt" + normalizer_ckpt = HOME / ".cache/cents/checkpoints/pecanstreet/normalizer/pecanstreet_normalizer_dim1_ctxsep_mlp_statsmlp.ckpt" # Build a minimal cfg for evaluator and generator eval_cfg = load_yaml("cents/config/evaluator/default.yaml") top_cfg = load_yaml("cents/config/config.yaml") @@ -63,7 +63,7 @@ def main(args) -> None: cfg.eval_disentanglement = eval_cfg.get("eval_disentanglement", True) cfg.job_name = eval_cfg.get("job_name", "default_job") cfg.save_results = True - cfg.save_dir = HOME / f"cents/outputs/diffusion_ts_commercial_all/2025-11-13_19-50-40/eval" + cfg.save_dir = HOME / f"cents/outputs/diffusion_ts_pecanstreet_all/2026-01-20_13-25-56/eval" if not os.path.exists(cfg.save_dir): diff --git a/scripts/generate.py b/scripts/generate.py new file mode 100644 index 0000000..f40fa8d --- /dev/null +++ b/scripts/generate.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +""" +Generate synthetic time series samples from a trained model. + +Supports: + - Random context: sample context from the dataset's support (including continuous). + - Explicit context: provide context as JSON (categorical: int codes; continuous: z-scored floats). + - Output to Parquet (default) or CSV. +""" + +import argparse +import json +import logging +import os +from pathlib import Path + +import torch +from omegaconf import OmegaConf + +from cents.data_generator import DataGenerator +from cents.datasets.pecanstreet import PecanStreetDataset +from cents.datasets.commercial import CommercialDataset +from cents.datasets.airquality import AirQualityDataset +from cents.datasets.utils import convert_generated_data_to_df +from cents.utils.config_loader import load_yaml +from cents.utils.utils import set_context_config_path + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", +) +DATASET_OVERRIDES = ["max_samples=10000", "skip_heavy_processing=True"] +PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] + + +def _load_dataset(name: str, overrides: list): + if name == "pecanstreet": + return PecanStreetDataset(overrides=DATASET_OVERRIDES + PECAN_OVERRIDES + (overrides or [])) + if name == "commercial": + return CommercialDataset(overrides=DATASET_OVERRIDES + (overrides or [])) + if name == "airquality": + return AirQualityDataset(overrides=DATASET_OVERRIDES + (overrides or [])) + raise ValueError(f"Dataset {name} not supported. Use: pecanstreet, commercial, airquality.") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate synthetic time series from a trained model.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model-ckpt", + type=str, + required=True, + help="Path to model checkpoint (.ckpt or .pt).", + ) + parser.add_argument( + "--normalizer-ckpt", + type=str, + default=None, + help="Path to normalizer checkpoint. If omitted, output stays in normalized space.", + ) + parser.add_argument( + "--model-type", + type=str, + default="diffusion_ts", + help="Model type (e.g. diffusion_ts) used to load the checkpoint.", + ) + parser.add_argument( + "--dataset", + type=str, + default="pecanstreet", + choices=("pecanstreet", "commercial", "airquality"), + help="Dataset name (must match the one used to train the model).", + ) + parser.add_argument( + "-n", + "--num-samples", + type=int, + default=100, + help="Number of samples to generate.", + ) + parser.add_argument( + "-o", + "--out", + type=str, + default="samples.parquet", + help="Output path for generated samples.", + ) + parser.add_argument( + "--format", + type=str, + choices=("parquet", "csv"), + default="parquet", + help="Output format. Parquet preserves array columns better.", + ) + parser.add_argument( + "--random-context", + action="store_true", + help="Sample context randomly from the dataset support (categorical and continuous).", + ) + parser.add_argument( + "--context", + type=str, + default=None, + help='Explicit context as JSON, e.g. \'{"weekday":0,"month":3}\'. ' + "Categorical: int codes. Continuous: z-scored (normalized) floats.", + ) + parser.add_argument( + "--context-config-path", + type=str, + default=None, + help="Path to custom context config YAML (optional).", + ) + parser.add_argument( + "--dataset-overrides", + type=str, + nargs="*", + default=[], + help="Extra dataset overrides, e.g. time_series_dims=1.", + ) + parser.add_argument( + "--no-ema", + action="store_true", + help="Disable EMA sampling (EMA is used by default when present in the checkpoint).", + ) + args = parser.parse_args() + + use_random = args.random_context + use_explicit = args.context is not None and args.context.strip() != "" + if not use_random and not use_explicit: + parser.error("Provide either --random-context or --context (JSON).") + if use_random and use_explicit: + parser.error("Provide only one of --random-context or --context.") + + if args.context_config_path: + set_context_config_path(args.context_config_path) + + overrides = list(args.dataset_overrides) if args.dataset_overrides else [] + + logging.info("Loading dataset %s...", args.dataset) + dataset = _load_dataset(args.dataset, overrides) + cfg = OmegaConf.create({}) + cfg.model = load_yaml(Path("cents/config/model") / f"{args.model_type}.yaml") + cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) + cfg.model.use_ema_sampling = not args.no_ema + + logging.info("Setting up DataGenerator (model_type=%s)...", args.model_type) + gen = DataGenerator(model_type=args.model_type, dataset=dataset, cfg=cfg) + gen.load_from_checkpoint(args.model_ckpt, args.normalizer_ckpt) + # Ensure EMA setting is applied to the config used by the model at generate time + target = getattr(gen.model, "cfg", None) or gen.cfg + if target is not None and hasattr(target, "model"): + target.model.use_ema_sampling = cfg.model.use_ema_sampling + gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) + + if use_random: + # Sample a new random context for each of the n samples + contexts = [dataset.sample_random_context_vars() for _ in range(args.num_samples)] + ctx_batch = { + k: torch.stack([c[k] for c in contexts]).to(gen.device) + for k in contexts[0].keys() + } + logging.info("Generating %d samples with %d random contexts...", args.num_samples, args.num_samples) + with torch.no_grad(): + ts = gen.model.generate(ctx_batch) + df = convert_generated_data_to_df(ts, ctx_batch, decode=False) + else: + context_dict = json.loads(args.context) + for k, v in context_dict.items(): + if isinstance(v, float): + context_dict[k] = v + else: + context_dict[k] = int(v) + gen.set_context(**context_dict) + logging.info("Generating %d samples with context %s...", args.num_samples, context_dict) + df = gen.generate(args.num_samples) + + if gen.normalizer is not None: + df = gen.normalizer.inverse_transform(df) + logging.info("Inverse-normalized outputs to original scale.") + else: + logging.warning("No normalizer loaded; outputs are in normalized space.") + + out = Path(args.out) + out.parent.mkdir(parents=True, exist_ok=True) + if args.format == "parquet": + df.to_parquet(out, index=False) + else: + df.to_csv(out, index=False) + logging.info("Wrote %d samples to %s", len(df), out.resolve()) + + +if __name__ == "__main__": + main() From fef7f882c82f0acd99149a77d83940683b5c83f5 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 21 Jan 2026 18:19:17 -0500 Subject: [PATCH 22/50] changed eval script to use commandline args --- cents/config/context/default.yaml | 2 +- scripts/eval_pretrained.py | 233 ++++++++++++++++++++++-------- 2 files changed, 172 insertions(+), 63 deletions(-) diff --git a/cents/config/context/default.yaml b/cents/config/context/default.yaml index 7835592..d2585ac 100644 --- a/cents/config/context/default.yaml +++ b/cents/config/context/default.yaml @@ -17,7 +17,7 @@ normalizer: # Dynamic context: context module used by the normalizer for time series context variables dynamic_context: - type: cnn # Context module type for dynamic context (e.g., "cnn") + type: null # Context module type for dynamic context (e.g., "cnn") # Future parameters can be added here: # n_layers: 2 # hidden_dim: 256 diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index b75a459..0b30637 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -1,95 +1,204 @@ import logging -from datetime import datetime -from typing import override +import os +from pathlib import Path -# import wandb from omegaconf import OmegaConf +import argparse from cents.data_generator import DataGenerator from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset +from cents.datasets.airquality import AirQualityDataset from cents.eval.eval import Evaluator from cents.utils.config_loader import load_yaml from cents.utils.utils import set_context_config_path -from pathlib import Path -import torch -import os -import argparse -MODEL_KEY = "diffusion_ts" -DATASET_OVERRIDES = [ - "max_samples=10000", - "skip_heavy_processing=True" -] +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", +) +DATASET_OVERRIDES = ["max_samples=10000", "skip_heavy_processing=True"] +PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] -PECAN_OVERRIDES = [ - "time_series_dims=1", - "user_group=all" -] -HOME = Path.home() +def _load_dataset(name: str, overrides: list): + """Load a dataset by name with optional overrides.""" + if name == "pecanstreet": + return PecanStreetDataset(overrides=DATASET_OVERRIDES + PECAN_OVERRIDES + (overrides or [])) + if name == "commercial": + return CommercialDataset(overrides=DATASET_OVERRIDES + (overrides or [])) + if name == "airquality": + return AirQualityDataset(overrides=DATASET_OVERRIDES + (overrides or [])) + raise ValueError(f"Dataset {name} not supported. Use: pecanstreet, commercial, airquality.") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Evaluate a trained model using comprehensive metrics.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model-ckpt", + type=str, + required=True, + help="Path to model checkpoint (.ckpt or .pt).", + ) + parser.add_argument( + "--normalizer-ckpt", + type=str, + default=None, + help="Path to normalizer checkpoint. If omitted, evaluation uses normalized space.", + ) + parser.add_argument( + "--model-type", + type=str, + default="diffusion_ts", + help="Model type (e.g. diffusion_ts) used to load the checkpoint.", + ) + parser.add_argument( + "--dataset", + type=str, + default="pecanstreet", + choices=("pecanstreet", "commercial", "airquality"), + help="Dataset name (must match the one used to train the model).", + ) + parser.add_argument( + "--dataset-overrides", + type=str, + nargs="*", + default=[], + help="Extra dataset overrides, e.g. time_series_dims=1.", + ) + parser.add_argument( + "--save-dir", + type=str, + default=None, + help="Directory to save evaluation results. If None, uses default location based on model checkpoint path.", + ) + parser.add_argument( + "--job-name", + type=str, + default=None, + help="Job name for evaluation run. If None, uses default from evaluator config.", + ) + parser.add_argument( + "--evaluator-config", + type=str, + default="cents/config/evaluator/default.yaml", + help="Path to evaluator config YAML file.", + ) + parser.add_argument( + "--config", + type=str, + default="cents/config/config.yaml", + help="Path to main config YAML file.", + ) + parser.add_argument( + "--no-ema", + action="store_true", + help="Disable EMA sampling (EMA is used by default when present in checkpoint).", + ) + parser.add_argument( + "--eval-pv-shift", + action="store_true", + help="Enable PV shift evaluation.", + ) + parser.add_argument( + "--no-eval-metrics", + action="store_true", + help="Disable evaluation metrics computation.", + ) + parser.add_argument( + "--no-eval-context-sparse", + action="store_true", + help="Disable sparse context evaluation.", + ) + parser.add_argument( + "--no-eval-disentanglement", + action="store_true", + help="Disable disentanglement evaluation.", + ) + parser.add_argument( + "--no-save-results", + action="store_true", + help="Disable saving evaluation results.", + ) + parser.add_argument( + "--context-config-path", + type=str, + default=None, + help="Path to custom context config YAML file (optional).", + ) + args = parser.parse_args() -def main(args) -> None: # Set custom context config path if provided if args.context_config_path: set_context_config_path(args.context_config_path) + + logging.info("Loading dataset %s...", args.dataset) + overrides = list(args.dataset_overrides) if args.dataset_overrides else [] + if args.dataset == "pecanstreet" and "time_series_dims" not in str(overrides): + overrides = overrides + ["time_series_dims=1", "user_group=all"] - model_ckpt = "cents/outputs/diffusion_ts_pecanstreet_all/2026-01-20_13-25-56/pecanstreet_diffusion_ts_dim1_ctxsep_mlp_statsmlp.ckpt" - logging.basicConfig( - level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" - ) - print("Loading dataset...") - dataset = PecanStreetDataset(overrides=DATASET_OVERRIDES + PECAN_OVERRIDES) - # dataset = CommercialDataset(overrides = DATASET_OVERRIDES) + dataset = _load_dataset(args.dataset, overrides) - normalizer_ckpt = HOME / ".cache/cents/checkpoints/pecanstreet/normalizer/pecanstreet_normalizer_dim1_ctxsep_mlp_statsmlp.ckpt" - # Build a minimal cfg for evaluator and generator - eval_cfg = load_yaml("cents/config/evaluator/default.yaml") - top_cfg = load_yaml("cents/config/config.yaml") + # Load configs + eval_cfg = load_yaml(args.evaluator_config) + top_cfg = load_yaml(args.config) + cfg = OmegaConf.create({}) cfg.evaluator = eval_cfg cfg.wandb = top_cfg.get("wandb", {}) cfg.device = top_cfg.get("device", "auto") - cfg.model = OmegaConf.create(OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{MODEL_KEY}.yaml"), resolve=True)) - cfg.dataset = OmegaConf.create( - OmegaConf.to_container(dataset.cfg, resolve=True) - ) - # Enable EMA sampling to use the EMA weights from checkpoint - cfg.model.use_ema_sampling = True - cfg.eval_pv_shift = eval_cfg.get("eval_pv_shift", False) - cfg.eval_metrics = eval_cfg.get("eval_metrics", True) - cfg.eval_context_sparse = eval_cfg.get("eval_context_sparse", True) - cfg.save_results = eval_cfg.get("save_results", False) - cfg.eval_disentanglement = eval_cfg.get("eval_disentanglement", True) - cfg.job_name = eval_cfg.get("job_name", "default_job") - cfg.save_results = True - cfg.save_dir = HOME / f"cents/outputs/diffusion_ts_pecanstreet_all/2026-01-20_13-25-56/eval" - - + cfg.model = OmegaConf.create( + OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{args.model_type}.yaml"), resolve=True) + ) + cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) + + # Set EMA sampling + cfg.model.use_ema_sampling = not args.no_ema + + # Set evaluation flags (use config defaults if not overridden) + cfg.eval_pv_shift = args.eval_pv_shift if args.eval_pv_shift else eval_cfg.get("eval_pv_shift", False) + cfg.eval_metrics = False if args.no_eval_metrics else eval_cfg.get("eval_metrics", True) + cfg.eval_context_sparse = False if args.no_eval_context_sparse else eval_cfg.get("eval_context_sparse", True) + cfg.eval_disentanglement = False if args.no_eval_disentanglement else eval_cfg.get("eval_disentanglement", True) + cfg.save_results = False if args.no_save_results else True + + # Set job name + cfg.job_name = args.job_name if args.job_name else eval_cfg.get("job_name", "default_job") + + # Set save directory + if args.save_dir: + cfg.save_dir = Path(args.save_dir) + else: + # Default: use model checkpoint directory + /eval + model_ckpt_path = Path(args.model_ckpt) + cfg.save_dir = model_ckpt_path.parent / "eval" + if not os.path.exists(cfg.save_dir): - os.makedirs(cfg.save_dir) - print("Creating Evaluation Directory") - - print("Dataset spec set. Setting up DataGenerator...") - - # Use the fixed checkpoint with DataGenerator - gen = DataGenerator(model_type = MODEL_KEY, dataset=dataset) - print("Loading checkpoint... EMA sampling enabled - will use EMA weights for generation") - gen.load_from_checkpoint(model_ckpt, normalizer_ckpt) + os.makedirs(cfg.save_dir, exist_ok=True) + logging.info("Created evaluation directory: %s", cfg.save_dir) + logging.info("Setting up DataGenerator (model_type=%s)...", args.model_type) + gen = DataGenerator(model_type=args.model_type, dataset=dataset) + + logging.info("Loading checkpoint... EMA sampling %s", "enabled" if cfg.model.use_ema_sampling else "disabled") + gen.load_from_checkpoint(args.model_ckpt, args.normalizer_ckpt) + + # Ensure EMA setting is applied to the config used by the model at generate time + target = getattr(gen.model, "cfg", None) or gen.cfg + if target is not None and hasattr(target, "model"): + target.model.use_ema_sampling = cfg.model.use_ema_sampling + gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) cfg.dataset = gen.model.cfg.dataset - print("Checkpoint loaded") - - print("Evaluating model...") + logging.info("Checkpoint loaded. Starting evaluation...") results = Evaluator(cfg, dataset).evaluate_model(data_generator=gen) + logging.info("Evaluation complete!") print(results) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--context-config-path", type=str, default=None, - help="Path to custom context config YAML file (optional)") - args = parser.parse_args() - main(args) + main() From 74034a44d6c8e53053e59bb61bfc82b2f17a577c Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Fri, 23 Jan 2026 10:28:57 -0500 Subject: [PATCH 23/50] UNSTABLE - fixes to dyn context in base model --- cents/models/base.py | 77 ++++++++++++++++++++--- cents/models/context.py | 18 +++++- cents/models/diffusion_ts.py | 116 +++++++++++++++++++++++++++++++---- 3 files changed, 187 insertions(+), 24 deletions(-) diff --git a/cents/models/base.py b/cents/models/base.py index 73b1f83..9f170c9 100644 --- a/cents/models/base.py +++ b/cents/models/base.py @@ -3,6 +3,7 @@ import pandas as pd import pytorch_lightning as pl import torch +import torch.nn as nn from omegaconf import DictConfig from cents.models.context import MLPContextModule, SepMLPContextModule # Import to trigger registration @@ -42,17 +43,75 @@ def __init__(self, cfg: DictConfig = None): emb_dim = getattr(cfg.model, "cond_emb_dim", 256) # Get context module type from context config context_cfg = get_context_config() - context_module_type = context_cfg.static_context.type - # Get continuous variables from config if specified - # continuous_vars = getattr(cfg.dataset, "continuous_context_vars", None) - # Use registry to get the context module class - ContextModuleCls = get_context_module_cls(context_module_type) - self.context_module = ContextModuleCls( - cfg.dataset.context_vars, - emb_dim, - ) + static_module_type = context_cfg.static_context.type + dynamic_module_type = getattr(context_cfg.dynamic_context, "type", None) + + # Separate static and dynamic context variables + continuous_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "continuous"] + categorical_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "categorical"] + dynamic_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "time_series"] + + static_context_vars = categorical_vars + continuous_vars + self.dynamic_context_vars = dynamic_vars + + # Create static context module (for categorical + continuous) + self.static_context_module = None + if static_context_vars: + StaticContextModuleCls = get_context_module_cls(static_module_type) + static_context_vars_dict = { + k: v for k, v in cfg.dataset.context_vars.items() + if k in static_context_vars + } + self.static_context_module = StaticContextModuleCls( + static_context_vars_dict, + emb_dim, + ) + + # Create dynamic context module (for time_series) + self.dynamic_context_module = None + if self.dynamic_context_vars and dynamic_module_type is not None: + DynamicContextModuleCls = get_context_module_cls("dynamic", dynamic_module_type) + dynamic_context_vars_dict = { + k: v for k, v in cfg.dataset.context_vars.items() + if k in self.dynamic_context_vars + } + seq_len = getattr(cfg.dataset, "seq_len", None) + if seq_len is None: + raise ValueError("seq_len must be specified in cfg.dataset for dynamic context modules") + self.dynamic_context_module = DynamicContextModuleCls( + dynamic_context_vars_dict, + emb_dim, + seq_len=seq_len, + ) + + # Determine embedding dimension and create combine MLP if both exist + if self.static_context_module is not None: + self.embedding_dim = self.static_context_module.embedding_dim + elif self.dynamic_context_module is not None: + self.embedding_dim = self.dynamic_context_module.embedding_dim + else: + raise ValueError("At least one of static_context_module or dynamic_context_module must be provided") + + # If both modules exist, create combine MLP + if self.static_context_module is not None and self.dynamic_context_module is not None: + combined_dim = self.static_context_module.embedding_dim + self.dynamic_context_module.embedding_dim + self.combine_mlp = nn.Sequential( + nn.Linear(combined_dim, self.embedding_dim), + nn.ReLU(), + ) + else: + self.combine_mlp = None + + # For backward compatibility, expose static_context_module as context_module + # (but subclasses should use static_context_module and dynamic_context_module directly) + self.context_module = self.static_context_module else: + self.static_context_module = None + self.dynamic_context_module = None self.context_module = None + self.combine_mlp = None + self.dynamic_context_vars = [] + self.embedding_dim = getattr(cfg.model, "cond_emb_dim", 256) if cfg is not None else 256 @abstractmethod def forward(self, *args, **kwargs): diff --git a/cents/models/context.py b/cents/models/context.py index cdd7d14..f9a341d 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -56,7 +56,11 @@ def __init__(self, context_vars: dict[str, int], embedding_dim: int): self.classification_heads = nn.ModuleDict( { - var_name: nn.Linear(embedding_dim, num_categories) + var_name: nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.ReLU(), + nn.Linear(embedding_dim, num_categories) + ) for var_name, num_categories in context_vars.items() } ) @@ -155,7 +159,11 @@ def __init__( self.classification_heads = nn.ModuleDict( { - var_name: nn.Linear(embedding_dim, num_categories) + var_name: nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.ReLU(), + nn.Linear(embedding_dim, num_categories) + ) for var_name, num_categories in self.categorical_vars.items() } ) @@ -163,7 +171,11 @@ def __init__( # Regression heads for continuous variables (output single value for MSE loss) self.regression_heads = nn.ModuleDict( { - var_name: nn.Linear(embedding_dim, 1) + var_name: nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.ReLU(), + nn.Linear(embedding_dim, 1) + ) for var_name in self.continuous_vars } ) diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index bbf3134..e0a72b9 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -53,7 +53,9 @@ def __init__(self, cfg: DictConfig): self.context_reconstruction_loss_weight = ( cfg.model.context_reconstruction_loss_weight ) - _ = self.context_module + # Verify context modules are initialized (static, dynamic, or both) + if not hasattr(self, 'static_context_module') and not hasattr(self, 'dynamic_context_module'): + raise ValueError("At least one context module (static or dynamic) must be initialized") # linear layer for denoised output self.fc = nn.Linear( @@ -142,6 +144,72 @@ def __init__(self, cfg: DictConfig): self.continuous_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "continuous"] self.categorical_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "categorical"] + def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict]: + """ + Get combined context embedding from static and/or dynamic context modules. + + Args: + context_vars: Dict of context tensors (static: single values, dynamic: time series) + + Returns: + embedding: Combined embedding tensor of shape (batch_size, embedding_dim) + all_logits: Dict of classification/regression logits from both modules + """ + embeddings = [] + all_logits = {} + + # Process static context variables + if self.static_context_module is not None: + # Filter static context variables + static_vars = { + k: v for k, v in context_vars.items() + if k not in getattr(self, 'dynamic_context_vars', []) + } + if static_vars: + device = next(self.static_context_module.parameters()).device + static_vars = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in static_vars.items() + } + static_embedding, static_logits = self.static_context_module(static_vars) + embeddings.append(static_embedding) + all_logits.update(static_logits) + + # Process dynamic context variables + if self.dynamic_context_module is not None: + # Filter dynamic context variables + dynamic_var_names = getattr(self, 'dynamic_context_vars', []) + dynamic_vars = { + k: v for k, v in context_vars.items() + if k in dynamic_var_names + } + if dynamic_vars: + device = next(self.dynamic_context_module.parameters()).device + dynamic_vars = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in dynamic_vars.items() + } + dynamic_embedding, dynamic_logits = self.dynamic_context_module(dynamic_vars) + # Check for NaN in dynamic embedding + if torch.isnan(dynamic_embedding).any() or torch.isinf(dynamic_embedding).any(): + raise ValueError( + f"NaN/Inf detected in dynamic embedding. " + f"Dynamic vars: {list(dynamic_vars.keys())}" + ) + embeddings.append(dynamic_embedding) + all_logits.update(dynamic_logits) + + # Combine embeddings if both exist + if len(embeddings) == 2: + combined = torch.cat(embeddings, dim=1) + embedding = self.combine_mlp(combined) + elif len(embeddings) == 1: + embedding = embeddings[0] + else: + raise ValueError("No context variables provided") + + return embedding, all_logits + def predict_noise_from_start( self, x_t: torch.Tensor, t: torch.Tensor, x0: torch.Tensor ) -> torch.Tensor: @@ -218,7 +286,7 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di """ b = x.shape[0] t = torch.randint(0, self.num_timesteps, (b,), device=self.device) - embedding, cond_classification_logits = self.context_module(context_vars) + embedding, cond_classification_logits = self._get_context_embedding(context_vars) # Check embedding for NaN/Inf and extreme values # if embedding.isnan().any() or embedding.isinf().any(): @@ -301,7 +369,6 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: for var_name, outputs in cond_class_logits.items(): labels = cond_batch[var_name] - if var_name in self.continuous_context_vars: loss = F.mse_loss(outputs, labels.float()) elif var_name in self.categorical_context_vars: @@ -314,7 +381,9 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: # print(loss) # print(outputs.mean(), labels.mean()) - h, _ = self.context_module(cond_batch) + cond_loss /= len(cond_class_logits) + + h, _ = self._get_context_embedding(cond_batch) tc_term = ( self.cfg.model.tc_loss_weight * total_correlation(h) if self.cfg.model.tc_loss_weight > 0.0 @@ -325,12 +394,35 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: rec_loss + self.context_reconstruction_loss_weight * cond_loss + tc_term ) - # # Check for NaN in total loss - # if torch.isnan(total_loss) or torch.isinf(total_loss): - # raise ValueError( - # f"NaN/Inf detected in total_loss at batch {batch_idx}. " - # f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss:.6f}, tc_term: {tc_term.item():.6f}" - # ) + # Check for NaN in total loss + if torch.isnan(total_loss) or torch.isinf(total_loss): + raise ValueError( + f"NaN/Inf detected in total_loss at batch {batch_idx}. " + f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss.item():.6f}, tc_term: {tc_term.item():.6f}" + ) + + # Debug: Check if context module parameters are getting gradients + # (only log occasionally to avoid spam) + if batch_idx % 50 == 0: + context_params_with_grad = [] + context_params_no_grad = [] + if self.static_context_module is not None: + for name, param in self.static_context_module.named_parameters(): + if param.requires_grad: + if param.grad is not None: + grad_norm = param.grad.norm().item() + context_params_with_grad.append((name, grad_norm)) + else: + context_params_no_grad.append(name) + + if context_params_no_grad: + print(f"[Warning] {len(context_params_no_grad)} context module parameters have no gradients!") + print(f" No grad params: {context_params_no_grad[:5]}...") + if context_params_with_grad: + avg_grad_norm = sum(g[1] for g in context_params_with_grad) / len(context_params_with_grad) + max_grad_norm = max(g[1] for g in context_params_with_grad) + print(f"[Debug] Context module gradients: avg_norm={avg_grad_norm:.6f}, max_norm={max_grad_norm:.6f}") + self.log_dict( { "train_loss": total_loss.item(), @@ -466,7 +558,7 @@ def sample(self, shape: Tuple[int, int, int], context_vars: dict) -> torch.Tenso Generated samples tensor. """ x = torch.randn(shape, device=self.device) - embedding, _ = self.context_module(context_vars) + embedding, _ = self._get_context_embedding(context_vars) for t in reversed(range(self.num_timesteps)): x = self.p_sample(x, t, embedding) return x @@ -479,7 +571,7 @@ def fast_sample( Faster sampling using a reduced number of timesteps. """ x = torch.randn(shape, device=self.device) - embedding, _ = self.context_module(context_vars) + embedding, _ = self._get_context_embedding(context_vars) times = torch.linspace( -1, self.num_timesteps - 1, steps=self.sampling_timesteps + 1 ) From 288ab180550f7b11ba2ecb8789efac4b6d6c490b Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 26 Jan 2026 12:23:56 -0500 Subject: [PATCH 24/50] Tracking tools for gradient flow to context --- cents/models/context.py | 5 ++- cents/models/diffusion_ts.py | 68 ++++++++++++++++++++++++------------ scripts/train.py | 17 +++++---- 3 files changed, 61 insertions(+), 29 deletions(-) diff --git a/cents/models/context.py b/cents/models/context.py index f9a341d..e0e5471 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -213,7 +213,8 @@ def forward(self, context_vars): embeddings = [] # Apply init MLPs to categorical variables for name, layer in self.init_mlps.items(): - embeddings.append(layer(encodings[name])) + if name in encodings: + embeddings.append(layer(encodings[name])) # Apply init MLPs to continuous variables for name, layer in self.continuous_init_mlps.items(): @@ -254,12 +255,14 @@ def forward(self, context_vars): classification_logits = { var_name: head(embedding) for var_name, head in self.classification_heads.items() + if var_name in context_vars } # Regression outputs for continuous variables regression_outputs = { var_name: head(embedding).squeeze(-1) # Remove last dim to get (batch_size,) for var_name, head in self.regression_heads.items() + if var_name in context_vars } # Combine both into a single dict for backward compatibility diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index e0a72b9..a1b42a7 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -401,28 +401,6 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss.item():.6f}, tc_term: {tc_term.item():.6f}" ) - # Debug: Check if context module parameters are getting gradients - # (only log occasionally to avoid spam) - if batch_idx % 50 == 0: - context_params_with_grad = [] - context_params_no_grad = [] - if self.static_context_module is not None: - for name, param in self.static_context_module.named_parameters(): - if param.requires_grad: - if param.grad is not None: - grad_norm = param.grad.norm().item() - context_params_with_grad.append((name, grad_norm)) - else: - context_params_no_grad.append(name) - - if context_params_no_grad: - print(f"[Warning] {len(context_params_no_grad)} context module parameters have no gradients!") - print(f" No grad params: {context_params_no_grad[:5]}...") - if context_params_with_grad: - avg_grad_norm = sum(g[1] for g in context_params_with_grad) / len(context_params_with_grad) - max_grad_norm = max(g[1] for g in context_params_with_grad) - print(f"[Debug] Context module gradients: avg_norm={avg_grad_norm:.6f}, max_norm={max_grad_norm:.6f}") - self.log_dict( { "train_loss": total_loss.item(), @@ -461,6 +439,52 @@ def on_train_start(self) -> None: update_every=self.cfg.model.ema_update_interval, ) + def on_after_backward(self) -> None: + """ + Check gradients after backward pass but before optimizer step. + This is the right place to inspect gradients before they're zeroed. + """ + # Get current batch index from trainer + if not hasattr(self.trainer, 'global_step'): + return + + batch_idx = self.trainer.global_step + + # Debug: Check if context module parameters are getting gradients + # Check AFTER backward pass but BEFORE optimizer step (only log occasionally) + if batch_idx % 50 == 0: + context_params_with_grad = [] + context_params_no_grad = [] + if self.static_context_module is not None: + for name, param in self.static_context_module.named_parameters(): + if param.requires_grad: + if param.grad is not None: + grad_norm = param.grad.norm().item() + # Check for NaN/Inf gradients + if torch.isnan(param.grad).any() or torch.isinf(param.grad).any(): + print(f"[Warning] NaN/Inf gradients detected in {name}") + else: + context_params_with_grad.append((name, grad_norm)) + else: + context_params_no_grad.append(name) + + if context_params_no_grad: + # Group by variable name to identify which context variables are missing + missing_vars = set() + for param_name in context_params_no_grad: + # Extract variable name from parameter name (e.g., "context_embeddings.year.weight" -> "year") + parts = param_name.split('.') + if len(parts) >= 2 and parts[0] in ['context_embeddings', 'init_mlps']: + missing_vars.add(parts[1]) + print(f"[Warning] {len(context_params_no_grad)} context module parameters have no gradients!") + if missing_vars: + print(f" Missing context variables: {sorted(missing_vars)}") + print(f" No grad params (sample): {context_params_no_grad[:5]}...") + if context_params_with_grad: + avg_grad_norm = sum(g[1] for g in context_params_with_grad) / len(context_params_with_grad) + max_grad_norm = max(g[1] for g in context_params_with_grad) + print(f"[Debug] Context module gradients: avg_norm={avg_grad_norm:.6f}, max_norm={max_grad_norm:.6f}") + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: """ Apply EMA update after each batch end. diff --git a/scripts/train.py b/scripts/train.py index 9a0e151..f3e551d 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -84,21 +84,26 @@ def main(args) -> None: parser.add_argument("--dataset", type=str, default="pecanstreet") parser.add_argument("--epochs", type=int, default=5000) parser.add_argument("--batch_size", type=int, default=None) - parser.add_argument("--wandb-enabled", type=bool, default=False) + parser.add_argument("--wandb-enabled", action="store_true", + help="Enable Weights and Biases logging") parser.add_argument("--wandb-project", type=str, default="cents") parser.add_argument("--wandb-entity", type=str, default=None) - parser.add_argument("--eval_after_training", type=bool, default=True) - parser.add_argument("--skip_heavy_processing", type=bool, default=True) + parser.add_argument("--eval_after_training", action="store_true", + help="Evaluate after training") + parser.add_argument("--skip_heavy_processing", action="store_true", + help="Skip heavy processing of dataset") parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_false") - parser.add_argument("--enable_checkpointing", type=bool, default=True) + parser.add_argument("--enable_checkpointing", action="store_true", + help="Enable checkpointing") parser.add_argument("--context-config-path", type=str, default=None, help="Path to custom context config YAML file (optional)") parser.add_argument("--context-overrides", type=str, nargs="*", default=[], help="Override context config values (e.g., 'static_context.type=mlp' 'dynamic_context.type=cnn')") - parser.add_argument("--force-retrain-normalizer", type=bool, default=False, + parser.add_argument("--force-retrain-normalizer", action="store_true", help="Force retraining of normalizer even if cached version exists") parser.add_argument("--resume-from-checkpoint", type=str, default=None, - help="Path to checkpoint file (.ckpt) to resume training from") + help="Path to checkpoint file (.ckpt) to resume training from", + ) args = parser.parse_args() main(args) From ddbc53c48cb3c4f77d6c47f6b9b629c535d5ebbf Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 2 Feb 2026 08:50:37 -0500 Subject: [PATCH 25/50] Vehicle Training Runs --- cents/datasets/timeseries_dataset.py | 7 +- cents/models/base.py | 3 +- cents/models/context.py | 26 ++++- cents/models/diffusion_ts.py | 97 +++++++++---------- cents/models/model_utils.py | 1 - cents/models/normalizer.py | 139 +++++++++++++-------------- scripts/train.py | 6 ++ 7 files changed, 145 insertions(+), 134 deletions(-) diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 8e870df..dd5a853 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -123,6 +123,7 @@ def __init__( if self.normalize: self._init_normalizer() cache_path = self._get_normalization_cache_path() + print("CACHE PATH", cache_path) if cache_path.exists(): print(f"[{'DDP Subprocess' if is_ddp_subprocess else 'Main Process'}] Loading pre-normalized data from cache") with open(cache_path, 'rb') as f: @@ -243,13 +244,7 @@ def get_train_dataloader( """ continuous_vars = getattr(self.cfg, "continuous_context_vars", None) or [] - # for col in continuous_vars: - # if col in self.data.columns: - # print(self.data[col].mean()) self._normalize_continuous_vars() - # for col in continuous_vars: - # if col in self.data.columns: - # print(self.data[col].mean()) return DataLoader( self, batch_size=batch_size, shuffle=shuffle, num_workers=8, persistent_workers=persistent_workers ) diff --git a/cents/models/base.py b/cents/models/base.py index 9f170c9..de03b79 100644 --- a/cents/models/base.py +++ b/cents/models/base.py @@ -70,6 +70,7 @@ def __init__(self, cfg: DictConfig = None): # Create dynamic context module (for time_series) self.dynamic_context_module = None if self.dynamic_context_vars and dynamic_module_type is not None: + num_ts_steps = getattr(cfg.dataset, "num_ts_steps", None) DynamicContextModuleCls = get_context_module_cls("dynamic", dynamic_module_type) dynamic_context_vars_dict = { k: v for k, v in cfg.dataset.context_vars.items() @@ -81,7 +82,7 @@ def __init__(self, cfg: DictConfig = None): self.dynamic_context_module = DynamicContextModuleCls( dynamic_context_vars_dict, emb_dim, - seq_len=seq_len, + seq_len=seq_len if num_ts_steps is None else num_ts_steps, ) # Determine embedding dimension and create combine MLP if both exist diff --git a/cents/models/context.py b/cents/models/context.py index e0e5471..315022c 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -188,7 +188,6 @@ def forward(self, context_vars): if name in context_vars: encodings[name] = layer(context_vars[name]) - #print(encodings, "ENCODINGS") # Process continuous variables (only those present in context_vars) for name, layer in self.continuous_projections.items(): if name in context_vars: @@ -251,7 +250,6 @@ def forward(self, context_vars): f"min={context_matrix.min():.4f}, max={context_matrix.max():.4f}" ) - #print(embedding, "post mixing") classification_logits = { var_name: head(embedding) for var_name, head in self.classification_heads.items() @@ -615,6 +613,12 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, attention_weights = torch.softmax(attention_weights, dim=1) pooled = (encoded * attention_weights).sum(dim=1) # (batch, embedding_dim) + # Normalize pooled embedding to prevent accumulation of large values + # Layer normalization: normalize across embedding dimension + pooled_mean = pooled.mean(dim=1, keepdim=True) # (batch, 1) + pooled_std = pooled.std(dim=1, keepdim=True) + 1e-8 # (batch, 1) + pooled = (pooled - pooled_mean) / pooled_std + embeddings.append(pooled) # Process numeric time series @@ -633,8 +637,14 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, # Replace NaN/Inf with zeros to prevent propagation ts_data = torch.where(torch.isfinite(ts_data), ts_data, torch.zeros_like(ts_data)) + # Normalize input to prevent numerical overflow + # Compute per-sample statistics to normalize each time series independently + ts_mean = ts_data.mean(dim=1, keepdim=True) # (batch, 1) + ts_std = ts_data.std(dim=1, keepdim=True) + 1e-8 # (batch, 1) - add epsilon to prevent division by zero + ts_data_normalized = (ts_data - ts_mean) / ts_std + # Project to embedding_dim: (batch, seq_len) -> (batch, seq_len, embedding_dim) - ts_data_expanded = ts_data.unsqueeze(-1) # (batch, seq_len, 1) + ts_data_expanded = ts_data_normalized.unsqueeze(-1) # (batch, seq_len, 1) embedded = self.ts_projections[name](ts_data_expanded) # (batch, seq_len, embedding_dim) # Add positional encoding if available @@ -660,6 +670,12 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, attention_weights = torch.softmax(attention_weights, dim=1) pooled = (encoded * attention_weights).sum(dim=1) # (batch, embedding_dim) + # Normalize pooled embedding to prevent accumulation of large values + # Layer normalization: normalize across embedding dimension + pooled_mean = pooled.mean(dim=1, keepdim=True) # (batch, 1) + pooled_std = pooled.std(dim=1, keepdim=True) + 1e-8 # (batch, 1) + pooled = (pooled - pooled_mean) / pooled_std + embeddings.append(pooled) if not embeddings: @@ -674,9 +690,9 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, if torch.isnan(combined).any() or torch.isinf(combined).any(): raise ValueError(f"NaN/Inf detected in combined embeddings before mixing MLP") embedding = self.mixing_mlp(combined) # (batch, embedding_dim) - # Check for NaN after mixing + # Check for NaN after mixing and normalization if torch.isnan(embedding).any() or torch.isinf(embedding).any(): - raise ValueError(f"NaN/Inf detected in final embedding after mixing MLP") + raise ValueError(f"NaN/Inf detected in final embedding after mixing MLP and normalization") return embedding, {} diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index a1b42a7..34bab68 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -284,43 +284,54 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di rec_loss: Reconstruction loss tensor. cond_logits: Classification logits dict from context module. """ + # Check input x for extreme values + # if x.abs().max() > 100.0: + # print(f"[Warning] Input x has extreme values: min={x.min():.4f}, max={x.max():.4f}, " + # f"mean={x.mean():.4f}, std={x.std():.4f}, shape={x.shape}") + + if torch.isnan(x).any() or torch.isinf(x).any(): + raise ValueError(f"NaN/Inf detected in input x. Shape: {x.shape}, " + f"NaN count: {torch.isnan(x).sum()}, Inf count: {torch.isinf(x).sum()}") + b = x.shape[0] t = torch.randint(0, self.num_timesteps, (b,), device=self.device) embedding, cond_classification_logits = self._get_context_embedding(context_vars) - # Check embedding for NaN/Inf and extreme values - # if embedding.isnan().any() or embedding.isinf().any(): - # raise ValueError( - # f"NaN/Inf detected in embedding from context module. " - # f"NaN count: {embedding.isnan().sum()}, Inf count: {embedding.isinf().sum()}" - # ) - - # Clamp extreme values to prevent numerical instability in transformer - # Don't fully normalize as that would change the learned embedding scale - # Just clip extreme outliers that could cause issues in attention/Fourier operations - # embedding_clamped = torch.clamp(embedding, min=-50.0, max=50.0) - - # # Log if clamping occurred (for debugging) - # if (embedding != embedding_clamped).any(): - # n_clamped = (embedding != embedding_clamped).sum().item() - # print(f"[Warning] Clamped {n_clamped} embedding values. " - # f"Original range: [{embedding.min():.4f}, {embedding.max():.4f}], " - # f"Clamped range: [{embedding_clamped.min():.4f}, {embedding_clamped.max():.4f}]") + # Check embedding for NaN/Inf + if embedding.isnan().any() or embedding.isinf().any(): + raise ValueError( + f"NaN/Inf detected in embedding from context module. " + f"NaN count: {embedding.isnan().sum()}, Inf count: {embedding.isinf().sum()}, " + f"shape: {embedding.shape}, min: {embedding.min()}, max: {embedding.max()}" + ) - # embedding_normalized = embedding_clamped + # Embedding should now be normalized by the context module (mean=0, std=1 per sample) + # Check that values are in reasonable range + if embedding.abs().max() > 100.0: + print(f"[Warning] Embedding has large values despite normalization: " + f"min={embedding.min():.4f}, max={embedding.max():.4f}, " + f"mean={embedding.mean():.4f}, std={embedding.std():.4f}") + # Check diffusion schedule parameters noise = torch.randn_like(x) + x_noisy = ( self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * noise ) - - # if x_noisy.isnan().any(): - # raise ValueError("NaN detected in x_noisy") + + if x_noisy.isnan().any() or x_noisy.isinf().any(): + raise ValueError(f"NaN/Inf detected in x_noisy. Shape: {x_noisy.shape}, " + f"NaN count: {torch.isnan(x_noisy).sum()}, Inf count: {torch.isinf(x_noisy).sum()}") # Use normalized embedding for concatenation embedding_expanded = embedding.unsqueeze(1).repeat(1, self.seq_len, 1) c = torch.cat([x_noisy, embedding_expanded], dim=-1) + + if c.isnan().any() or c.isinf().any(): + raise ValueError(f"NaN/Inf detected in concatenated input 'c'. " + f"Shape: {c.shape}, x_noisy stats: min={x_noisy.min():.4f}, max={x_noisy.max():.4f}, " + f"embedding stats: min={embedding.min():.4f}, max={embedding.max():.4f}") # if c.isnan().any() or c.isinf().any(): # raise ValueError( @@ -332,21 +343,8 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di # f"max={embedding.max():.4f}" # ) - # if t.isnan().any(): - # raise ValueError("NaN detected in timestep 't'") - trend, season = self.model(c, t, padding_masks=None) - # if trend.isnan().any(): - # print("trend") - - # if season.isnan().any(): - # print("season") x_recon = self.fc(trend + season) - # if x_recon.isnan().any(): - # print("X RECON") - # if x.isnan().any(): - # print("x") - # print("REC LOSS", x_recon, x) rec_loss = self.recon_loss_fn(x_recon, x) return rec_loss, cond_classification_logits @@ -367,21 +365,22 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: cond_loss = 0.0 - for var_name, outputs in cond_class_logits.items(): - labels = cond_batch[var_name] - if var_name in self.continuous_context_vars: - loss = F.mse_loss(outputs, labels.float()) - elif var_name in self.categorical_context_vars: - loss = self.auxiliary_loss(outputs, labels) + # for var_name, outputs in cond_class_logits.items(): + # labels = cond_batch[var_name] + # if var_name in self.continuous_context_vars: + # loss = F.mse_loss(outputs, labels.float()) + # elif var_name in self.categorical_context_vars: + # loss = self.auxiliary_loss(outputs, labels) - cond_loss += loss.mean() + # cond_loss += loss.mean() - # if var_name in self.continuous_context_vars: - # print(var_name) - # print(loss) - # print(outputs.mean(), labels.mean()) + # # if var_name in self.continuous_context_vars: + # # print(var_name) + # # print(loss) + # # print(outputs.mean(), labels.mean()) - cond_loss /= len(cond_class_logits) + + # cond_loss /= len(cond_class_logits) h, _ = self._get_context_embedding(cond_batch) tc_term = ( @@ -398,14 +397,14 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: if torch.isnan(total_loss) or torch.isinf(total_loss): raise ValueError( f"NaN/Inf detected in total_loss at batch {batch_idx}. " - f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss.item():.6f}, tc_term: {tc_term.item():.6f}" + f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss:.6f}, tc_term: {tc_term.item():.6f}" ) self.log_dict( { "train_loss": total_loss.item(), "rec_loss": rec_loss.item(), - "cond_loss": cond_loss.item(), + # "cond_loss": cond_loss.item(), "tc_loss": tc_term, }, prog_bar=True, diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index 304cca4..28c1b61 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -134,7 +134,6 @@ def forward(self, x): x: [batch size, sequence length, embed dim] output: [batch size, sequence length, embed dim] """ - # print(x.shape) x = x + self.pe return self.dropout(x) diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index ee8905c..c9ba91d 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -277,11 +277,11 @@ def __init__( self.dataset = dataset # Get continuous variables from config if specified - continuous_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "continuous"] - categorical_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "categorical"] + self.continuous_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "continuous"] + self.categorical_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "categorical"] dynamic_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "time_series"] - self.static_context_vars = categorical_vars + continuous_vars + self.static_context_vars = self.categorical_vars + self.continuous_vars self.dynamic_context_vars = dynamic_vars self.context_vars = self.static_context_vars + self.dynamic_context_vars @@ -291,6 +291,7 @@ def __init__( self.time_series_dims = dataset_cfg.time_series_dims self.do_scale = dataset_cfg.scale self.seq_len = dataset_cfg.seq_len + self.num_ts_steps = getattr(dataset_cfg, "num_ts_steps", None) # For dynamic context length # Get context config # context_cfg = get_context_config() @@ -315,7 +316,7 @@ def __init__( self.static_context_vars_dict, 256, ) - + # Create dynamic context module (for time_series) dynamic_context_module = None if self.dynamic_context_vars and self.dynamic_module_type is not None: @@ -325,10 +326,12 @@ def __init__( k: v for k, v in self.dataset_cfg.context_vars.items() if k in self.dynamic_context_vars } + # Use num_ts_steps for dynamic context length if available, otherwise seq_len + dynamic_seq_len = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len dynamic_context_module = DynamicContextModuleCls( dynamic_context_vars_dict, 256, - seq_len=self.seq_len, + seq_len=dynamic_seq_len, ) self.normalizer_model = _NormalizerModule( @@ -347,7 +350,7 @@ def __init__( self.dynamic_cond_module = self.normalizer_model.dynamic_cond_module # Will be populated in setup() - self.group_stats = {} + self.sample_stats = [] self._verify_parameters() def _verify_parameters(self): @@ -372,9 +375,10 @@ def _verify_parameters(self): def setup(self, stage: Optional[str] = None): """ - Lightning hook: compute group statistics before training. + Lightning hook: prepare training data before training. """ - self.group_stats = self._compute_group_stats() + # Compute per-sample statistics - no grouping needed + self.sample_stats = self._compute_per_sample_stats() # Log initial predictions to check if model is in the right ballpark if stage == "fit" or stage is None: @@ -611,44 +615,56 @@ def train_dataloader(self): pin_memory=torch.cuda.is_available(), # Helps with GPU transfer prefetch_factor=2, # Reduce prefetch to avoid memory issues ) - # def on_after_backward(self): - # unused = [n for n,p in self.named_parameters() if p.requires_grad and p.grad is None] - # if unused: - # print("UNUSED:", unused[:50]) - def _compute_group_stats(self) -> dict: + def _compute_per_sample_stats(self) -> list: """ - Compute per-group (context combination) statistics from raw data. + Compute statistics for each individual sample. + This allows the model to learn context → normalization_params for all context types + (categorical, continuous, and dynamic) without requiring grouping. Returns: - Mapping from context tuple to (mu_array, std_array, zmin_array, zmax_array, dynamic_ctx_dict). + List of tuples: (context_vars_dict, mu_array, std_array, zmin_array, zmax_array) """ df = self.dataset.data.copy() - grouped_stats = {} - for group_vals, group_df in df.groupby(self.static_context_vars): - dimension_points = [[] for _ in range(self.time_series_dims)] - # Store dynamic context variables (time series) for this group - # We'll use the first row's dynamic context as representative + sample_stats = [] + continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] + + for idx, row in df.iterrows(): + context_vars_dict = {} + + # Process static context variables (categorical + continuous) + for var_name in self.static_context_vars: + if var_name in row: + if var_name in continuous_vars: + context_vars_dict[var_name] = torch.tensor(row[var_name], dtype=torch.float32) + else: + context_vars_dict[var_name] = torch.tensor(row[var_name], dtype=torch.long) + + # Process dynamic context variables (time series) dynamic_ctx_dict = {} - if self.dynamic_context_vars: - first_row = group_df.iloc[0] - for var_name in self.dynamic_context_vars: - if var_name in first_row: - # Get the time series sequence - ts_data = first_row[var_name] - if isinstance(ts_data, np.ndarray): - dynamic_ctx_dict[var_name] = ts_data - elif isinstance(ts_data, list): - dynamic_ctx_dict[var_name] = np.array(ts_data) - else: - # If it's a scalar, repeat it to match seq_len - dynamic_ctx_dict[var_name] = np.full(self.seq_len, ts_data) + for var_name in self.dynamic_context_vars: + # Check for both the original name and context_ prefix (vehicle dataset uses context_ prefix) + ts_data = None + if var_name in row: + ts_data = row[var_name] + + if ts_data is not None: + if isinstance(ts_data, np.ndarray): + dynamic_ctx_dict[var_name] = ts_data + elif isinstance(ts_data, list): + dynamic_ctx_dict[var_name] = np.array(ts_data) + else: + # If it's a scalar, repeat it to match the appropriate length + # Use num_ts_steps if available (for dynamic context), otherwise seq_len + context_len = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len + dynamic_ctx_dict[var_name] = np.full(context_len, ts_data) + + # Compute statistics from this sample's target time series + dimension_points = [] + for d, col_name in enumerate(self.time_series_cols): + arr = np.array(row[col_name], dtype=np.float32).flatten() + dimension_points.append(arr) - for _, row in group_df.iterrows(): - for d, col_name in enumerate(self.time_series_cols): - arr = np.array(row[col_name], dtype=np.float32).flatten() - dimension_points[d].append(arr) - dimension_points = [np.concatenate(d, axis=0) for d in dimension_points] mu_array = np.array( [pts.mean() for pts in dimension_points], dtype=np.float32 ) @@ -673,47 +689,34 @@ def _compute_group_stats(self) -> dict: ) else: z_min_array = z_max_array = None - - grouped_stats[tuple(group_vals)] = ( + + sample_stats.append(( + context_vars_dict, + dynamic_ctx_dict, mu_array, std_array, z_min_array, z_max_array, - dynamic_ctx_dict, - ) - return grouped_stats + )) + + return sample_stats def _create_training_dataset(self) -> Dataset: """ - Build an internal Dataset yielding true stats for each context group. + Build an internal Dataset yielding per-sample statistics. Returns: PyTorch Dataset of samples (context_vars_dict, mu, sigma, zmin, zmax). """ - data_tuples = [ - (ctx_tuple, mu_arr, sigma_arr, zmin_arr, zmax_arr, dynamic_ctx_dict) - for ctx_tuple, ( - mu_arr, - sigma_arr, - zmin_arr, - zmax_arr, - dynamic_ctx_dict, - ) in self.group_stats.items() - ] - - continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] - class _TrainSet(Dataset): """ - Adapter Dataset to wrap group_stats tuples for DataLoader. + Adapter Dataset to wrap per-sample statistics for DataLoader. """ - def __init__(self, samples, static_context_vars, dynamic_context_vars, do_scale, continuous_vars, dataset_cfg): + def __init__(self, samples, dynamic_context_vars, do_scale, dataset_cfg): self.samples = samples - self.static_context_vars = static_context_vars self.dynamic_context_vars = dynamic_context_vars self.do_scale = do_scale - self.continuous_vars = continuous_vars self.dataset_cfg = dataset_cfg def __len__(self) -> int: @@ -733,17 +736,9 @@ def __getitem__(self, idx: int): zmin_t: True min z-score tensor or None. zmax_t: True max z-score tensor or None. """ - ctx_tuple, mu_arr, sigma_arr, zmin_arr, zmax_arr, dynamic_ctx_dict = self.samples[idx] - context_vars_dict = {} - - # Process static context variables - for i, var_name in enumerate(self.static_context_vars): - if var_name in self.continuous_vars: - context_vars_dict[var_name] = torch.tensor(ctx_tuple[i], dtype=torch.float32) - else: - context_vars_dict[var_name] = torch.tensor(ctx_tuple[i], dtype=torch.long) + context_vars_dict, dynamic_ctx_dict, mu_arr, sigma_arr, zmin_arr, zmax_arr = self.samples[idx] - # Process dynamic context variables (time series) + # Process dynamic context variables (time series) - convert to tensors for var_name in self.dynamic_context_vars: if var_name in dynamic_ctx_dict: ts_data = dynamic_ctx_dict[var_name] @@ -772,7 +767,7 @@ def __getitem__(self, idx: int): zmax_t = torch.from_numpy(zmax_arr).float() if self.do_scale else None return context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t - return _TrainSet(data_tuples, self.static_context_vars, self.dynamic_context_vars, self.do_scale, continuous_vars, self.dataset_cfg) + return _TrainSet(self.sample_stats, self.dynamic_context_vars, self.do_scale, self.dataset_cfg) def transform(self, df: pd.DataFrame) -> pd.DataFrame: """ diff --git a/scripts/train.py b/scripts/train.py index f3e551d..9e861b3 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -4,6 +4,7 @@ from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset from cents.datasets.airquality import AirQualityDataset +from cents.datasets.vehicle import VehicleDataset from cents.trainer import Trainer from cents.utils.utils import set_context_config_path, set_context_overrides from pytorch_lightning.callbacks import EarlyStopping @@ -42,6 +43,11 @@ def main(args) -> None: overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"], force_retrain_normalizer=args.force_retrain_normalizer ) + elif args.dataset == "vehicle": + dataset = VehicleDataset( + overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"], + force_retrain_normalizer=args.force_retrain_normalizer + ) else: raise ValueError(f"Dataset {args.dataset} not supported") From a89902845f6b2461459e27b2bed197f5283c29c2 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 2 Feb 2026 13:47:55 -0500 Subject: [PATCH 26/50] Added vehicle files ; fixes for bininng numerics --- cents/config/dataset/commercial.yaml | 1 - cents/config/dataset/pecanstreet.yaml | 6 +- cents/config/dataset/vehicle.yaml | 24 ++ cents/config/trainer/diffusion_ts.yaml | 2 +- cents/config/trainer/normalizer.yaml | 6 +- cents/datasets/timeseries_dataset.py | 26 +- cents/datasets/utils.py | 1 + cents/datasets/vehicle.py | 148 +++++++++++ cents/models/context.py | 24 +- cents/models/diffusion_ts.py | 34 +-- cents/models/normalizer.py | 332 ++++++++++++++++++++----- cents/trainer.py | 1 + scripts/train.py | 2 +- 13 files changed, 505 insertions(+), 102 deletions(-) create mode 100644 cents/config/dataset/vehicle.yaml create mode 100644 cents/datasets/vehicle.py diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index a60d82b..07c601c 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -15,7 +15,6 @@ time_series_columns: "energy_meter" data_columns: ["dataid","energy_meter","timestamp"] metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sqft", "yearbuilt"] numeric_context_bins: 5 -numeric_cols: ["sqft", "yearbuilt"] # Columns to bin as numeric reduce_cardinality: False diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index 0e63035..2d16ace 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -15,6 +15,8 @@ data_columns: ["dataid","local_15min","car1","grid","solar"] metadata_columns: ["dataid","building_type","solar","car1","city","state","total_square_footage","house_construction_year"] user_group: all # non_pv_users, all, pv_users numeric_context_bins: 5 +normalizer_stats_mode: group + context_vars: month: ["categorical", 12] @@ -24,5 +26,5 @@ context_vars: car1: ["categorical", 2] city: ["categorical", 7] state: ["categorical", 3] - total_square_footage: ["continuous", null] - house_construction_year: ["continuous", null] \ No newline at end of file + total_square_footage: ["categorical", null] + house_construction_year: ["categorical", null] \ No newline at end of file diff --git a/cents/config/dataset/vehicle.yaml b/cents/config/dataset/vehicle.yaml new file mode 100644 index 0000000..672d928 --- /dev/null +++ b/cents/config/dataset/vehicle.yaml @@ -0,0 +1,24 @@ +name: vehicle +normalize: False +scale: False +use_learned_normalizer: True +seq_len: 15 # 1.5 s +time_series_dims: 6 +shuffle: True +skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP) +max_samples: null # Limit dataset size (null = use all data) +path: "./data/vehicle" +time_series_columns: ["Acceleration_pedal_depth", "Vehicle_speed", "Brake_pedal_depth", "Vehicle_acceleration", "VCU_MotTqCmd", "MCU_MotActTq"] +data_columns: ["Time", "Acceleration_pedal_depth", "Vehicle_speed", "Brake_pedal_depth", "Vehicle_acceleration", "VCU_MotTqCmd", "MCU_MotActTq"] +num_ts_steps: 5 #.5s +numeric_context_bins: null + +context_vars: + context_Acceleration_pedal_depth: ["time_series", null] + context_Vehicle_speed: ["time_series", null] + context_Brake_pedal_depth: ["time_series", null] + context_Vehicle_acceleration: ["time_series", null] + context_VCU_MotTqCmd: ["time_series", null] + context_MCU_MotActTq: ["time_series", null] + + diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 25196b8..18dbd68 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -1,7 +1,7 @@ precision: "16-mixed" accelerator: auto devices: auto -strategy: ddp_find_unused_parameters_false +strategy: ddp_find_unused_parameters_true gradient_accumulate_every: 4 log_every_n_steps: 1 batch_size: 512 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index a42ac0e..1e78c27 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,13 +1,13 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: auto +devices: 1 precision: 16-mixed log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 n_epochs: 2000 -batch_size: 8192 -lr: 1e-5 +batch_size: 4096 +lr: 3e-4 gradient_clip_val: 1.0 save_cycle: 5000 eval_after_training: False diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index dd5a853..78545d5 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -59,15 +59,12 @@ def __init__( ): # Initialize basic attributes # Handle OmegaConf ListConfig objects - if isinstance(time_series_column_names, ListConfig): + if not isinstance(time_series_column_names, list): time_series_column_names = list(time_series_column_names) if isinstance(context_var_column_names, ListConfig): context_var_column_names = list(context_var_column_names) - self.time_series_column_names = ( - time_series_column_names - if isinstance(time_series_column_names, list) - else [time_series_column_names] - ) + + self.time_series_column_names = time_series_column_names self.time_series_dims = self.cfg.time_series_dims self.context_vars = context_var_column_names or [] self.seq_len = seq_len @@ -86,9 +83,12 @@ def __init__( self.cfg = cfg self.context_var_dict = self.cfg.context_vars + self.numeric_cols = [k for k, v in self.cfg.context_vars.items() if v[0] == "categorical" and v[1] is None] self.numeric_context_bins = self.cfg.numeric_context_bins - if not hasattr(self, "threshold"): - self.threshold = (-self.cfg.threshold, self.cfg.threshold) + + for k in self.numeric_cols: + self.context_var_dict[k] = ["categorical", self.numeric_context_bins] + if not hasattr(self, "name"): self.name = "custom" @@ -123,9 +123,8 @@ def __init__( if self.normalize: self._init_normalizer() cache_path = self._get_normalization_cache_path() - print("CACHE PATH", cache_path) if cache_path.exists(): - print(f"[{'DDP Subprocess' if is_ddp_subprocess else 'Main Process'}] Loading pre-normalized data from cache") + print(f"[{'DDP Subprocess' if is_ddp_subprocess else 'Main Process'}] Loading pre-normalized data from cache", cache_path) with open(cache_path, 'rb') as f: self.data = pickle.load(f) else: @@ -347,16 +346,17 @@ def _encode_context_vars( """ continuous_vars = [k for k, v in self.cfg.context_vars.items() if v[0] == "continuous"] time_series_cols = [k for k, v in self.cfg.context_vars.items() if v[0] == "time_series"] - numeric_cols = [k for k, v in self.cfg.context_vars.items() if v[0] == "categorical" and v[1] == None] encoded_data, mapping = encode_context_variables( data=data, columns_to_encode=self.context_vars, bins=self.numeric_context_bins, - numeric_cols=numeric_cols, + numeric_cols=self.numeric_cols, continuous_vars=continuous_vars, time_series_cols=time_series_cols, categorical_time_series=self.categorical_time_series, ) + print(mapping, "mapping") + print("ranges encode context vars") return encoded_data, mapping @@ -700,7 +700,7 @@ def _init_normalizer(self) -> None: devices=ncfg.devices, strategy=ncfg.strategy, log_every_n_steps=ncfg.log_every_n_steps, - logger=True, + logger=False, ) trainer.fit(self._normalizer) torch.save(self._normalizer.state_dict(), cache_path) diff --git a/cents/datasets/utils.py b/cents/datasets/utils.py index 0aa2cfc..4fe06f7 100644 --- a/cents/datasets/utils.py +++ b/cents/datasets/utils.py @@ -167,6 +167,7 @@ def encode_context_variables( elif col in time_series_cols or col in continuous_vars: continue elif numeric_cols and col in numeric_cols: + print("ENCODING NUMERIC COL", col) # Numeric column: Perform binning # Handle NaN values by filling with median before binning if encoded_data[col].isna().all(): diff --git a/cents/datasets/vehicle.py b/cents/datasets/vehicle.py new file mode 100644 index 0000000..693478e --- /dev/null +++ b/cents/datasets/vehicle.py @@ -0,0 +1,148 @@ +import os +import warnings +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides + +from cents.datasets.timeseries_dataset import TimeSeriesDataset + +warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class VehicleDataset(TimeSeriesDataset): + """ + Dataset class for Vehicle time series data. + + Handles loading, preprocessin, including normalization and context variables. + + Attributes: + cfg (DictConfig): Hydra config for the dataset. + name (str): Dataset name. + normalize (bool): Whether to apply normalization. + """ + + def __init__( + self, + cfg: Optional[DictConfig] = None, + overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False, + ): + """ + Initialize and preprocess the Vehicle dataset. + + Loads metadata and timeseries CSVs, then applies filtering, + grouping, user-subsetting, and calls the base class for + further preprocessing (normalization, merging, rarity flags). + + Args: + cfg (Optional[DictConfig]): Override Hydra config; if None, + load from `config/dataset/vehicle.yaml`. + overrides (Optional[List[str]]): Override Hydra config; if None, + load from `config/dataset/vehicle.yaml` and apply overrides. + + Raises: + FileNotFoundError: If required CSV files are missing. + """ + if cfg is None: + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "vehicle.yaml")) + if overrides: + cfg = apply_overrides(cfg, overrides) + + self.cfg = cfg + self.name = cfg.name + self.normalize = cfg.normalize + self.time_series_dims = cfg.time_series_dims + self.num_ts_steps = cfg.num_ts_steps + self.seq_len = self.cfg.seq_len + + self._load_data() + + ts_cols: List[str] = self.cfg.time_series_columns[: self.time_series_dims] + + + super().__init__( + data=self.data, + time_series_column_names=ts_cols, + context_var_column_names=list(self.cfg.context_vars.keys()), + seq_len=self.cfg.seq_len, + normalize=self.cfg.normalize, + scale=self.cfg.scale, + skip_heavy_processing=cfg.get('skip_heavy_processing', False), + size=cfg.get('max_samples', None), + force_retrain_normalizer=force_retrain_normalizer, + ) + + def _load_data(self) -> None: + """ + Load . + + Populates self.data DataFrame. + + Raises: + FileNotFoundError: If any required CSV file is missing. + """ + module_dir = os.path.dirname(os.path.abspath(__file__)) + path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) + self.data = pd.read_csv(os.path.join(path, "vehicle_signal_data.csv")) + + + def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: + ''' + Convert timestamps, assemble sequences of length seq_len, and merge metadata. + + Args: + data (pd.DataFrame): Raw concatenated grid (and solar) rows. + + Returns: + pd.DataFrame: One row per sequence, with array-valued 'grid' and + ''' + + # Assemble sequences of length seq_len, each with a prefix context of length self.num_ts_steps + # Time remains raw seconds, do not convert to timestamps. + time_series_cols = self.cfg.time_series_columns[: self.time_series_dims] + context_var_names = list(self.cfg.context_vars.keys()) + data = data.sort_values("Time").reset_index(drop=True) # ensure increasing raw seconds + + # Only build full (context+target) window sequences that fit fully within data + total_window = self.num_ts_steps + self.seq_len + rolling_idxs = ( + pd.Series(np.arange(len(data))) + .rolling(window=total_window) + .apply(lambda x: x[0], raw=True) + .dropna() + .index + ) + + # Preallocate arrays for sequences, for efficiency + out = {col: [] for col in time_series_cols} + for cvar in context_var_names: + out[f"{cvar}"] = [] + out["context_time"] = [] + + for idx in rolling_idxs: + window_slice = data.iloc[idx - total_window + 1 : idx + 1] + context_slice = window_slice.iloc[:self.num_ts_steps] + target_slice = window_slice.iloc[self.num_ts_steps:] + + # Store target sequences + for col in time_series_cols: + out[col].append(target_slice[col].to_numpy()) + + # Store context as array(s) + for cvar in time_series_cols: + # Context for each variable over the context window + if cvar in context_slice.columns: + out[f"context_{cvar}"].append(context_slice[cvar].to_numpy()) + else: + out[f"context_{cvar}"].append([None] *self.num_ts_steps) + + # Optionally, keep the raw "Time" for context window (useful to recover absolute position/relative time) + out["context_time"].append(context_slice["Time"].to_numpy()) + + out_df = pd.DataFrame(out) + return out_df + \ No newline at end of file diff --git a/cents/models/context.py b/cents/models/context.py index 315022c..f11cded 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -39,10 +39,9 @@ def __init__(self, context_vars: dict[str, int], embedding_dim: int): """ super().__init__() self.embedding_dim = embedding_dim - self.context_embeddings = nn.ModuleDict( { - name: nn.Embedding(num_categories, embedding_dim) + name: nn.Embedding(num_categories[1], embedding_dim) for name, num_categories in context_vars.items() } ) @@ -59,7 +58,7 @@ def __init__(self, context_vars: dict[str, int], embedding_dim: int): var_name: nn.Sequential( nn.Linear(embedding_dim, embedding_dim), nn.ReLU(), - nn.Linear(embedding_dim, num_categories) + nn.Linear(embedding_dim, num_categories[1]) ) for var_name, num_categories in context_vars.items() } @@ -78,20 +77,37 @@ def forward( embedding (Tensor): Combined embedding of shape (batch_size, embedding_dim). classification_logits (Dict[str, Tensor]): Logits per variable, each of shape (batch_size, num_categories). - """ + """ +# # At start of forward, before any embedding(context_vars[name]) +# for name in context_vars: +# t = context_vars[name] +# if t.dtype in (torch.long, torch.int): +# t_cpu = t.detach().cpu() +# print(f"{name}: shape={t_cpu.shape}, min={t_cpu.min().item()}, max={t_cpu.max().item()}") +# print(context_vars.keys(), self.context_embeddings.keys(), "context_vars and context_embeddings") embeddings = [ layer(context_vars[name]) for name, layer in self.context_embeddings.items() ] + # print("max", embeddings[0].max(), "min", embeddings[0].min(), "mean", embeddings[0].mean(), "std", embeddings[0].std(), "nan", embeddings[0].isnan().sum(), "inf", embeddings[0].isinf().sum()) + # print(embeddings, "embeddings") + context_matrix = torch.cat(embeddings, dim=1) embedding = self.mlp(context_matrix) + # print("max", embedding.max(), "min", embedding.min(), "mean", embedding.mean(), "std", embedding.std(), "nan", embedding.isnan().sum(), "inf", embedding.isinf().sum()) + + # print(embedding, "embedding") + classification_logits = { var_name: head(embedding) for var_name, head in self.classification_heads.items() } + # print(classification_logits, "classification_logits") + return embedding, classification_logits + @register_context_module("default", "sep_mlp") class SepMLPContextModule(BaseContextModule): def __init__( diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 34bab68..21767bc 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -206,8 +206,7 @@ def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict elif len(embeddings) == 1: embedding = embeddings[0] else: - raise ValueError("No context variables provided") - + raise ValueError("No context variables provided") return embedding, all_logits def predict_noise_from_start( @@ -288,7 +287,6 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di # if x.abs().max() > 100.0: # print(f"[Warning] Input x has extreme values: min={x.min():.4f}, max={x.max():.4f}, " # f"mean={x.mean():.4f}, std={x.std():.4f}, shape={x.shape}") - if torch.isnan(x).any() or torch.isinf(x).any(): raise ValueError(f"NaN/Inf detected in input x. Shape: {x.shape}, " f"NaN count: {torch.isnan(x).sum()}, Inf count: {torch.isinf(x).sum()}") @@ -296,7 +294,6 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di b = x.shape[0] t = torch.randint(0, self.num_timesteps, (b,), device=self.device) embedding, cond_classification_logits = self._get_context_embedding(context_vars) - # Check embedding for NaN/Inf if embedding.isnan().any() or embedding.isinf().any(): raise ValueError( @@ -314,25 +311,20 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di # Check diffusion schedule parameters noise = torch.randn_like(x) - x_noisy = ( self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * noise ) - if x_noisy.isnan().any() or x_noisy.isinf().any(): raise ValueError(f"NaN/Inf detected in x_noisy. Shape: {x_noisy.shape}, " f"NaN count: {torch.isnan(x_noisy).sum()}, Inf count: {torch.isinf(x_noisy).sum()}") - # Use normalized embedding for concatenation embedding_expanded = embedding.unsqueeze(1).repeat(1, self.seq_len, 1) c = torch.cat([x_noisy, embedding_expanded], dim=-1) - if c.isnan().any() or c.isinf().any(): raise ValueError(f"NaN/Inf detected in concatenated input 'c'. " f"Shape: {c.shape}, x_noisy stats: min={x_noisy.min():.4f}, max={x_noisy.max():.4f}, " f"embedding stats: min={embedding.min():.4f}, max={embedding.max():.4f}") - # if c.isnan().any() or c.isinf().any(): # raise ValueError( # f"NaN/Inf detected in concatenated input 'c'. " @@ -342,7 +334,6 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di # f"std={embedding.std():.4f}, min={embedding.min():.4f}, " # f"max={embedding.max():.4f}" # ) - trend, season = self.model(c, t, padding_masks=None) x_recon = self.fc(trend + season) rec_loss = self.recon_loss_fn(x_recon, x) @@ -365,14 +356,14 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: cond_loss = 0.0 - # for var_name, outputs in cond_class_logits.items(): - # labels = cond_batch[var_name] - # if var_name in self.continuous_context_vars: - # loss = F.mse_loss(outputs, labels.float()) - # elif var_name in self.categorical_context_vars: - # loss = self.auxiliary_loss(outputs, labels) + for var_name, outputs in cond_class_logits.items(): + labels = cond_batch[var_name] + if var_name in self.continuous_context_vars: + loss = F.mse_loss(outputs, labels.float()) + elif var_name in self.categorical_context_vars: + loss = self.auxiliary_loss(outputs, labels) - # cond_loss += loss.mean() + cond_loss += loss.mean() # # if var_name in self.continuous_context_vars: # # print(var_name) @@ -380,7 +371,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: # # print(outputs.mean(), labels.mean()) - # cond_loss /= len(cond_class_logits) + cond_loss /= len(cond_class_logits) h, _ = self._get_context_embedding(cond_batch) tc_term = ( @@ -404,7 +395,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: { "train_loss": total_loss.item(), "rec_loss": rec_loss.item(), - # "cond_loss": cond_loss.item(), + "cond_loss": cond_loss.item(), "tc_loss": tc_term, }, prog_bar=True, @@ -477,8 +468,9 @@ def on_after_backward(self) -> None: missing_vars.add(parts[1]) print(f"[Warning] {len(context_params_no_grad)} context module parameters have no gradients!") if missing_vars: - print(f" Missing context variables: {sorted(missing_vars)}") - print(f" No grad params (sample): {context_params_no_grad[:5]}...") + pass + # print(f" Missing context variables: {sorted(missing_vars)}") + # print(f" No grad params (sample): {context_params_no_grad[:5]}...") if context_params_with_grad: avg_grad_norm = sum(g[1] for g in context_params_with_grad) / len(context_params_with_grad) max_grad_norm = max(g[1] for g in context_params_with_grad) diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index c9ba91d..9efad1f 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -61,33 +61,55 @@ def __init__( # This helps with training stability self._initialize_output_layer() - def _initialize_output_layer(self): - """Initialize the output layer to reasonable starting values.""" - # Get the last linear layer - output_layer = self.net[-1] + def _initialize_output_layer(self, init_sigma: float = 1.0): + """ + Initialize the last layer so log_sigma starts around log(init_sigma). + Assumes outputs are later reshaped as out.view(B, K, D) where K=2 or 4. + Therefore log_sigma lives in the SECOND block: indices [D:2D) in the flattened vector. + """ + assert init_sigma > 0.0, "init_sigma must be > 0" + D = self.time_series_dims + K = 4 if self.do_scale else 2 + + out_layer = self.net[-1] + if not isinstance(out_layer, nn.Linear): + raise RuntimeError("Expected last module in self.net to be nn.Linear") + with torch.no_grad(): - # Initialize all weights with small values - # nn.init.xavier_uniform_(output_layer.weight, gain=1.0) - - # Initialize all biases to zero first - # nn.init.zeros_(output_layer.bias) - - # For log_sigma outputs (indices 1, 3, 5, ...), initialize bias to small negative - # This makes exp(log_sigma) start around 0.1-1.0 - if self.do_scale: - # Pattern: mu, log_sigma, z_min, z_max for each dimension - for dim_idx in range(self.time_series_dims): - # log_sigma is at index 1 + 4*dim_idx - log_sigma_idx = 1 + 4 * dim_idx - # Initialize to 3.0: exp(3.0) ≈ 20, closer to typical sigma ~27 - output_layer.bias[log_sigma_idx].fill_(3.0) - else: - # Pattern: mu, log_sigma for each dimension - for dim_idx in range(self.time_series_dims): - # log_sigma is at index 1 + 2*dim_idx - log_sigma_idx = 1 + 2 * dim_idx - # Initialize to 3.0: exp(3.0) ≈ 20, closer to typical sigma ~27 - output_layer.bias[log_sigma_idx].fill_(3.0) + # Reasonable default: keep biases at 0, then set log_sigma bias block. + nn.init.zeros_(out_layer.bias) + + # Optional: small weights so output starts near the bias. + # (If you prefer PyTorch defaults, comment this out.) + nn.init.xavier_uniform_(out_layer.weight, gain=0.01) + + # Set log_sigma block bias + log_sigma_bias = np.log(init_sigma) + start = 1 * D + end = 2 * D + out_layer.bias[start:end].fill_(log_sigma_bias) + + # If you want, you can also bias z_min/z_max to plausible values here when do_scale=True + # Example (optional): + # if self.do_scale: + # out_layer.bias[2*D:3*D].fill_(-2.0) # z_min + # out_layer.bias[3*D:4*D].fill_( 2.0) # z_max + + # Sanity: ensure output dimension matches expectation + expected_out = K * D + if out_layer.out_features != expected_out: + raise ValueError( + f"Output layer out_features={out_layer.out_features}, expected {expected_out} (K={K}, D={D})." + ) + + @staticmethod + def _soft_clamp_tanh(x: torch.Tensor, bound: float) -> torch.Tensor: + """ + Smoothly maps x into [-bound, bound] using tanh. + """ + if bound <= 0: + raise ValueError("bound must be > 0") + return bound * torch.tanh(x / bound) def forward(self, z: torch.Tensor): """ @@ -121,10 +143,7 @@ def forward(self, z: torch.Tensor): # Store unclamped version for loss computation BEFORE clamping # This must be done before any operations that might break the computation graph pred_log_sigma_unclamped = pred_log_sigma - - # Clamp log_sigma to prevent exp() from producing infinity - # exp(88) ≈ 1.6e38 (near float32 max), so clamp to reasonable range - pred_log_sigma_clamped = torch.clamp(pred_log_sigma, min=-10.0, max=10.0) + pred_log_sigma_clamped = self._soft_clamp_tanh(pred_log_sigma, bound=10.0) pred_sigma = torch.exp(pred_log_sigma_clamped) return pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped @@ -309,7 +328,7 @@ def __init__( StaticContextModuleCls = get_context_module_cls(self.static_module_type) # Filter context_vars to only static ones self.static_context_vars_dict = { - k: v for k, v in self.dataset_cfg.context_vars.items() + k: v for k, v in self.dataset.context_var_dict.items() if k in self.static_context_vars } static_context_module = StaticContextModuleCls( @@ -378,7 +397,8 @@ def setup(self, stage: Optional[str] = None): Lightning hook: prepare training data before training. """ # Compute per-sample statistics - no grouping needed - self.sample_stats = self._compute_per_sample_stats() + mode = getattr(self.dataset_cfg, "normalizer_stats_mode", "sample") + self.sample_stats = self._build_training_samples(mode, use_quantile_scale=False) # Log initial predictions to check if model is in the right ballpark if stage == "fit" or stage is None: @@ -428,28 +448,18 @@ def forward(self, cat_vars_dict: dict): return self.normalizer_model(cat_vars_dict) def _compute_loss_mse(self, pred_mu, pred_sigma, pred_log_sigma_unclamped, mu_t, sigma_t): - """ - Compute MSE loss for mu and sigma. - - Args: - pred_mu: Predicted means - pred_sigma: Predicted standard deviations - pred_log_sigma_unclamped: Unclamped log sigma predictions - mu_t: Target means - sigma_t: Target standard deviations - - Returns: - loss_mu, loss_sigma - """ - # Use standard MSE loss for mu loss_mu = F.mse_loss(pred_mu, mu_t) - # Use log-space loss for sigma - this is more numerically stable - # and handles scale differences better - target_log_sigma = torch.log(sigma_t + 1e-8) # Add small epsilon to avoid log(0) - loss_sigma = F.mse_loss(pred_log_sigma_unclamped, target_log_sigma) + # FIX: Clamp the target log sigma to a reasonable range + # (e.g., nothing smaller than e^-5 approx 0.006) + target_log_sigma = torch.log(sigma_t + 1e-8) + target_log_sigma = torch.clamp(target_log_sigma, min=-5.0, max=10.0) + + # FIX: Use Huber Loss (SmoothL1Loss) instead of MSE for stability + loss_sigma = F.smooth_l1_loss(pred_log_sigma_unclamped, target_log_sigma) return loss_mu, loss_sigma + def _compute_loss_gaussian_nll(self, pred_mu, pred_sigma, mu_t, sigma_t): """ @@ -509,9 +519,34 @@ def training_step(self, batch, batch_idx: int): f"Unknown loss_type: {self.loss_type}. " f"Supported types: 'mse', 'gaussian_nll'" ) + total_loss = loss_mu + loss_sigma + + # Log prediction statistics to monitor if model is learning~ + # if batch_idx % 500000 == 0: # Log every 100 batches to avoid spam + # with torch.no_grad(): + # # Debug: Check shapes and actual errors + # print(f"\n[Batch {batch_idx}] Debug Loss Computation:") + # print(f" pred_mu shape: {pred_mu.shape}, mu_t shape: {mu_t.shape}") + # print(f" pred_mu mean: {pred_mu.mean().item():.4f}, mu_t mean: {mu_t.mean().item():.4f}") + # print(f" pred_mu range: [{pred_mu.min().item():.4f}, {pred_mu.max().item():.4f}]") + # print(f" mu_t range: [{mu_t.min().item():.4f}, {mu_t.max().item():.4f}]") + # mu_errors = (pred_mu - mu_t).abs() + # print(f" mu errors: mean={mu_errors.mean().item():.4f}, max={mu_errors.max().item():.4f}, min={mu_errors.min().item():.4f}") + # mu_squared_errors = (pred_mu - mu_t) ** 2 + # print(f" mu squared errors: mean={mu_squared_errors.mean().item():.4f}, max={mu_squared_errors.max().item():.4f}") + # print(f" loss_mu (computed): {loss_mu.item():.4f}") + # print(f" loss_mu (manual mean): {mu_squared_errors.mean().item():.4f}") + + # self.log("pred_mu_mean", pred_mu.mean(), on_step=True, on_epoch=False) + # self.log("pred_mu_std", pred_mu.std(), on_step=True, on_epoch=False) + # self.log("pred_sigma_mean", pred_sigma.mean(), on_step=True, on_epoch=False) + # self.log("pred_sigma_std", pred_sigma.std(), on_step=True, on_epoch=False) + # self.log("target_mu_mean", mu_t.mean(), on_step=True, on_epoch=False) + # self.log("target_sigma_mean", sigma_t.mean(), on_step=True, on_epoch=False) + if self.do_scale: if torch.isnan(pred_z_min).any() or torch.isnan(pred_z_max).any(): raise ValueError( @@ -542,12 +577,34 @@ def training_step(self, batch, batch_idx: int): # Log prediction statistics to monitor if model is learning if batch_idx % 100 == 0: # Log every 100 batches to avoid spam with torch.no_grad(): - self.log("pred_mu_mean", pred_mu.mean(), on_step=True, on_epoch=False) - self.log("pred_mu_std", pred_mu.std(), on_step=True, on_epoch=False) - self.log("pred_sigma_mean", pred_sigma.mean(), on_step=True, on_epoch=False) - self.log("pred_sigma_std", pred_sigma.std(), on_step=True, on_epoch=False) - self.log("target_mu_mean", mu_t.mean(), on_step=True, on_epoch=False) - self.log("target_sigma_mean", sigma_t.mean(), on_step=True, on_epoch=False) + # Log shapes (as number of elements for logging purposes) + self.log("pred_mu_num_elements", pred_mu.numel(), on_step=True, on_epoch=False) + self.log("mu_t_num_elements", mu_t.numel(), on_step=True, on_epoch=False) + self.log("pred_mu_batch_size", pred_mu.shape[0] if len(pred_mu.shape) > 0 else 1, on_step=True, on_epoch=False) + self.log("pred_mu_dims", pred_mu.shape[1] if len(pred_mu.shape) > 1 else 1, on_step=True, on_epoch=False) + + # Log ranges + self.log("pred_mu_min", pred_mu.min(), on_step=True, on_epoch=False) + self.log("pred_mu_max", pred_mu.max(), on_step=True, on_epoch=False) + self.log("mu_t_min", mu_t.min(), on_step=True, on_epoch=False) + self.log("mu_t_max", mu_t.max(), on_step=True, on_epoch=False) + + # Log error statistics + mu_errors = (pred_mu - mu_t).abs() + mu_squared_errors = (pred_mu - mu_t) ** 2 + self.log("mu_error_mean", mu_errors.mean(), on_step=True, on_epoch=False) + self.log("mu_error_max", mu_errors.max(), on_step=True, on_epoch=False) + self.log("mu_error_min", mu_errors.min(), on_step=True, on_epoch=False) + self.log("mu_squared_error_mean", mu_squared_errors.mean(), on_step=True, on_epoch=False) + self.log("mu_squared_error_max", mu_squared_errors.max(), on_step=True, on_epoch=False) + + # Log existing statistics + self.log("pred_mu_mean", pred_mu.mean(), on_step=True, on_epoch=True) + self.log("pred_mu_std", pred_mu.std(), on_step=True, on_epoch=True) + self.log("pred_sigma_mean", pred_sigma.mean(), on_step=True, on_epoch=True) + self.log("pred_sigma_std", pred_sigma.std(), on_step=True, on_epoch=True) + self.log("target_mu_mean", mu_t.mean(), on_step=True, on_epoch=True) + self.log("target_sigma_mean", sigma_t.mean(), on_step=True, on_epoch=True) return total_loss @@ -891,3 +948,166 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: arr = z * (sigma[d] + 1e-8) + mu[d] df_out.at[i, col] = arr return df_out + + def _build_training_samples( + self, + mode: str = "sample", # "sample" or "group" + group_vars: Optional[list[str]] = None, + use_quantile_scale: bool = False, # if True: use q01/q99 instead of min/max for zlow/zhigh + q_low: float = 0.01, + q_high: float = 0.99, + ) -> list: + """ + Build training samples for the normalizer. + + Returns a list of tuples: + (context_vars_dict, dynamic_ctx_dict, mu_array, std_array, zlow_array, zhigh_array) + + - mode="sample": one tuple per row + - mode="group": one tuple per group (grouped by group_vars) + + Notes: + - group mode is only well-defined for *static* variables. For dynamic context, + this function will raise unless you explicitly exclude them from grouping. + - continuous vars: if you keep them continuous (float), grouping by them is usually pointless. + In group mode, we therefore ignore continuous vars by default unless you explicitly put them in group_vars. + """ + assert mode in {"sample", "group"}, f"mode must be 'sample' or 'group', got {mode}" + + df = self.dataset.data.copy() + + # Identify context types + continuous_vars = set(getattr(self.dataset_cfg, "continuous_context_vars", None) or []) + dynamic_vars = set(self.dynamic_context_vars) # time_series context vars + static_vars = [v for v in self.static_context_vars] # categorical + continuous + + # Default grouping vars: static categorical only (exclude continuous + dynamic) + if group_vars is None: + group_vars = [v for v in static_vars if (v not in continuous_vars and v not in dynamic_vars)] + + # Sanity: grouping by dynamic vars is almost always wrong (huge keys, high cardinality) + bad = [v for v in group_vars if v in dynamic_vars] + if bad: + raise ValueError( + f"group_vars contains dynamic(time_series) vars {bad}. " + f"Remove them or use mode='sample'." + ) + + # Helper: compute stats from a single row's time series columns + def _row_stats(row) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: + dim_points = [] + for d, col_name in enumerate(self.time_series_cols): + arr = np.asarray(row[col_name], dtype=np.float32).reshape(-1) + dim_points.append(arr) + + mu = np.array([x.mean() for x in dim_points], dtype=np.float32) + std = np.array([x.std() + 1e-8 for x in dim_points], dtype=np.float32) + + if not self.do_scale: + return mu, std, None, None + + zlow = np.zeros(self.time_series_dims, dtype=np.float32) + zhigh = np.zeros(self.time_series_dims, dtype=np.float32) + + for i, (x, m, s) in enumerate(zip(dim_points, mu, std)): + z = (x - m) / s + if use_quantile_scale: + zlow[i] = np.quantile(z, q_low).astype(np.float32) + zhigh[i] = np.quantile(z, q_high).astype(np.float32) + else: + zlow[i] = z.min().astype(np.float32) + zhigh[i] = z.max().astype(np.float32) + + return mu, std, zlow, zhigh + + samples = [] + + if mode == "sample": + for _, row in df.iterrows(): + context_vars_dict = {} + + # static vars + for v in static_vars: + if v not in row: + continue + if v in continuous_vars: + context_vars_dict[v] = torch.tensor(row[v], dtype=torch.float32) + else: + context_vars_dict[v] = torch.tensor(row[v], dtype=torch.long) + + # dynamic vars (store separately; TrainSet will tensorize them) + dynamic_ctx_dict = {} + for v in self.dynamic_context_vars: + if v not in row: + continue + ts_data = row[v] + if isinstance(ts_data, np.ndarray): + dynamic_ctx_dict[v] = ts_data + elif isinstance(ts_data, list): + dynamic_ctx_dict[v] = np.array(ts_data) + else: + # scalar -> repeat + L = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len + dynamic_ctx_dict[v] = np.full(L, ts_data) + + mu, std, zlow, zhigh = _row_stats(row) + + samples.append((context_vars_dict, dynamic_ctx_dict, mu, std, zlow, zhigh)) + + return samples + + # mode == "group" + # Build grouped stats by aggregating all points from rows in each group. + # Context dict comes from group key values. + grouped = df.groupby(group_vars, dropna=False) + + for group_key, gdf in grouped: + # group_key can be scalar or tuple depending on #group_vars + if len(group_vars) == 1: + group_key = (group_key,) + + context_vars_dict = {} + for i, v in enumerate(group_vars): + # group vars should be categorical by default; cast to long. + # if user explicitly included a continuous var in group_vars, keep float. + if v in continuous_vars: + context_vars_dict[v] = torch.tensor(group_key[i], dtype=torch.float32) + else: + context_vars_dict[v] = torch.tensor(group_key[i], dtype=torch.long) + + dynamic_ctx_dict = {} # undefined for grouping; keep empty + + # Aggregate raw points per dim + dim_points = [[] for _ in range(self.time_series_dims)] + for _, row in gdf.iterrows(): + for d, col_name in enumerate(self.time_series_cols): + arr = np.asarray(row[col_name], dtype=np.float32).reshape(-1) + dim_points[d].append(arr) + + dim_points = [np.concatenate(xs, axis=0) if len(xs) else np.zeros((0,), dtype=np.float32) + for xs in dim_points] + + mu = np.array([x.mean() if x.size else 0.0 for x in dim_points], dtype=np.float32) + std = np.array([x.std() + 1e-8 if x.size else 1.0 for x in dim_points], dtype=np.float32) + + if self.do_scale: + zlow = np.zeros(self.time_series_dims, dtype=np.float32) + zhigh = np.zeros(self.time_series_dims, dtype=np.float32) + for i, (x, m, s) in enumerate(zip(dim_points, mu, std)): + if x.size == 0: + zlow[i], zhigh[i] = -2.0, 2.0 + continue + z = (x - m) / s + if use_quantile_scale: + zlow[i] = np.quantile(z, q_low).astype(np.float32) + zhigh[i] = np.quantile(z, q_high).astype(np.float32) + else: + zlow[i] = z.min().astype(np.float32) + zhigh[i] = z.max().astype(np.float32) + else: + zlow = zhigh = None + + samples.append((context_vars_dict, dynamic_ctx_dict, mu, std, zlow, zhigh)) + + return samples + diff --git a/cents/trainer.py b/cents/trainer.py index 6e827ee..5374944 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -90,6 +90,7 @@ def fit(self, ckpt_path: Optional[str] = None) -> "Trainer": num_workers=6, # Maximum for 7.5GB/10GB GPU usage persistent_workers=True, ) + print(f"[Cents] Training model on {len(train_loader)} batches") self.pl_trainer.fit(self.model, train_loader, None, ckpt_path=ckpt_path) return self diff --git a/scripts/train.py b/scripts/train.py index 9e861b3..09cc81e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -98,7 +98,7 @@ def main(args) -> None: help="Evaluate after training") parser.add_argument("--skip_heavy_processing", action="store_true", help="Skip heavy processing of dataset") - parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_false") + parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_true") parser.add_argument("--enable_checkpointing", action="store_true", help="Enable checkpointing") parser.add_argument("--context-config-path", type=str, default=None, From eff190cc0d0e1a28f2c116cb959c68ddd0cc4c6a Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Tue, 3 Feb 2026 17:14:11 -0500 Subject: [PATCH 27/50] Added additional diffusion training objectives ; snr ; revised EMA --- cents/config/context/default.yaml | 2 +- cents/config/model/diffusion_ts.yaml | 6 +- cents/config/trainer/diffusion_ts.yaml | 4 +- cents/datasets/timeseries_dataset.py | 3 - cents/eval/eval.py | 12 +- cents/models/diffusion_ts.py | 298 ++++++++++++++++----- cents/models/normalizer.py | 70 ++--- scripts/eval_pretrained.py | 18 ++ tests/test_configs/model/diffusion_ts.yaml | 3 + 9 files changed, 286 insertions(+), 130 deletions(-) diff --git a/cents/config/context/default.yaml b/cents/config/context/default.yaml index d2585ac..6327799 100644 --- a/cents/config/context/default.yaml +++ b/cents/config/context/default.yaml @@ -3,7 +3,7 @@ # Static context: used by generative models (ACGAN, Diffusion_TS) for conditioning static_context: - type: sep_mlp # Context module type (e.g., "mlp", "sep_mlp") + type: mlp # Context module type (e.g., "mlp", "sep_mlp") # Future parameters can be added here: # n_layers: 2 # hidden_dim: 256 diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 7cb76c1..6406413 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -11,6 +11,9 @@ n_steps: 1000 sampling_timesteps: 1000 sampling_batch_size: 4096 loss_type: l1 #l2 +training_objective: eps +loss_weighting: default +min_snr_gamma: 5.0 beta_schedule: cosine #linear diffusion ts paper uses linear schedule n_heads: 4 mlp_hidden_times: 4 @@ -24,4 +27,5 @@ reg_weight: null gradient_accumulate_every: 2 ema_decay: 0.99 ema_update_interval: 10 -use_ema_sampling: True \ No newline at end of file +use_ema_sampling: True +k_bins: 20 \ No newline at end of file diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 18dbd68..961d3c1 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -17,8 +17,8 @@ checkpoint: lr_scheduler_params: factor: 0.5 - patience: 200 - min_lr: 1.0e-5 + patience: 50 + min_lr: 1.0e-6 threshold: 1.0e-1 threshold_mode: rel verbose: false diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 78545d5..be6f943 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -241,7 +241,6 @@ def get_train_dataloader( Returns: DataLoader: Configured data loader. """ - continuous_vars = getattr(self.cfg, "continuous_context_vars", None) or [] self._normalize_continuous_vars() return DataLoader( @@ -355,8 +354,6 @@ def _encode_context_vars( time_series_cols=time_series_cols, categorical_time_series=self.categorical_time_series, ) - print(mapping, "mapping") - print("ranges encode context vars") return encoded_data, mapping diff --git a/cents/eval/eval.py b/cents/eval/eval.py index c6f336d..3b9334b 100644 --- a/cents/eval/eval.py +++ b/cents/eval/eval.py @@ -348,12 +348,12 @@ def evaluate_subset( """ dataset.data = dataset.get_combined_rarity() real_data_subset = dataset.data.iloc[indices].reset_index(drop=True) - context_vars = { - name: torch.tensor( - real_data_subset[name].values, dtype=torch.long, device=self.device - ) - for name in dataset.context_vars - } + continuous_vars = getattr(dataset, "continuous_vars", []) + context_vars = {} + for name in dataset.context_vars: + vals = real_data_subset[name].values + dtype = torch.float32 if name in continuous_vars else torch.long + context_vars[name] = torch.tensor(vals, dtype=dtype, device=self.device) generated_ts = model.generate(context_vars).cpu().numpy() if generated_ts.ndim == 2: diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 21767bc..d5af5d6 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -10,6 +10,8 @@ from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau from tqdm.auto import tqdm +from contextlib import contextmanager + from cents.models.base import GenerativeModel from cents.models.model_utils import ( @@ -30,6 +32,18 @@ class Diffusion_TS(GenerativeModel): Uses a Transformer backbone to predict and denoise time series over discrete diffusion timesteps. Supports EMA smoothing and configurable beta schedules. + + Training objective (config model.training_objective): x0, epsilon, or v. + - x0: predict clean sample; loss = L1/L2(model_out, x_clean). + - epsilon: predict noise; loss = L1/L2(pred_epsilon, noise); pred_epsilon derived from model x0. + - v: v-parameterization; loss = L1/L2(pred_v, true_v); pred_v derived from model x0. + The network always outputs x0 (fc layer); sampling uses x0 and q_posterior unchanged. + Variance-debugging reference: + (a) Sampling: model predicts x0; reverse step uses q_posterior(x0,x_t,t); + noise term = sqrt(posterior_variance). Same beta schedule in train and sample. + (b) Normalization: per-group (context) z-score + optional [zmin,zmax] scale; + denorm must use the same normalizer.inverse_transform. Run + scripts/check_diffusion_consistency.py to verify norm/denorm identity. """ def __init__(self, cfg: DictConfig): @@ -96,6 +110,13 @@ def __init__(self, cfg: DictConfig): ) self.fast_sampling = self.sampling_timesteps < self.num_timesteps self.loss_type = cfg.model.loss_type + self.training_objective = getattr( + cfg.model, "training_objective", "x0" + ).lower() + if self.training_objective not in ("x0", "eps", "v"): + raise ValueError( + f"training_objective must be one of x0, eps, v; got {self.training_objective}" + ) # register buffers for diffusion coefficients self.register_buffer("betas", betas) @@ -127,7 +148,18 @@ def __init__(self, cfg: DictConfig): self.register_buffer("posterior_mean_coef1", pmc1) self.register_buffer("posterior_mean_coef2", pmc2) - lw = torch.sqrt(alphas) * torch.sqrt(1.0 - alphas_cumprod) / betas / 100 + # Loss weighting: default (legacy), uniform, snr, or min_snr (via compute_snr_weights) + loss_weighting = getattr(cfg.model, "loss_weighting", "default") + min_snr_gamma = getattr(cfg.model, "min_snr_gamma", 5.0) + if loss_weighting == "default": + lw = torch.sqrt(alphas) * torch.sqrt(1.0 - alphas_cumprod) / betas.clamp(min=1e-8) / 100 + else: + lw = self.compute_snr_weights( + alphas_cumprod, + loss_weighting=loss_weighting, + objective=self.training_objective, + gamma=min_snr_gamma, + ) self.register_buffer("loss_weight", lw) # choose reconstruction loss @@ -246,6 +278,78 @@ def predict_start_from_noise( - self.sqrt_recipm1_alphas_cumprod[t].view(-1, 1, 1) * noise ) + def predict_start_from_v( + self, x_t: torch.Tensor, t: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + """ + Reconstruct x0 from x_t and v-parameterization. + v = sqrt(alpha_bar_t) * epsilon - sqrt(1 - alpha_bar_t) * x0 => x0 = sqrt(alpha_bar_t) * x_t - sqrt(1 - alpha_bar_t) * v + """ + return ( + self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x_t + - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * v + ) + + def predict_noise_from_v( + self, x_t: torch.Tensor, t: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + """ + Reconstruct epsilon from x_t and v-parameterization. + v = sqrt(alpha_bar_t) * epsilon - sqrt(1 - alpha_bar_t) * x0 => epsilon = sqrt(1 - alpha_bar_t) * x_t + sqrt(alpha_bar_t) * v + """ + return ( + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x_t + + self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * v + ) + + + def compute_snr_weights( + alphas_cumprod: torch.Tensor, + *, + loss_weighting: str, + objective: str, + gamma: float = 5.0, + ) -> torch.Tensor: + """ + SNR-based loss weighting per timestep. + + Args: + alphas_cumprod: Cumulative product of alphas, shape (n_steps,). + loss_weighting: "uniform" | "snr" | "min_snr". + objective: "eps" | "x0" | "v" — must match training objective. + gamma: Cap for SNR when loss_weighting == "min_snr". + + Returns: + Weight tensor same shape as alphas_cumprod. + """ + snr = alphas_cumprod / (1.0 - alphas_cumprod).clamp(min=1e-8) + + if loss_weighting == "uniform": + return torch.ones_like(snr) + + if loss_weighting == "snr": + if objective == "eps": + return 1.0 / (snr + 1.0) + elif objective == "x0": + return snr / (snr + 1.0) + elif objective == "v": + return 1.0 / torch.sqrt(snr + 1.0) + else: + raise ValueError(objective) + + if loss_weighting == "min_snr": + snr_c = torch.minimum(snr, torch.full_like(snr, gamma)) + if objective == "eps": + return snr_c / snr.clamp(min=1e-8) + elif objective == "x0": + return snr_c + elif objective == "v": + return snr_c / (snr + 1.0) + else: + raise ValueError(objective) + + raise ValueError(loss_weighting) + def q_posterior( self, x_start: torch.Tensor, @@ -292,7 +396,8 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di f"NaN count: {torch.isnan(x).sum()}, Inf count: {torch.isinf(x).sum()}") b = x.shape[0] - t = torch.randint(0, self.num_timesteps, (b,), device=self.device) + # t = torch.randint(0, self.num_timesteps, (b,), device=self.device) + t = self.stratified_timesteps(b, self.num_timesteps, self.cfg.model.k_bins, device=self.device) embedding, cond_classification_logits = self._get_context_embedding(context_vars) # Check embedding for NaN/Inf if embedding.isnan().any() or embedding.isinf().any(): @@ -335,8 +440,27 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di # f"max={embedding.max():.4f}" # ) trend, season = self.model(c, t, padding_masks=None) - x_recon = self.fc(trend + season) - rec_loss = self.recon_loss_fn(x_recon, x) + x_start_pred = self.fc(trend + season) + # Compute loss based on training objective (network always predicts x0; we derive epsilon/v as needed) + if self.training_objective == "x0": + loss_per_elem = self.recon_loss_fn(x_start_pred, x, reduction="none") + elif self.training_objective == "epsilon": + pred_noise = self.predict_noise_from_start(x_noisy, t, x_start_pred) + loss_per_elem = self.recon_loss_fn(pred_noise, noise, reduction="none") + else: # v + pred_noise = self.predict_noise_from_start(x_noisy, t, x_start_pred) + pred_v = ( + self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * pred_noise + - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x_start_pred + ) + true_v = ( + self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * noise + - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x + ) + loss_per_elem = self.recon_loss_fn(pred_v, true_v, reduction="none") + rec_loss = ( + self.loss_weight[t].view(-1, 1, 1) * loss_per_elem + ).mean() return rec_loss, cond_classification_logits def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: @@ -371,7 +495,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: # # print(outputs.mean(), labels.mean()) - cond_loss /= len(cond_class_logits) + # cond_loss /= len(cond_class_logits) h, _ = self._get_context_embedding(cond_batch) tc_term = ( @@ -399,6 +523,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: "tc_loss": tc_term, }, prog_bar=True, + sync_dist=True, ) return total_loss @@ -471,10 +596,10 @@ def on_after_backward(self) -> None: pass # print(f" Missing context variables: {sorted(missing_vars)}") # print(f" No grad params (sample): {context_params_no_grad[:5]}...") - if context_params_with_grad: - avg_grad_norm = sum(g[1] for g in context_params_with_grad) / len(context_params_with_grad) - max_grad_norm = max(g[1] for g in context_params_with_grad) - print(f"[Debug] Context module gradients: avg_norm={avg_grad_norm:.6f}, max_norm={max_grad_norm:.6f}") + # if context_params_with_grad: + # avg_grad_norm = sum(g[1] for g in context_params_with_grad) / len(context_params_with_grad) + # max_grad_norm = max(g[1] for g in context_params_with_grad) + # print(f"[Debug] Context module gradients: avg_norm={avg_grad_norm:.6f}, max_norm={max_grad_norm:.6f}") def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: """ @@ -609,6 +734,18 @@ def fast_sample( x = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise return x + @contextmanager + def ema_scope(self): + if hasattr(self, "_ema") and self._ema and getattr(self.cfg.model, "use_ema_sampling", False): + self._ema.store(self.model.parameters()) + self._ema.copy_to(self.model.parameters()) + try: + yield + finally: + self._ema.restore(self.model.parameters()) + else: + yield + def generate(self, context_vars: dict) -> torch.Tensor: """ Public entry to generate conditioned samples in batches. @@ -623,48 +760,29 @@ def generate(self, context_vars: dict) -> torch.Tensor: total = len(next(iter(context_vars.values()))) generated_samples = [] - for start_idx in tqdm( - range(0, total, bs), - unit="seq", - desc="[CENTS] Generating samples", - leave=True, - ): - end_idx = min(start_idx + bs, total) - batch_context_vars = { - var_name: var_tensor[start_idx:end_idx] - for var_name, var_tensor in context_vars.items() - } - current_bs = end_idx - start_idx - shape = (current_bs, self.seq_len, self.time_series_dims) - - with torch.no_grad(): - if getattr(self.cfg.model, "use_ema_sampling", False): - self._ensure_ema_helper() - if hasattr(self, "_ema") and self._ema: - original_model = self.model - self.model = self._ema.ema_model - try: - if self.fast_sampling: - samples = self.fast_sample(shape, batch_context_vars) - else: - samples = self.sample(shape, batch_context_vars) - finally: - # Restore original model - self.model = original_model + with self.ema_scope(): + for start_idx in tqdm( + range(0, total, bs), + unit="seq", + desc="[CENTS] Generating samples", + leave=True, + ): + end_idx = min(start_idx + bs, total) + batch_context_vars = { + var_name: var_tensor[start_idx:end_idx] + for var_name, var_tensor in context_vars.items() + } + current_bs = end_idx - start_idx + shape = (current_bs, self.seq_len, self.time_series_dims) + + with torch.no_grad(): + if self.fast_sampling: + samples = self.fast_sample(shape, batch_context_vars) else: - samples = ( - self.fast_sample(shape, batch_context_vars) - if self.fast_sampling - else self.sample(shape, batch_context_vars) - ) - else: - samples = ( - self.fast_sample(shape, batch_context_vars) - if self.fast_sampling - else self.sample(shape, batch_context_vars) - ) + samples = self.sample(shape, batch_context_vars) + - generated_samples.append(samples) + generated_samples.append(samples.cpu()) return torch.cat(generated_samples, dim=0) @@ -679,41 +797,81 @@ def _ensure_ema_helper(self) -> None: beta=self.cfg.model.ema_decay, update_every=self.cfg.model.ema_update_interval, ) + def stratified_timesteps(self, batch_size: int, num_timesteps: int, k_bins: int, device=None) -> torch.Tensor: + device = device or "cpu" + k_bins = min(k_bins, batch_size) + edges = torch.linspace(0, num_timesteps, steps=k_bins + 1, device=device) + + # sample one t per bin + u = torch.rand(k_bins, device=device) + t_bins = (edges[:-1] + u * (edges[1:] - edges[:-1])).floor().clamp_(0, num_timesteps - 1).long() + + # repeat to fill batch, then shuffle + reps = (batch_size + k_bins - 1) // k_bins + t = t_bins.repeat(reps)[:batch_size] + t = t[torch.randperm(batch_size, device=device)] + return t + class EMA(nn.Module): """ Exponential Moving Average (EMA) helper for model parameters. - - Maintains a shadow copy of the model weights that are updated - via EMA every `update_every` steps. """ - - def __init__(self, model: nn.Module, beta: float, update_every: int): - """ - Args: - model: Base model to copy for EMA tracking. - beta: EMA decay rate (0 < beta < 1). - update_every: Frequency (in steps) to apply EMA update. - """ + def __init__(self, model: nn.Module, beta: float = 0.9999, update_every: int = 10): super().__init__() - self.model = copy.deepcopy(model) - self.ema_model = self.model.eval() self.beta = beta self.update_every = update_every self.step = 0 - for param in self.ema_model.parameters(): - param.requires_grad = False + + # CRITICAL FIX 1: self.ema_model is the ONLY deepcopy. + # It holds the shadow weights. + self.ema_model = copy.deepcopy(model) + self.ema_model.eval() + self.ema_model.requires_grad_(False) + + # CRITICAL FIX 2: We keep a reference to the LIVE model (not a copy) + # so we can grab the latest trained weights during update(). + self.source_model = model + + # Buffer to store temporary weights for the context manager + self.collected_params = [] def update(self) -> None: """ - Perform an EMA update of the shadow model parameters. - Called typically at end of each training batch. + Update the shadow parameters using the source model's current weights. """ self.step += 1 if self.step % self.update_every != 0: return + with torch.no_grad(): - for ema_p, model_p in zip( - self.ema_model.parameters(), self.model.parameters() - ): - ema_p.data.mul_(self.beta).add_(model_p.data, alpha=1.0 - self.beta) + # Zip the shadow model (ema) against the live model (source) + for ema_p, src_p in zip(self.ema_model.parameters(), self.source_model.parameters()): + # ema_new = beta * ema_old + (1 - beta) * current_weight + ema_p.data.mul_(self.beta).add_(src_p.data, alpha=1.0 - self.beta) + + def store(self, parameters): + """ + Save the current parameters (of the live model) to a temporary list. + Used by the context manager to back up weights before swapping. + """ + self.collected_params = [param.clone().cpu() for param in parameters] + + def restore(self, parameters): + """ + Restore the saved parameters back to the live model. + """ + if not self.collected_params: + raise RuntimeError("No parameters stored to restore.") + + for param, saved_param in zip(parameters, self.collected_params): + param.data.copy_(saved_param.data.to(param.device)) + + self.collected_params = [] # Clear memory + + def copy_to(self, parameters): + """ + Copy the EMA shadow weights INTO the live model parameters. + """ + for param, ema_param in zip(parameters, self.ema_model.parameters()): + param.data.copy_(ema_param.data.to(param.device)) \ No newline at end of file diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 9efad1f..613a991 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -523,30 +523,6 @@ def training_step(self, batch, batch_idx: int): total_loss = loss_mu + loss_sigma - - # Log prediction statistics to monitor if model is learning~ - # if batch_idx % 500000 == 0: # Log every 100 batches to avoid spam - # with torch.no_grad(): - # # Debug: Check shapes and actual errors - # print(f"\n[Batch {batch_idx}] Debug Loss Computation:") - # print(f" pred_mu shape: {pred_mu.shape}, mu_t shape: {mu_t.shape}") - # print(f" pred_mu mean: {pred_mu.mean().item():.4f}, mu_t mean: {mu_t.mean().item():.4f}") - # print(f" pred_mu range: [{pred_mu.min().item():.4f}, {pred_mu.max().item():.4f}]") - # print(f" mu_t range: [{mu_t.min().item():.4f}, {mu_t.max().item():.4f}]") - # mu_errors = (pred_mu - mu_t).abs() - # print(f" mu errors: mean={mu_errors.mean().item():.4f}, max={mu_errors.max().item():.4f}, min={mu_errors.min().item():.4f}") - # mu_squared_errors = (pred_mu - mu_t) ** 2 - # print(f" mu squared errors: mean={mu_squared_errors.mean().item():.4f}, max={mu_squared_errors.max().item():.4f}") - # print(f" loss_mu (computed): {loss_mu.item():.4f}") - # print(f" loss_mu (manual mean): {mu_squared_errors.mean().item():.4f}") - - # self.log("pred_mu_mean", pred_mu.mean(), on_step=True, on_epoch=False) - # self.log("pred_mu_std", pred_mu.std(), on_step=True, on_epoch=False) - # self.log("pred_sigma_mean", pred_sigma.mean(), on_step=True, on_epoch=False) - # self.log("pred_sigma_std", pred_sigma.std(), on_step=True, on_epoch=False) - # self.log("target_mu_mean", mu_t.mean(), on_step=True, on_epoch=False) - # self.log("target_sigma_mean", sigma_t.mean(), on_step=True, on_epoch=False) - if self.do_scale: if torch.isnan(pred_z_min).any() or torch.isnan(pred_z_max).any(): raise ValueError( @@ -567,44 +543,44 @@ def training_step(self, batch, batch_idx: int): ) # Log individual components to understand what's happening - self.log("train_loss", total_loss, prog_bar=True) - self.log("loss_mu", loss_mu, on_step=True, on_epoch=True, prog_bar=False) - self.log("loss_sigma", loss_sigma, on_step=True, on_epoch=True, prog_bar=False) + self.log("train_loss", total_loss, prog_bar=True, sync_dist=True) + self.log("loss_mu", loss_mu, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) + self.log("loss_sigma", loss_sigma, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) if self.do_scale: - self.log("loss_zmin", loss_zmin, on_step=True, on_epoch=True, prog_bar=False) - self.log("loss_zmax", loss_zmax, on_step=True, on_epoch=True, prog_bar=False) + self.log("loss_zmin", loss_zmin, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) + self.log("loss_zmax", loss_zmax, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) # Log prediction statistics to monitor if model is learning if batch_idx % 100 == 0: # Log every 100 batches to avoid spam with torch.no_grad(): # Log shapes (as number of elements for logging purposes) - self.log("pred_mu_num_elements", pred_mu.numel(), on_step=True, on_epoch=False) - self.log("mu_t_num_elements", mu_t.numel(), on_step=True, on_epoch=False) + self.log("pred_mu_num_elements", pred_mu.numel(), on_step=True, on_epoch=False, sync_dist=True) + self.log("mu_t_num_elements", mu_t.numel(), on_step=True, on_epoch=False, sync_dist=True) self.log("pred_mu_batch_size", pred_mu.shape[0] if len(pred_mu.shape) > 0 else 1, on_step=True, on_epoch=False) - self.log("pred_mu_dims", pred_mu.shape[1] if len(pred_mu.shape) > 1 else 1, on_step=True, on_epoch=False) + self.log("pred_mu_dims", pred_mu.shape[1] if len(pred_mu.shape) > 1 else 1, on_step=True, on_epoch=False, sync_dist=True) # Log ranges - self.log("pred_mu_min", pred_mu.min(), on_step=True, on_epoch=False) - self.log("pred_mu_max", pred_mu.max(), on_step=True, on_epoch=False) - self.log("mu_t_min", mu_t.min(), on_step=True, on_epoch=False) - self.log("mu_t_max", mu_t.max(), on_step=True, on_epoch=False) + self.log("pred_mu_min", pred_mu.min(), on_step=True, on_epoch=False, sync_dist=True) + self.log("pred_mu_max", pred_mu.max(), on_step=True, on_epoch=False, sync_dist=True) + self.log("mu_t_min", mu_t.min(), on_step=True, on_epoch=False, sync_dist=True) + self.log("mu_t_max", mu_t.max(), on_step=True, on_epoch=False, sync_dist=True) # Log error statistics mu_errors = (pred_mu - mu_t).abs() mu_squared_errors = (pred_mu - mu_t) ** 2 - self.log("mu_error_mean", mu_errors.mean(), on_step=True, on_epoch=False) - self.log("mu_error_max", mu_errors.max(), on_step=True, on_epoch=False) - self.log("mu_error_min", mu_errors.min(), on_step=True, on_epoch=False) - self.log("mu_squared_error_mean", mu_squared_errors.mean(), on_step=True, on_epoch=False) - self.log("mu_squared_error_max", mu_squared_errors.max(), on_step=True, on_epoch=False) + self.log("mu_error_mean", mu_errors.mean(), on_step=True, on_epoch=False, sync_dist=True) + self.log("mu_error_max", mu_errors.max(), on_step=True, on_epoch=False, sync_dist=True) + self.log("mu_error_min", mu_errors.min(), on_step=True, on_epoch=False, sync_dist=True) + self.log("mu_squared_error_mean", mu_squared_errors.mean(), on_step=True, on_epoch=False, sync_dist=True) + self.log("mu_squared_error_max", mu_squared_errors.max(), on_step=True, on_epoch=False, sync_dist=True) # Log existing statistics - self.log("pred_mu_mean", pred_mu.mean(), on_step=True, on_epoch=True) - self.log("pred_mu_std", pred_mu.std(), on_step=True, on_epoch=True) - self.log("pred_sigma_mean", pred_sigma.mean(), on_step=True, on_epoch=True) - self.log("pred_sigma_std", pred_sigma.std(), on_step=True, on_epoch=True) - self.log("target_mu_mean", mu_t.mean(), on_step=True, on_epoch=True) - self.log("target_sigma_mean", sigma_t.mean(), on_step=True, on_epoch=True) + self.log("pred_mu_mean", pred_mu.mean(), on_step=True, on_epoch=True, sync_dist=True) + self.log("pred_mu_std", pred_mu.std(), on_step=True, on_epoch=True, sync_dist=True) + self.log("pred_sigma_mean", pred_sigma.mean(), on_step=True, on_epoch=True, sync_dist=True) + self.log("pred_sigma_std", pred_sigma.std(), on_step=True, on_epoch=True, sync_dist=True) + self.log("target_mu_mean", mu_t.mean(), on_step=True, on_epoch=True, sync_dist=True) + self.log("target_sigma_mean", sigma_t.mean(), on_step=True, on_epoch=True, sync_dist=True) return total_loss diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 0b30637..34b6823 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -1,6 +1,8 @@ import logging import os from pathlib import Path +import torch +import torch.nn.functional as F from omegaconf import OmegaConf import argparse @@ -193,6 +195,22 @@ def main() -> None: gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) cfg.dataset = gen.model.cfg.dataset + + with torch.no_grad(): + betas = gen.model.betas + alphas = 1 - betas + abar = torch.cumprod(alphas, dim=0) + abar_prev = F.pad(abar[:-1], (1,0), value=1.0) + + pv_expected = betas * (1 - abar_prev) / (1 - abar) + print((pv_expected - gen.model.posterior_variance).abs().max()) + + pmc1_expected = betas * abar_prev.sqrt() / (1 - abar) + pmc2_expected = (1 - abar_prev) * alphas.sqrt() / (1 - abar) + + print((pmc1_expected - gen.model.posterior_mean_coef1).abs().max()) + print((pmc2_expected - gen.model.posterior_mean_coef2).abs().max()) + logging.info("Checkpoint loaded. Starting evaluation...") results = Evaluator(cfg, dataset).evaluate_model(data_generator=gen) diff --git a/tests/test_configs/model/diffusion_ts.yaml b/tests/test_configs/model/diffusion_ts.yaml index 1750bc9..e52666c 100644 --- a/tests/test_configs/model/diffusion_ts.yaml +++ b/tests/test_configs/model/diffusion_ts.yaml @@ -11,6 +11,9 @@ n_steps: 1000 sampling_timesteps: 1000 sampling_batch_size: 1 loss_type: l1 #l2 +training_objective: x0 +loss_weighting: min_snr +min_snr_gamma: 5.0 beta_schedule: cosine #linear n_heads: 4 mlp_hidden_times: 4 From 9f77f1d7d38d04b1ca585defcac08263d650e309 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Tue, 3 Feb 2026 22:12:30 -0500 Subject: [PATCH 28/50] Added better run tracking --- .gitignore | 1 + cents/config/model/diffusion_ts.yaml | 2 +- cents/datasets/airquality.py | 6 ++- cents/datasets/commercial.py | 6 ++- cents/datasets/pecanstreet.py | 2 + cents/datasets/timeseries_dataset.py | 23 ++++++--- cents/datasets/vehicle.py | 2 + cents/models/diffusion_ts.py | 2 + cents/trainer.py | 44 ++++++++++++++++- scripts/train.py | 72 +++++++++++++++++++++------- 10 files changed, 131 insertions(+), 29 deletions(-) diff --git a/.gitignore b/.gitignore index be2515a..14c1d4e 100644 --- a/.gitignore +++ b/.gitignore @@ -106,6 +106,7 @@ ENV/ .*.swp # Repository Specific +runs/ cents/data/* cents/data/pecanstreet/* cents/data/commercial/* diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 6406413..f26d309 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -12,7 +12,7 @@ sampling_timesteps: 1000 sampling_batch_size: 4096 loss_type: l1 #l2 training_objective: eps -loss_weighting: default +loss_weighting: snr min_snr_gamma: 5.0 beta_schedule: cosine #linear diffusion ts paper uses linear schedule n_heads: 4 diff --git a/cents/datasets/airquality.py b/cents/datasets/airquality.py index 06d8207..b0ca2cd 100644 --- a/cents/datasets/airquality.py +++ b/cents/datasets/airquality.py @@ -17,8 +17,9 @@ class AirQualityDataset(TimeSeriesDataset): def __init__(self, cfg: DictConfig = None, - overrides: Optional[List[str]] = None, - force_retrain_normalizer: bool = False): + overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None): """ Initializes the AirQuality Dataset. Available at: https://doi.org/10.24432/C5RK5G. @@ -63,6 +64,7 @@ def __init__(self, cfg: DictConfig = None, size=cfg.get('max_samples', None), categorical_time_series=self.categorical_time_series, force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, ) def _load_data(self): diff --git a/cents/datasets/commercial.py b/cents/datasets/commercial.py index c84c6c4..43b2d15 100644 --- a/cents/datasets/commercial.py +++ b/cents/datasets/commercial.py @@ -15,9 +15,10 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) class CommercialDataset(TimeSeriesDataset): - def __init__(self, cfg: DictConfig = None, + def __init__(self, cfg: DictConfig = None, overrides: Optional[List[str]] = None, - force_retrain_normalizer: bool = False): + force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None): """ Initializes the commercial energy dataset. @@ -56,6 +57,7 @@ def __init__(self, cfg: DictConfig = None, skip_heavy_processing=cfg.get('skip_heavy_processing', False), size=cfg.get('max_samples', None), force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, ) def _load_data(self): diff --git a/cents/datasets/pecanstreet.py b/cents/datasets/pecanstreet.py index 2c551f8..c3e0022 100644 --- a/cents/datasets/pecanstreet.py +++ b/cents/datasets/pecanstreet.py @@ -34,6 +34,7 @@ def __init__( cfg: Optional[DictConfig] = None, overrides: Optional[List[str]] = None, force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None, ): """ Initialize and preprocess the PecanStreet dataset. @@ -86,6 +87,7 @@ def __init__( skip_heavy_processing=cfg.get('skip_heavy_processing', False), size=cfg.get('max_samples', None), force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, ) def _load_data(self) -> None: diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index be6f943..63cbfba 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -56,6 +56,7 @@ def __init__( size: int = None, categorical_time_series: Dict[str, int] = None, force_retrain_normalizer: bool = False, + run_dir: Any = None, ): # Initialize basic attributes # Handle OmegaConf ListConfig objects @@ -98,7 +99,8 @@ def __init__( self.normalize = normalize self.scale = scale self.force_retrain_normalizer = force_retrain_normalizer - + self.run_dir = Path(run_dir) if run_dir is not None else None + # Store categorical time series info self.categorical_time_series = categorical_time_series or {} @@ -571,11 +573,14 @@ def _ensure_rarity_computed(self): def _get_rarity_cache_path(self) -> str: """Get cache file path for rarity features.""" import hashlib - # Create a hash based on dataset characteristics for cache key context_cfg = get_context_config() context_module_type = context_cfg.static_context.type cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{str(sorted(self.context_vars))}_{context_module_type or ''}" cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8] + if self.run_dir is not None: + cache_dir = self.run_dir / "cache" / "rarity" + cache_dir.mkdir(parents=True, exist_ok=True) + return str(cache_dir / f"rarity_{cache_hash}.pkl") cache_dir = os.path.join(ROOT_DIR, "cache", "rarity") os.makedirs(cache_dir, exist_ok=True) return os.path.join(cache_dir, f"rarity_{cache_hash}.pkl") @@ -584,12 +589,15 @@ def _get_normalization_cache_path(self): """Get cache file path for normalized data.""" import hashlib from pathlib import Path - # Create hash based on dataset + normalizer characteristics context_cfg = get_context_config() context_module_type = context_cfg.dynamic_context.type stats_head_type = context_cfg.normalizer.stats_head_type cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{self.normalize}_{self.scale}_{context_module_type or ''}_{stats_head_type or ''}" cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8] + if self.run_dir is not None: + cache_dir = self.run_dir / "cache" / "normalized_data" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / f"normalized_{cache_hash}.pkl" cache_dir = Path(ROOT_DIR) / "cache" / "normalized_data" cache_dir.mkdir(parents=True, exist_ok=True) return cache_dir / f"normalized_{cache_hash}.pkl" @@ -642,9 +650,12 @@ def _init_normalizer(self) -> None: On first run, trains a new Normalizer and writes a single state dict to cache. On subsequent runs, loads that file. If loading fails, deletes the corrupted cache and retrains. """ - normalizer_dir = ( - Path.home() / ".cache" / "cents" / "checkpoints" / self.name / "normalizer" - ) + if self.run_dir is not None: + normalizer_dir = self.run_dir / "normalizer" + else: + normalizer_dir = ( + Path.home() / ".cache" / "cents" / "checkpoints" / self.name / "normalizer" + ) normalizer_dir.mkdir(parents=True, exist_ok=True) # Get context_module_type and stats_head_type from context config diff --git a/cents/datasets/vehicle.py b/cents/datasets/vehicle.py index 693478e..cbd1cb3 100644 --- a/cents/datasets/vehicle.py +++ b/cents/datasets/vehicle.py @@ -30,6 +30,7 @@ def __init__( cfg: Optional[DictConfig] = None, overrides: Optional[List[str]] = None, force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None, ): """ Initialize and preprocess the Vehicle dataset. @@ -74,6 +75,7 @@ def __init__( skip_heavy_processing=cfg.get('skip_heavy_processing', False), size=cfg.get('max_samples', None), force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, ) def _load_data(self) -> None: diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index d5af5d6..4dd3f0b 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -304,6 +304,7 @@ def predict_noise_from_v( def compute_snr_weights( + self, alphas_cumprod: torch.Tensor, *, loss_weighting: str, @@ -524,6 +525,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: }, prog_bar=True, sync_dist=True, + on_epoch=True, ) return total_loss diff --git a/cents/trainer.py b/cents/trainer.py index 5374944..ad660f7 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -1,3 +1,4 @@ +import csv from pathlib import Path from typing import Dict, List, Optional @@ -174,6 +175,8 @@ def _compose_cfg(self, ov: List[str]) -> DictConfig: if not hasattr(cfg, "run_dir") or not cfg.run_dir: timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") cfg.run_dir = str(PKG_ROOT / "outputs" / cfg.job_name / timestamp) + # Checkpoint dir: run_dir/checkpoints so run root stays clean + cfg.checkpoint_dir = str(Path(cfg.run_dir) / "checkpoints") return cfg def _instantiate_model(self): @@ -223,15 +226,19 @@ def _instantiate_trainer(self) -> pl.Trainer: filename_parts.append(f"stats{stats_head_type}") + checkpoint_dir = getattr(self.cfg, "checkpoint_dir", None) or str(Path(self.cfg.run_dir) / "checkpoints") + Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) callbacks.append( ModelCheckpoint( - dirpath=self.cfg.run_dir, + dirpath=checkpoint_dir, filename="_".join(filename_parts), save_last=tc.checkpoint.save_last, save_on_train_epoch_end=True, ### Perhaps excessive ) ) callbacks.append(EvalAfterTraining(self.cfg, self.dataset)) + if getattr(self.cfg, "run_dir", None): + callbacks.append(LogLossToCsv(self.cfg.run_dir)) logger = False if getattr(self.cfg, "wandb", None) and self.cfg.wandb.enabled: logger = WandbLogger( @@ -255,6 +262,41 @@ def _instantiate_trainer(self) -> pl.Trainer: ) +class LogLossToCsv(Callback): + """Append epoch loss values to runs//train_losses.csv.""" + + def __init__(self, run_dir: str): + super().__init__() + self.run_dir = Path(run_dir) + self._csv_path = self.run_dir / "train_losses.csv" + self._header_written = False + + def _ensure_header(self, metric_names: List[str]) -> None: + if self._header_written: + return + self.run_dir.mkdir(parents=True, exist_ok=True) + with open(self._csv_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["epoch"] + metric_names) + self._header_written = True + + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + metrics = trainer.callback_metrics + if not metrics: + return + # Filter to loss-like keys and sort for consistent column order + loss_keys = sorted(k for k in metrics if "loss" in k.lower()) + if not loss_keys: + return + self._ensure_header(loss_keys) + row = [trainer.current_epoch] + for k in loss_keys: + v = metrics[k] + row.append(float(v) if hasattr(v, "item") else float(v)) + with open(self._csv_path, "a", newline="") as f: + csv.writer(f).writerow(row) + + class EvalAfterTraining(Callback): """Run full evaluator at the *end* of training and log metrics to W&B.""" diff --git a/scripts/train.py b/scripts/train.py index 09cc81e..8006525 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,52 +1,84 @@ from datetime import datetime -import pandas as pd +import yaml +from pathlib import Path from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset from cents.datasets.airquality import AirQualityDataset from cents.datasets.vehicle import VehicleDataset from cents.trainer import Trainer -from cents.utils.utils import set_context_config_path, set_context_overrides -from pytorch_lightning.callbacks import EarlyStopping +from cents.utils.utils import set_context_config_path, set_context_overrides, get_context_config +from omegaconf import OmegaConf import warnings import argparse + warnings.simplefilter(action='ignore', category=FutureWarning) +PROJECT_ROOT = Path(__file__).resolve().parent.parent +RUNS_DIR = PROJECT_ROOT / "runs" + + +def _write_run_summary(run_dir: Path, run_name: str, trainer: Trainer) -> None: + """Write a summary YAML of run choices (context, model, dataset, trainer) to run_dir.""" + cfg = trainer.cfg + context_cfg = get_context_config() + summary = { + "run_name": run_name, + "run_dir": str(run_dir), + "dataset": OmegaConf.to_container(cfg.dataset, resolve=True) if hasattr(cfg, "dataset") and cfg.dataset else {}, + "model": OmegaConf.to_container(cfg.model, resolve=True) if hasattr(cfg, "model") and cfg.model else {}, + "context": OmegaConf.to_container(context_cfg, resolve=True) if context_cfg else {}, + "trainer": OmegaConf.to_container(cfg.trainer, resolve=True) if hasattr(cfg, "trainer") and cfg.trainer else {}, + } + path = run_dir / "summary.yaml" + with open(path, "w") as f: + yaml.dump(summary, f, default_flow_style=False, sort_keys=False) + print(f"[Cents] Wrote run summary to {path}") + def main(args) -> None: MODEL_NAME = args.model_name CR_LOSS_WEIGHT = args.cr_loss_weight TC_LOSS_WEIGHT = args.tc_loss_weight - + run_name = args.run_name + + # Create run directory under runs/ + RUNS_DIR.mkdir(parents=True, exist_ok=True) + run_dir = RUNS_DIR / run_name + run_dir.mkdir(parents=True, exist_ok=True) + print(f"[Cents] Run directory: {run_dir}") + # Set custom context config path if provided if args.context_config_path: set_context_config_path(args.context_config_path) - + # Set context config overrides if provided if args.context_overrides: set_context_overrides(args.context_overrides) - - # Skip heavy processing for DDP compatibility if args.dataset == "pecanstreet": dataset = PecanStreetDataset( overrides=[f"skip_heavy_processing={args.skip_heavy_processing}, time_series_dims=1, user_group=all"], - force_retrain_normalizer=args.force_retrain_normalizer + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), ) elif args.dataset == "commercial": dataset = CommercialDataset( overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"], - force_retrain_normalizer=args.force_retrain_normalizer + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), ) elif args.dataset == "airquality": dataset = AirQualityDataset( overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"], - force_retrain_normalizer=args.force_retrain_normalizer + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), ) elif args.dataset == "vehicle": dataset = VehicleDataset( overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"], - force_retrain_normalizer=args.force_retrain_normalizer + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), ) else: raise ValueError(f"Dataset {args.dataset} not supported") @@ -54,16 +86,17 @@ def main(args) -> None: print("Initialized Dataset") trainer_overrides = [ + f"run_dir={run_dir}", f"trainer.max_epochs={args.epochs}", f"trainer.strategy={args.ddp_strategy}", f"trainer.devices={args.devices}", f"trainer.eval_after_training={args.eval_after_training}", f"train.accelerator={args.accelerator}", - "trainer.early_stopping.patience=100", # Stop if no improvement for 100 epochs - "trainer.early_stopping.monitor=train_loss", # Monitor training loss - "trainer.early_stopping.mode=min", # Stop when loss stops decreasing - f"trainer.enable_checkpointing={args.enable_checkpointing}", # Explicitly enable checkpointing - "trainer.logger=False", # Disable logger to see checkpoint messages + "trainer.early_stopping.patience=100", + "trainer.early_stopping.monitor=train_loss", + "trainer.early_stopping.mode=min", + f"trainer.enable_checkpointing={args.enable_checkpointing}", + "trainer.logger=False", f"wandb.enabled={args.wandb_enabled}", f"wandb.project={args.wandb_project}", f"wandb.entity={args.wandb_entity}", @@ -78,6 +111,8 @@ def main(args) -> None: overrides=trainer_overrides, ) + _write_run_summary(run_dir, run_name, trainer) + trainer.fit(ckpt_path=args.resume_from_checkpoint) if __name__ == "__main__": @@ -108,7 +143,10 @@ def main(args) -> None: parser.add_argument("--force-retrain-normalizer", action="store_true", help="Force retraining of normalizer even if cached version exists") parser.add_argument("--resume-from-checkpoint", type=str, default=None, - help="Path to checkpoint file (.ckpt) to resume training from", + help="Path to checkpoint file (.ckpt) to resume training from", + ) + parser.add_argument("--run-name", type=str, required=True, + help="Name of this run. A directory runs/ will be created for checkpoints, cache, and summary.", ) args = parser.parse_args() From d87d2dba039d7865f46f89de2d159c25b31e80aa Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 9 Feb 2026 13:38:41 -0500 Subject: [PATCH 29/50] Added AdaLN for stronger conditioning --- cents/config/trainer/diffusion_ts.yaml | 2 +- cents/config/trainer/normalizer.yaml | 2 +- cents/models/diffusion_ts.py | 145 ++++++++----------------- cents/models/model_utils.py | 137 ++++++++++++++--------- cents/models/normalizer.py | 2 +- cents/trainer.py | 2 +- scripts/train.py | 3 +- 7 files changed, 139 insertions(+), 154 deletions(-) diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 961d3c1..c881f09 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -1,7 +1,7 @@ precision: "16-mixed" accelerator: auto devices: auto -strategy: ddp_find_unused_parameters_true +strategy: ddp_find_unused_parameters_false gradient_accumulate_every: 4 log_every_n_steps: 1 batch_size: 512 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index 1e78c27..f5d8daf 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,6 +1,6 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: 1 +devices: 2,3 precision: 16-mixed log_every_n_steps: 1 hidden_dim: 512 diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 4dd3f0b..7cd0374 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -71,13 +71,13 @@ def __init__(self, cfg: DictConfig): if not hasattr(self, 'static_context_module') and not hasattr(self, 'dynamic_context_module'): raise ValueError("At least one context module (static or dynamic) must be initialized") - # linear layer for denoised output + # linear layer for denoised output (no longer includes embedding_dim) self.fc = nn.Linear( - self.time_series_dims + self.embedding_dim, self.time_series_dims + self.time_series_dims, self.time_series_dims ) - # Transformer backbone + # Transformer backbone (now uses AdaLN conditioning instead of input concatenation) self.model = Transformer( - n_feat=self.time_series_dims + self.embedding_dim, + n_feat=self.time_series_dims, n_channel=self.seq_len, n_layer_enc=cfg.model.n_layer_enc, n_layer_dec=cfg.model.n_layer_dec, @@ -88,6 +88,7 @@ def __init__(self, cfg: DictConfig): max_len=self.seq_len, n_embd=cfg.model.d_model, conv_params=[cfg.model.kernel_size, cfg.model.padding_size], + cond_dim=self.embedding_dim, ) # EMA helper will be initialized on train start @@ -223,11 +224,11 @@ def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict } dynamic_embedding, dynamic_logits = self.dynamic_context_module(dynamic_vars) # Check for NaN in dynamic embedding - if torch.isnan(dynamic_embedding).any() or torch.isinf(dynamic_embedding).any(): - raise ValueError( - f"NaN/Inf detected in dynamic embedding. " - f"Dynamic vars: {list(dynamic_vars.keys())}" - ) + # if torch.isnan(dynamic_embedding).any() or torch.isinf(dynamic_embedding).any(): + # raise ValueError( + # f"NaN/Inf detected in dynamic embedding. " + # f"Dynamic vars: {list(dynamic_vars.keys())}" + # ) embeddings.append(dynamic_embedding) all_logits.update(dynamic_logits) @@ -392,28 +393,28 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di # if x.abs().max() > 100.0: # print(f"[Warning] Input x has extreme values: min={x.min():.4f}, max={x.max():.4f}, " # f"mean={x.mean():.4f}, std={x.std():.4f}, shape={x.shape}") - if torch.isnan(x).any() or torch.isinf(x).any(): - raise ValueError(f"NaN/Inf detected in input x. Shape: {x.shape}, " - f"NaN count: {torch.isnan(x).sum()}, Inf count: {torch.isinf(x).sum()}") + # if torch.isnan(x).any() or torch.isinf(x).any(): + # raise ValueError(f"NaN/Inf detected in input x. Shape: {x.shape}, " + # f"NaN count: {torch.isnan(x).sum()}, Inf count: {torch.isinf(x).sum()}") b = x.shape[0] # t = torch.randint(0, self.num_timesteps, (b,), device=self.device) t = self.stratified_timesteps(b, self.num_timesteps, self.cfg.model.k_bins, device=self.device) embedding, cond_classification_logits = self._get_context_embedding(context_vars) # Check embedding for NaN/Inf - if embedding.isnan().any() or embedding.isinf().any(): - raise ValueError( - f"NaN/Inf detected in embedding from context module. " - f"NaN count: {embedding.isnan().sum()}, Inf count: {embedding.isinf().sum()}, " - f"shape: {embedding.shape}, min: {embedding.min()}, max: {embedding.max()}" - ) + # if embedding.isnan().any() or embedding.isinf().any(): + # raise ValueError( + # f"NaN/Inf detected in embedding from context module. " + # f"NaN count: {embedding.isnan().sum()}, Inf count: {embedding.isinf().sum()}, " + # f"shape: {embedding.shape}, min: {embedding.min()}, max: {embedding.max()}" + # ) # Embedding should now be normalized by the context module (mean=0, std=1 per sample) # Check that values are in reasonable range - if embedding.abs().max() > 100.0: - print(f"[Warning] Embedding has large values despite normalization: " - f"min={embedding.min():.4f}, max={embedding.max():.4f}, " - f"mean={embedding.mean():.4f}, std={embedding.std():.4f}") + # if embedding.abs().max() > 100.0: + # print(f"[Warning] Embedding has large values despite normalization: " + # f"min={embedding.min():.4f}, max={embedding.max():.4f}, " + # f"mean={embedding.mean():.4f}, std={embedding.std():.4f}") # Check diffusion schedule parameters noise = torch.randn_like(x) @@ -421,27 +422,9 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * noise ) - if x_noisy.isnan().any() or x_noisy.isinf().any(): - raise ValueError(f"NaN/Inf detected in x_noisy. Shape: {x_noisy.shape}, " - f"NaN count: {torch.isnan(x_noisy).sum()}, Inf count: {torch.isinf(x_noisy).sum()}") - # Use normalized embedding for concatenation - embedding_expanded = embedding.unsqueeze(1).repeat(1, self.seq_len, 1) - c = torch.cat([x_noisy, embedding_expanded], dim=-1) - if c.isnan().any() or c.isinf().any(): - raise ValueError(f"NaN/Inf detected in concatenated input 'c'. " - f"Shape: {c.shape}, x_noisy stats: min={x_noisy.min():.4f}, max={x_noisy.max():.4f}, " - f"embedding stats: min={embedding.min():.4f}, max={embedding.max():.4f}") - # if c.isnan().any() or c.isinf().any(): - # raise ValueError( - # f"NaN/Inf detected in concatenated input 'c'. " - # f"x_noisy stats: mean={x_noisy.mean():.4f}, std={x_noisy.std():.4f}, " - # f"min={x_noisy.min():.4f}, max={x_noisy.max():.4f}. " - # f"embedding stats: mean={embedding.mean():.4f}, " - # f"std={embedding.std():.4f}, min={embedding.min():.4f}, " - # f"max={embedding.max():.4f}" - # ) - trend, season = self.model(c, t, padding_masks=None) - x_start_pred = self.fc(trend + season) + # Pass embedding as cond parameter instead of concatenating to input + trend, season = self.model(x_noisy, t, padding_masks=None, cond=embedding) + x_start_pred = self.fc((trend + season).contiguous()) # Compute loss based on training objective (network always predicts x0; we derive epsilon/v as needed) if self.training_objective == "x0": loss_per_elem = self.recon_loss_fn(x_start_pred, x, reduction="none") @@ -510,11 +493,11 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: ) # Check for NaN in total loss - if torch.isnan(total_loss) or torch.isinf(total_loss): - raise ValueError( - f"NaN/Inf detected in total_loss at batch {batch_idx}. " - f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss:.6f}, tc_term: {tc_term.item():.6f}" - ) + # if torch.isnan(total_loss) or torch.isinf(total_loss): + # raise ValueError( + # f"NaN/Inf detected in total_loss at batch {batch_idx}. " + # f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss:.6f}, tc_term: {tc_term.item():.6f}" + # ) self.log_dict( { @@ -556,52 +539,21 @@ def on_train_start(self) -> None: update_every=self.cfg.model.ema_update_interval, ) - def on_after_backward(self) -> None: - """ - Check gradients after backward pass but before optimizer step. - This is the right place to inspect gradients before they're zeroed. - """ - # Get current batch index from trainer - if not hasattr(self.trainer, 'global_step'): - return - - batch_idx = self.trainer.global_step - - # Debug: Check if context module parameters are getting gradients - # Check AFTER backward pass but BEFORE optimizer step (only log occasionally) - if batch_idx % 50 == 0: - context_params_with_grad = [] - context_params_no_grad = [] - if self.static_context_module is not None: - for name, param in self.static_context_module.named_parameters(): - if param.requires_grad: - if param.grad is not None: - grad_norm = param.grad.norm().item() - # Check for NaN/Inf gradients - if torch.isnan(param.grad).any() or torch.isinf(param.grad).any(): - print(f"[Warning] NaN/Inf gradients detected in {name}") - else: - context_params_with_grad.append((name, grad_norm)) - else: - context_params_no_grad.append(name) - - if context_params_no_grad: - # Group by variable name to identify which context variables are missing - missing_vars = set() - for param_name in context_params_no_grad: - # Extract variable name from parameter name (e.g., "context_embeddings.year.weight" -> "year") - parts = param_name.split('.') - if len(parts) >= 2 and parts[0] in ['context_embeddings', 'init_mlps']: - missing_vars.add(parts[1]) - print(f"[Warning] {len(context_params_no_grad)} context module parameters have no gradients!") - if missing_vars: - pass - # print(f" Missing context variables: {sorted(missing_vars)}") - # print(f" No grad params (sample): {context_params_no_grad[:5]}...") - # if context_params_with_grad: - # avg_grad_norm = sum(g[1] for g in context_params_with_grad) / len(context_params_with_grad) - # max_grad_norm = max(g[1] for g in context_params_with_grad) - # print(f"[Debug] Context module gradients: avg_norm={avg_grad_norm:.6f}, max_norm={max_grad_norm:.6f}") + # def on_after_backward(self) -> None: + # """ + # Check gradients after backward pass but before optimizer step. + # This is the right place to inspect gradients before they're zeroed. + # """ + # # Get current batch index from trainer + # for name, p in self.named_parameters(): + # if p.grad is None: + # continue + # if p.grad.stride() != p.stride(): + # print("stride mismatch:", name, + # "param", tuple(p.shape), p.stride(), + # "grad", tuple(p.grad.shape), p.grad.stride()) + # break + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: """ @@ -655,9 +607,8 @@ def model_predictions( pred_noise: predicted noise tensor. x_start: predicted clean sample tensor. """ - c = torch.cat([x, embedding.unsqueeze(1).repeat(1, self.seq_len, 1)], dim=-1) - trend, season = self.model(c, t, padding_masks=None) - x_start = self.fc(trend + season) + trend, season = self.model(x, t, padding_masks=None, cond=embedding) + x_start = self.fc((trend + season).contiguous()) pred_noise = self.predict_noise_from_start(x, t, x_start) return pred_noise, x_start diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index 28c1b61..c6b3920 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -212,14 +212,18 @@ def forward(self, x): class Conv_MLP(nn.Module): def __init__(self, in_dim, out_dim, resid_pdrop=0.0): super().__init__() - self.sequential = nn.Sequential( - Transpose(shape=(1, 2)), - nn.Conv1d(in_dim, out_dim, 3, stride=1, padding=1), - nn.Dropout(p=resid_pdrop), - ) + self.conv = nn.Conv1d(in_dim, out_dim, 3, stride=1, padding=1) + if self.conv.weight.requires_grad: + self.conv.weight.register_hook(lambda grad: grad.contiguous()) + self.drop = nn.Dropout(p=resid_pdrop) def forward(self, x): - return self.sequential(x).transpose(1, 2) + # x: (B, T, C) + x = x.transpose(1, 2).contiguous() # (B, C, T) contiguous + x = self.conv(x) + x = self.drop(x) + return x.transpose(1, 2).contiguous() # back to (B, T, C), contiguous + class Transformer_MLP(nn.Module): @@ -264,15 +268,13 @@ def forward(self, x): class AdaLayerNorm(nn.Module): def __init__(self, n_embd): super().__init__() - self.emb = SinusoidalPosEmb(n_embd) + # self.emb = SinusoidalPosEmb(n_embd) self.silu = nn.SiLU() self.linear = nn.Linear(n_embd, n_embd * 2) self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False) - def forward(self, x, timestep, label_emb=None): - emb = self.emb(timestep) - if label_emb is not None: - emb = emb + label_emb + def forward(self, x, emb): + # emb: (B, n_embd) - Pre-computed and combined (time + label) emb = self.linear(self.silu(emb)).unsqueeze(1) scale, shift = torch.chunk(emb, 2, dim=2) x = self.layernorm(x) * (1 + scale) + shift @@ -324,6 +326,8 @@ def __init__(self, in_dim, out_dim, in_feat, out_feat, act): def forward(self, input): b, c, h = input.shape + if not input.is_contiguous(): + input = input.contiguous() x = self.trend(input).transpose(1, 2) trend_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device)) trend_vals = trend_vals.transpose(1, 2) @@ -552,16 +556,10 @@ def __init__( activate="GELU", ): super().__init__() - self.ln1 = AdaLayerNorm(n_embd) - self.ln2 = nn.LayerNorm(n_embd) - self.attn = FullAttention( - n_embd=n_embd, - n_head=n_head, - attn_pdrop=attn_pdrop, - resid_pdrop=resid_pdrop, - ) - + self.ln2 = AdaLayerNorm(n_embd) + self.attn = FullAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + assert activate in ["GELU", "GELU2"] act = nn.GELU() if activate == "GELU" else GELU2() @@ -572,10 +570,11 @@ def __init__( nn.Dropout(resid_pdrop), ) - def forward(self, x, timestep, mask=None, label_emb=None): - a, att = self.attn(self.ln1(x, timestep, label_emb), mask=mask) + def forward(self, x, cond_emb, mask=None): + # cond_emb is the combined time+label embedding + a, att = self.attn(self.ln1(x, cond_emb), mask=mask) x = x + a - x = x + self.mlp(self.ln2(x)) # only one really use encoder_output + x = x + self.mlp(self.ln2(x, cond_emb)) return x, att @@ -606,10 +605,10 @@ def __init__( ] ) - def forward(self, input, t, padding_masks=None, label_emb=None): + def forward(self, input, cond_emb, padding_masks=None): x = input - for block_idx in range(len(self.blocks)): - x, _ = self.blocks[block_idx](x, t, mask=padding_masks, label_emb=label_emb) + for block in self.blocks: + x, _ = block(x, cond_emb, mask=padding_masks) return x @@ -631,7 +630,7 @@ def __init__( super().__init__() self.ln1 = AdaLayerNorm(n_embd) - self.ln2 = nn.LayerNorm(n_embd) + self.ln2 = AdaLayerNorm(n_embd) # Changed from nn.LayerNorm to AdaLayerNorm self.attn1 = FullAttention( n_embd=n_embd, @@ -667,14 +666,24 @@ def __init__( self.proj = nn.Conv1d(n_channel, n_channel * 2, 1) self.linear = nn.Linear(n_embd, n_feat) - def forward(self, x, encoder_output, timestep, mask=None, label_emb=None): - a, att = self.attn1(self.ln1(x, timestep, label_emb), mask=mask) + def forward(self, x, encoder_output, cond_emb, mask=None): + a, att = self.attn1(self.ln1(x, cond_emb), mask=mask) x = x + a - a, att = self.attn2(self.ln1_1(x, timestep), encoder_output, mask=mask) + a, att = self.attn2(self.ln1_1(x, cond_emb), encoder_output, mask=mask) x = x + a - x1, x2 = self.proj(x).chunk(2, dim=1) + + # FIX: chunk() returns views that are often non-contiguous. + # Since self.proj and self.trend use Conv1d, this causes the DDP stride mismatch. + x_proj = self.proj(x) + x1, x2 = x_proj.chunk(2, dim=1) + + # Make contiguous before passing to specialized blocks + x1 = x1.contiguous() + x2 = x2.contiguous() + trend, season = self.trend(x1), self.seasonal(x2) - x = x + self.mlp(self.ln2(x)) + + x = x + self.mlp(self.ln2(x, cond_emb)) m = torch.mean(x, dim=1, keepdim=True) return x - m, self.linear(m), trend, season @@ -713,15 +722,19 @@ def __init__( ] ) - def forward(self, x, t, enc, padding_masks=None, label_emb=None): + def forward(self, x, cond_emb, enc, padding_masks=None): b, c, _ = x.shape - # att_weights = [] mean = [] - season = torch.zeros((b, c, self.d_model), device=x.device) - trend = torch.zeros((b, c, self.n_feat), device=x.device) - for block_idx in range(len(self.blocks)): - x, residual_mean, residual_trend, residual_season = self.blocks[block_idx]( - x, enc, t, mask=padding_masks, label_emb=label_emb + # Initialize accumulating tensors on the correct device + season = torch.zeros((b, c, x.shape[-1]), device=x.device) + # Note: Check if season dim is n_embd or n_feat. + # FourierLayer returns same dim as input x (n_embd) + + trend = torch.zeros((b, c, self.blocks[0].linear.out_features), device=x.device) + + for block in self.blocks: + x, residual_mean, residual_trend, residual_season = block( + x, enc, cond_emb, mask=padding_masks ) season += residual_season trend += residual_trend @@ -746,11 +759,21 @@ def __init__( block_activate="GELU", max_len=2048, conv_params=None, + cond_dim=None, **kwargs ): super().__init__() self.emb = Conv_MLP(n_feat, n_embd, resid_pdrop=resid_pdrop) self.inverse = Conv_MLP(n_embd, n_feat, resid_pdrop=resid_pdrop) + + self.cond_dim = cond_dim + if cond_dim is not None: + # Map context embedding (B, cond_dim) -> (B, n_embd) + self.cond_proj = nn.Linear(cond_dim, n_embd) + else: + self.cond_proj = None + + self.time_emb = SinusoidalPosEmb(n_embd) if conv_params is None or conv_params[0] is None: if n_feat < 32 and n_channel < 64: @@ -808,29 +831,39 @@ def __init__( n_embd, dropout=resid_pdrop, max_len=max_len ) - def forward(self, input, t, padding_masks=None, return_res=False): + def forward(self, input, t, padding_masks=None, return_res=False, cond=None): + # cond: (B, cond_dim) or None + t_emb = self.time_emb(t) + + label_emb = None + if (cond is not None) and (self.cond_proj is not None): + label_emb = self.cond_proj(cond) # (B, n_embd) + # Add them up here to pass a single vector down + total_cond_emb = t_emb + label_emb + else: + total_cond_emb = t_emb + emb = self.emb(input) inp_enc = self.pos_enc(emb) - enc_cond = self.encoder(inp_enc, t, padding_masks=padding_masks) + + enc_cond = self.encoder(inp_enc, total_cond_emb, padding_masks=padding_masks) inp_dec = self.pos_dec(emb) output, mean, trend, season = self.decoder( - inp_dec, t, enc_cond, padding_masks=padding_masks + inp_dec, total_cond_emb, enc_cond, padding_masks=padding_masks ) res = self.inverse(output) - res_m = torch.mean(res, dim=1, keepdim=True) - season_error = ( - self.combine_s(season.transpose(1, 2)).transpose(1, 2) + res - res_m - ) - trend = self.combine_m(mean) + res_m + trend + + # .contiguous() usage here was correct in your original code + res_m = torch.mean(res, dim=1, keepdim=True).contiguous() + combine_m_out = self.combine_m(mean).contiguous() + combine_s_out = self.combine_s(season.transpose(1, 2)).transpose(1, 2).contiguous() + season_error = (combine_s_out + res - res_m).contiguous() + trend = (combine_m_out + res_m + trend).contiguous() if return_res: - return ( - trend, - self.combine_s(season.transpose(1, 2)).transpose(1, 2), - res - res_m, - ) + return trend, combine_s_out, res - res_m return trend, season_error diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 613a991..3f61e82 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -644,7 +644,7 @@ def train_dataloader(self): batch_size=self.normalizer_training_cfg.batch_size, shuffle=True, num_workers=4, # Use fewer workers to reduce overhead - persistent_workers=False, # Disable to avoid multiprocessing cleanup issues + persistent_workers=True, # Disable to avoid multiprocessing cleanup issues pin_memory=torch.cuda.is_available(), # Helps with GPU transfer prefetch_factor=2, # Reduce prefetch to avoid memory issues ) diff --git a/cents/trainer.py b/cents/trainer.py index ad660f7..5fab43b 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -88,7 +88,7 @@ def fit(self, ckpt_path: Optional[str] = None) -> "Trainer": train_loader = self.dataset.get_train_dataloader( batch_size=self.cfg.trainer.batch_size, shuffle=True, - num_workers=6, # Maximum for 7.5GB/10GB GPU usage + num_workers=4, # Maximum for 7.5GB/10GB GPU usage persistent_workers=True, ) print(f"[Cents] Training model on {len(train_loader)} batches") diff --git a/scripts/train.py b/scripts/train.py index 8006525..b5c0471 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -133,7 +133,8 @@ def main(args) -> None: help="Evaluate after training") parser.add_argument("--skip_heavy_processing", action="store_true", help="Skip heavy processing of dataset") - parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_true") + parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_false", + help="DDP strategy; use 'ddp' (no find_unused_parameters) for best perf") parser.add_argument("--enable_checkpointing", action="store_true", help="Enable checkpointing") parser.add_argument("--context-config-path", type=str, default=None, From 606be7c7e258b0a5b79c2824e6d6334bc457a4a8 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 9 Feb 2026 13:43:00 -0500 Subject: [PATCH 30/50] cleanup --- cents/config/dataset/vehicle.yaml | 24 ----- cents/datasets/vehicle.py | 150 ------------------------------ 2 files changed, 174 deletions(-) delete mode 100644 cents/config/dataset/vehicle.yaml delete mode 100644 cents/datasets/vehicle.py diff --git a/cents/config/dataset/vehicle.yaml b/cents/config/dataset/vehicle.yaml deleted file mode 100644 index 672d928..0000000 --- a/cents/config/dataset/vehicle.yaml +++ /dev/null @@ -1,24 +0,0 @@ -name: vehicle -normalize: False -scale: False -use_learned_normalizer: True -seq_len: 15 # 1.5 s -time_series_dims: 6 -shuffle: True -skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP) -max_samples: null # Limit dataset size (null = use all data) -path: "./data/vehicle" -time_series_columns: ["Acceleration_pedal_depth", "Vehicle_speed", "Brake_pedal_depth", "Vehicle_acceleration", "VCU_MotTqCmd", "MCU_MotActTq"] -data_columns: ["Time", "Acceleration_pedal_depth", "Vehicle_speed", "Brake_pedal_depth", "Vehicle_acceleration", "VCU_MotTqCmd", "MCU_MotActTq"] -num_ts_steps: 5 #.5s -numeric_context_bins: null - -context_vars: - context_Acceleration_pedal_depth: ["time_series", null] - context_Vehicle_speed: ["time_series", null] - context_Brake_pedal_depth: ["time_series", null] - context_Vehicle_acceleration: ["time_series", null] - context_VCU_MotTqCmd: ["time_series", null] - context_MCU_MotActTq: ["time_series", null] - - diff --git a/cents/datasets/vehicle.py b/cents/datasets/vehicle.py deleted file mode 100644 index cbd1cb3..0000000 --- a/cents/datasets/vehicle.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import warnings -from typing import Any, Dict, List, Optional - -import numpy as np -import pandas as pd -from omegaconf import DictConfig -from cents.utils.config_loader import load_yaml, apply_overrides - -from cents.datasets.timeseries_dataset import TimeSeriesDataset - -warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) -ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - - -class VehicleDataset(TimeSeriesDataset): - """ - Dataset class for Vehicle time series data. - - Handles loading, preprocessin, including normalization and context variables. - - Attributes: - cfg (DictConfig): Hydra config for the dataset. - name (str): Dataset name. - normalize (bool): Whether to apply normalization. - """ - - def __init__( - self, - cfg: Optional[DictConfig] = None, - overrides: Optional[List[str]] = None, - force_retrain_normalizer: bool = False, - run_dir: Optional[str] = None, - ): - """ - Initialize and preprocess the Vehicle dataset. - - Loads metadata and timeseries CSVs, then applies filtering, - grouping, user-subsetting, and calls the base class for - further preprocessing (normalization, merging, rarity flags). - - Args: - cfg (Optional[DictConfig]): Override Hydra config; if None, - load from `config/dataset/vehicle.yaml`. - overrides (Optional[List[str]]): Override Hydra config; if None, - load from `config/dataset/vehicle.yaml` and apply overrides. - - Raises: - FileNotFoundError: If required CSV files are missing. - """ - if cfg is None: - cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "vehicle.yaml")) - if overrides: - cfg = apply_overrides(cfg, overrides) - - self.cfg = cfg - self.name = cfg.name - self.normalize = cfg.normalize - self.time_series_dims = cfg.time_series_dims - self.num_ts_steps = cfg.num_ts_steps - self.seq_len = self.cfg.seq_len - - self._load_data() - - ts_cols: List[str] = self.cfg.time_series_columns[: self.time_series_dims] - - - super().__init__( - data=self.data, - time_series_column_names=ts_cols, - context_var_column_names=list(self.cfg.context_vars.keys()), - seq_len=self.cfg.seq_len, - normalize=self.cfg.normalize, - scale=self.cfg.scale, - skip_heavy_processing=cfg.get('skip_heavy_processing', False), - size=cfg.get('max_samples', None), - force_retrain_normalizer=force_retrain_normalizer, - run_dir=run_dir, - ) - - def _load_data(self) -> None: - """ - Load . - - Populates self.data DataFrame. - - Raises: - FileNotFoundError: If any required CSV file is missing. - """ - module_dir = os.path.dirname(os.path.abspath(__file__)) - path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) - self.data = pd.read_csv(os.path.join(path, "vehicle_signal_data.csv")) - - - def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: - ''' - Convert timestamps, assemble sequences of length seq_len, and merge metadata. - - Args: - data (pd.DataFrame): Raw concatenated grid (and solar) rows. - - Returns: - pd.DataFrame: One row per sequence, with array-valued 'grid' and - ''' - - # Assemble sequences of length seq_len, each with a prefix context of length self.num_ts_steps - # Time remains raw seconds, do not convert to timestamps. - time_series_cols = self.cfg.time_series_columns[: self.time_series_dims] - context_var_names = list(self.cfg.context_vars.keys()) - data = data.sort_values("Time").reset_index(drop=True) # ensure increasing raw seconds - - # Only build full (context+target) window sequences that fit fully within data - total_window = self.num_ts_steps + self.seq_len - rolling_idxs = ( - pd.Series(np.arange(len(data))) - .rolling(window=total_window) - .apply(lambda x: x[0], raw=True) - .dropna() - .index - ) - - # Preallocate arrays for sequences, for efficiency - out = {col: [] for col in time_series_cols} - for cvar in context_var_names: - out[f"{cvar}"] = [] - out["context_time"] = [] - - for idx in rolling_idxs: - window_slice = data.iloc[idx - total_window + 1 : idx + 1] - context_slice = window_slice.iloc[:self.num_ts_steps] - target_slice = window_slice.iloc[self.num_ts_steps:] - - # Store target sequences - for col in time_series_cols: - out[col].append(target_slice[col].to_numpy()) - - # Store context as array(s) - for cvar in time_series_cols: - # Context for each variable over the context window - if cvar in context_slice.columns: - out[f"context_{cvar}"].append(context_slice[cvar].to_numpy()) - else: - out[f"context_{cvar}"].append([None] *self.num_ts_steps) - - # Optionally, keep the raw "Time" for context window (useful to recover absolute position/relative time) - out["context_time"].append(context_slice["Time"].to_numpy()) - - out_df = pd.DataFrame(out) - return out_df - \ No newline at end of file From bd5d13db9e0f27981f99557ab5ab414cf6706acf Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 9 Feb 2026 22:49:58 -0500 Subject: [PATCH 31/50] fixes to eval --- cents/config/context/default.yaml | 2 +- cents/config/dataset/default.yaml | 28 ++--- cents/data_generator.py | 2 + cents/datasets/timeseries_dataset.py | 12 ++ cents/eval/eval.py | 3 + cents/models/diffusion_ts.py | 58 ++++----- scripts/eval_pretrained.py | 181 +++++++++++++++++++++------ scripts/train.py | 35 ++++-- 8 files changed, 235 insertions(+), 86 deletions(-) diff --git a/cents/config/context/default.yaml b/cents/config/context/default.yaml index 6327799..d2585ac 100644 --- a/cents/config/context/default.yaml +++ b/cents/config/context/default.yaml @@ -3,7 +3,7 @@ # Static context: used by generative models (ACGAN, Diffusion_TS) for conditioning static_context: - type: mlp # Context module type (e.g., "mlp", "sep_mlp") + type: sep_mlp # Context module type (e.g., "mlp", "sep_mlp") # Future parameters can be added here: # n_layers: 2 # hidden_dim: 256 diff --git a/cents/config/dataset/default.yaml b/cents/config/dataset/default.yaml index 3317216..14f65d8 100644 --- a/cents/config/dataset/default.yaml +++ b/cents/config/dataset/default.yaml @@ -1,15 +1,15 @@ -name: default -normalize: True -scale: True -use_learned_normalizer: True -shuffle: True -threshold: 6 -time_series_dims: 1 -time_series_columns: [] -seq_len: 8 -user_group: null +# name: default +# normalize: True +# scale: True +# use_learned_normalizer: True +# shuffle: True +# threshold: 6 +# time_series_dims: 1 +# time_series_columns: [] +# seq_len: 8 +# user_group: null -numeric_context_bins: 5 -context_vars: {} # Dict mapping variable names to category counts (for categorical) or placeholders (for continuous) -continuous_context_vars: [] # Optional: list of variable names that should be kept as continuous (not binned) -stats_head_type: mlp +# numeric_context_bins: 5 +# context_vars: {} # Dict mapping variable names to category counts (for categorical) or placeholders (for continuous) +# continuous_context_vars: [] # Optional: list of variable names that should be kept as continuous (not binned) +# stats_head_type: mlp diff --git a/cents/data_generator.py b/cents/data_generator.py index 58836c2..10f3c46 100644 --- a/cents/data_generator.py +++ b/cents/data_generator.py @@ -220,6 +220,8 @@ def load_from_checkpoint( ckpt_path, state = self._resolve_ckpt(model_ckpt) ModelCls = get_model_cls(self.model_type) + print(self.cfg) + if ckpt_path.suffix == ".ckpt": print(f"[Cents] Loading model from checkpoint: {ckpt_path}") self.model = ( diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 63cbfba..34e670d 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -332,6 +332,18 @@ def inverse_transform( df = self._normalizer.inverse_transform(df) return self.merge_timeseries_columns(df) if merged else df + def apply_pretrained_normalizer(self) -> None: + """ + Transform self.data with the attached pretrained normalizer (e.g. after + dataset._normalizer = data_generator.normalizer). Use when normalize=False + was set to avoid training but you still want real data in normalized space. + """ + if getattr(self, "_normalizer", None) is None: + return + df_split = self.split_timeseries(self.data.copy()) + self.data = self.merge_timeseries_columns(self._normalizer.transform(df_split)) + self.data = self.data.reset_index(drop=True) + def _encode_context_vars( self, data: pd.DataFrame ) -> Tuple[pd.DataFrame, Dict[str, Any]]: diff --git a/cents/eval/eval.py b/cents/eval/eval.py index 3b9334b..9e6d6c0 100644 --- a/cents/eval/eval.py +++ b/cents/eval/eval.py @@ -110,6 +110,9 @@ def evaluate_model( if data_generator.normalizer is not None: dataset._normalizer = data_generator.normalizer print("[Cents] Using pre-trained normalizer from DataGenerator") + if not getattr(dataset.cfg, "normalize", True): + dataset.apply_pretrained_normalizer() + print("[Cents] Normalized dataset with pretrained normalizer") else: if not model: model = self.get_trained_model(dataset) diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 7cd0374..5c72bf9 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -562,39 +562,39 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: if hasattr(self, '_ema') and self._ema: self._ema.update() - def on_load_checkpoint(self, checkpoint: dict) -> None: - """ - Restore EMA weights from checkpoint after loading. - """ - super().on_load_checkpoint(checkpoint) + # def on_load_checkpoint(self, checkpoint: dict) -> None: + # """ + # Restore EMA weights from checkpoint after loading. + # """ + # super().on_load_checkpoint(checkpoint) - # Check if EMA weights exist in checkpoint - state_dict = checkpoint.get('state_dict', {}) - ema_keys = [key for key in state_dict.keys() if key.startswith('_ema.')] + # # Check if EMA weights exist in checkpoint + # state_dict = checkpoint.get('state_dict', {}) + # ema_keys = [key for key in state_dict.keys() if key.startswith('_ema.')] - if ema_keys: - if not hasattr(self, '_ema') or self._ema is None: - self._ema = EMA( - self.model, - beta=self.cfg.model.ema_decay, - update_every=self.cfg.model.ema_update_interval, - ) + # if ema_keys: + # if not hasattr(self, '_ema') or self._ema is None: + # self._ema = EMA( + # self.model, + # beta=self.cfg.model.ema_decay, + # update_every=self.cfg.model.ema_update_interval, + # ) - # Load EMA weights into the EMA helper - ema_state_dict = {} - for key, value in state_dict.items(): - if key.startswith('_ema.ema_model.'): - # Map '_ema.ema_model.*' -> 'ema_model.*' (remove the _ema prefix) - ema_key = key.replace('_ema.ema_model.', 'ema_model.') - ema_state_dict[ema_key] = value + # # Load EMA weights into the EMA helper + # ema_state_dict = {} + # for key, value in state_dict.items(): + # if key.startswith('_ema.ema_model.'): + # # Map '_ema.ema_model.*' -> 'ema_model.*' (remove the _ema prefix) + # ema_key = key.replace('_ema.ema_model.', 'ema_model.') + # ema_state_dict[ema_key] = value - if ema_state_dict: - print(f"Loading {len(ema_state_dict)} EMA weights from checkpoint") - self._ema.ema_model.load_state_dict(ema_state_dict, strict=False) - else: - raise ValueError("No EMA model weights found in checkpoint") - else: - raise ValueError("No EMA keys found in checkpoint") + # if ema_state_dict: + # print(f"Loading {len(ema_state_dict)} EMA weights from checkpoint") + # self._ema.ema_model.load_state_dict(ema_state_dict, strict=False) + # else: + # raise ValueError("No EMA model weights found in checkpoint") + # else: + # raise ValueError("No EMA keys found in checkpoint") @torch.no_grad() def model_predictions( diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 34b6823..b126361 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -1,6 +1,8 @@ import logging import os from pathlib import Path +from typing import Tuple + import torch import torch.nn.functional as F @@ -8,11 +10,12 @@ import argparse from cents.data_generator import DataGenerator +from cents.models.registry import get_model_type_from_hf_name from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset from cents.datasets.airquality import AirQualityDataset from cents.eval.eval import Evaluator -from cents.utils.config_loader import load_yaml +from cents.utils.config_loader import load_yaml, apply_overrides from cents.utils.utils import set_context_config_path logging.basicConfig( @@ -22,15 +25,66 @@ DATASET_OVERRIDES = ["max_samples=10000", "skip_heavy_processing=True"] PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] +CONFIG_DATASET_DIR = Path(__file__).resolve().parent.parent / "cents" / "config" / "dataset" + + +def _load_dataset_config(dataset_name: str, overrides: list) -> OmegaConf: + """Load dataset-specific config from config/dataset/{dataset_name}.yaml and apply overrides.""" + config_path = CONFIG_DATASET_DIR / f"{dataset_name}.yaml" + if not config_path.exists(): + raise ValueError( + f"Dataset config not found for '{dataset_name}' at {config_path}. " + f"Available: {[p.name for p in CONFIG_DATASET_DIR.glob('*.yaml')]}" + ) + cfg = load_yaml(str(config_path)) + if overrides: + cfg = apply_overrides(cfg, overrides) + return cfg + + +def _infer_dataset_shape_from_ckpt( + ckpt_path: str, cond_emb_dim: int +) -> Tuple[int, int]: + """ + Infer seq_len and time_series_dims from a Diffusion_TS checkpoint state_dict + so the model can be built with the same architecture as when the checkpoint was saved. + """ + ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) + state_dict = ckpt.get("state_dict", ckpt) + # Keys may be "model.pos_enc.pe" (Lightning) or "pos_enc.pe" (raw) + for pe_key in ("model.pos_enc.pe", "pos_enc.pe"): + if pe_key in state_dict: + # shape (1, seq_len, d_model) + seq_len = int(state_dict[pe_key].shape[1]) + break + else: + raise ValueError( + "Could not infer seq_len from checkpoint (no pos_enc.pe key in state_dict)" + ) + # combine_s: Conv1d(n_embd, n_feat, ...) -> weight shape (n_feat, n_embd, k) + # n_feat = time_series_dims + cond_emb_dim + for cs_key in ("model.combine_s.weight", "combine_s.weight"): + if cs_key in state_dict: + n_feat = int(state_dict[cs_key].shape[0]) + time_series_dims = n_feat - cond_emb_dim + if time_series_dims < 1: + time_series_dims = 1 + break + else: + raise ValueError( + "Could not infer time_series_dims from checkpoint (no combine_s.weight in state_dict)" + ) + return seq_len, time_series_dims + -def _load_dataset(name: str, overrides: list): - """Load a dataset by name with optional overrides.""" +def _load_dataset(name: str, dataset_cfg: OmegaConf): + """Load a dataset by name using dataset-specific config (from config/dataset/{name}.yaml).""" if name == "pecanstreet": - return PecanStreetDataset(overrides=DATASET_OVERRIDES + PECAN_OVERRIDES + (overrides or [])) + return PecanStreetDataset(cfg=dataset_cfg) if name == "commercial": - return CommercialDataset(overrides=DATASET_OVERRIDES + (overrides or [])) + return CommercialDataset(cfg=dataset_cfg) if name == "airquality": - return AirQualityDataset(overrides=DATASET_OVERRIDES + (overrides or [])) + return AirQualityDataset(cfg=dataset_cfg) raise ValueError(f"Dataset {name} not supported. Use: pecanstreet, commercial, airquality.") @@ -42,20 +96,26 @@ def main() -> None: parser.add_argument( "--model-ckpt", type=str, - required=True, - help="Path to model checkpoint (.ckpt or .pt).", + default=None, + help="Path to model checkpoint (.ckpt or .pt). Required unless --model-key is set.", + ) + parser.add_argument( + "--model-key", + type=str, + default=None, + help="HuggingFace model key (e.g. Watts_2_1D). If set, model and normalizer are loaded from HF instead of --model-ckpt.", ) parser.add_argument( "--normalizer-ckpt", type=str, default=None, - help="Path to normalizer checkpoint. If omitted, evaluation uses normalized space.", + help="Path to normalizer checkpoint. If omitted, evaluation uses normalized space (or HF normalizer when using --model-key).", ) parser.add_argument( "--model-type", type=str, - default="diffusion_ts", - help="Model type (e.g. diffusion_ts) used to load the checkpoint.", + default=None, + help="Model type (e.g. diffusion_ts) used to load the checkpoint. Inferred from --model-key when loading from HF.", ) parser.add_argument( "--dataset", @@ -75,7 +135,7 @@ def main() -> None: "--save-dir", type=str, default=None, - help="Directory to save evaluation results. If None, uses default location based on model checkpoint path.", + help="Directory to save evaluation results. If None, uses checkpoint parent + /eval or outputs/eval/ when using --model-key.", ) parser.add_argument( "--job-name", @@ -133,67 +193,118 @@ def main() -> None: ) args = parser.parse_args() + if not args.model_ckpt and not args.model_key: + parser.error("One of --model-ckpt or --model-key is required.") + if args.model_ckpt and args.model_key: + parser.error("Use only one of --model-ckpt or --model-key.") + # Set custom context config path if provided if args.context_config_path: set_context_config_path(args.context_config_path) logging.info("Loading dataset %s...", args.dataset) - overrides = list(args.dataset_overrides) if args.dataset_overrides else [] - if args.dataset == "pecanstreet" and "time_series_dims" not in str(overrides): - overrides = overrides + ["time_series_dims=1", "user_group=all"] - - dataset = _load_dataset(args.dataset, overrides) + overrides = list(DATASET_OVERRIDES) + if args.dataset == "pecanstreet": + overrides = overrides + PECAN_OVERRIDES + # Use pretrained normalizer from checkpoint/HF: skip dataset normalizer init so it doesn't train + if args.model_key: + overrides = overrides + ["normalize=False"] + # Watts (and most pretrained) normalizers use scale=True (do_scale); match so stats_head shape loads + if args.model_key: + overrides = overrides + ["scale=True"] + if args.dataset_overrides: + overrides = overrides + list(args.dataset_overrides) + dataset_cfg = _load_dataset_config(args.dataset, overrides) + dataset = _load_dataset(args.dataset, dataset_cfg) + + # Resolve model type (from key when loading from HF, else from args) + if args.model_key: + model_type = get_model_type_from_hf_name(args.model_key) + else: + model_type = args.model_type or "diffusion_ts" # Load configs eval_cfg = load_yaml(args.evaluator_config) top_cfg = load_yaml(args.config) - + cfg = OmegaConf.create({}) cfg.evaluator = eval_cfg cfg.wandb = top_cfg.get("wandb", {}) - cfg.device = top_cfg.get("device", "auto") + cfg.device = "cuda:0" cfg.model = OmegaConf.create( - OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{args.model_type}.yaml"), resolve=True) + OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{model_type}.yaml"), resolve=True) ) cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) - + + print("EVAL CONFIG:") + print(cfg) + + # When loading from a local checkpoint, infer seq_len and time_series_dims from the + # checkpoint so the model is built with the same architecture (avoids shape mismatch). + if args.model_ckpt and Path(args.model_ckpt).suffix == ".ckpt": + try: + ckpt_seq_len, ckpt_time_series_dims = _infer_dataset_shape_from_ckpt( + args.model_ckpt, cond_emb_dim=int(cfg.model.cond_emb_dim) + ) + if ckpt_seq_len != cfg.dataset.seq_len or ckpt_time_series_dims != cfg.dataset.time_series_dims: + logging.info( + "Checkpoint has seq_len=%s, time_series_dims=%s; overriding dataset config to match.", + ckpt_seq_len, ckpt_time_series_dims, + ) + cfg.dataset.seq_len = ckpt_seq_len + cfg.dataset.time_series_dims = ckpt_time_series_dims + dataset.cfg.seq_len = ckpt_seq_len + dataset.cfg.time_series_dims = ckpt_time_series_dims + dataset.seq_len = ckpt_seq_len + dataset.time_series_dims = ckpt_time_series_dims + except (KeyError, ValueError) as e: + logging.warning( + "Could not infer dataset shape from checkpoint (%s). Using eval dataset config; shape mismatch may occur.", + e, + ) + # Set EMA sampling cfg.model.use_ema_sampling = not args.no_ema - + # Set evaluation flags (use config defaults if not overridden) cfg.eval_pv_shift = args.eval_pv_shift if args.eval_pv_shift else eval_cfg.get("eval_pv_shift", False) cfg.eval_metrics = False if args.no_eval_metrics else eval_cfg.get("eval_metrics", True) cfg.eval_context_sparse = False if args.no_eval_context_sparse else eval_cfg.get("eval_context_sparse", True) cfg.eval_disentanglement = False if args.no_eval_disentanglement else eval_cfg.get("eval_disentanglement", True) cfg.save_results = False if args.no_save_results else True - + # Set job name cfg.job_name = args.job_name if args.job_name else eval_cfg.get("job_name", "default_job") - + # Set save directory if args.save_dir: cfg.save_dir = Path(args.save_dir) + elif args.model_key: + cfg.save_dir = Path("outputs/eval") / args.model_key else: - # Default: use model checkpoint directory + /eval model_ckpt_path = Path(args.model_ckpt) cfg.save_dir = model_ckpt_path.parent / "eval" - + if not os.path.exists(cfg.save_dir): os.makedirs(cfg.save_dir, exist_ok=True) logging.info("Created evaluation directory: %s", cfg.save_dir) - logging.info("Setting up DataGenerator (model_type=%s)...", args.model_type) - gen = DataGenerator(model_type=args.model_type, dataset=dataset) - - logging.info("Loading checkpoint... EMA sampling %s", "enabled" if cfg.model.use_ema_sampling else "disabled") - gen.load_from_checkpoint(args.model_ckpt, args.normalizer_ckpt) - + use_hf = args.model_key is not None + if use_hf: + logging.info("Setting up DataGenerator from HuggingFace (model_key=%s)...", args.model_key) + gen = DataGenerator(model_name=args.model_key, dataset=dataset, cfg=cfg) + else: + logging.info("Setting up DataGenerator (model_type=%s)...", model_type) + gen = DataGenerator(model_type=model_type, dataset=dataset, cfg=cfg) + logging.info("Loading checkpoint... EMA sampling %s", "enabled" if cfg.model.use_ema_sampling else "disabled") + gen.load_from_checkpoint(args.model_ckpt, args.normalizer_ckpt) + # Ensure EMA setting is applied to the config used by the model at generate time target = getattr(gen.model, "cfg", None) or gen.cfg if target is not None and hasattr(target, "model"): target.model.use_ema_sampling = cfg.model.use_ema_sampling - - gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) + + # gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) cfg.dataset = gen.model.cfg.dataset with torch.no_grad(): @@ -219,4 +330,4 @@ def main() -> None: if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/scripts/train.py b/scripts/train.py index b5c0471..034e718 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -8,6 +8,7 @@ from cents.datasets.vehicle import VehicleDataset from cents.trainer import Trainer from cents.utils.utils import set_context_config_path, set_context_overrides, get_context_config +from cents.utils.config_loader import load_yaml, apply_overrides from omegaconf import OmegaConf import warnings import argparse @@ -16,6 +17,21 @@ PROJECT_ROOT = Path(__file__).resolve().parent.parent RUNS_DIR = PROJECT_ROOT / "runs" +CONFIG_DATASET_DIR = PROJECT_ROOT / "cents" / "config" / "dataset" + + +def _load_dataset_config(dataset_name: str, overrides: list) -> OmegaConf: + """Load dataset-specific config from config/dataset/{dataset_name}.yaml and apply overrides.""" + config_path = CONFIG_DATASET_DIR / f"{dataset_name}.yaml" + if not config_path.exists(): + raise ValueError( + f"Dataset config not found for '{dataset_name}' at {config_path}. " + f"Available: {[p.name for p in CONFIG_DATASET_DIR.glob('*.yaml')]}" + ) + cfg = load_yaml(str(config_path)) + if overrides: + cfg = apply_overrides(cfg, overrides) + return cfg def _write_run_summary(run_dir: Path, run_name: str, trainer: Trainer) -> None: @@ -56,27 +72,33 @@ def main(args) -> None: if args.context_overrides: set_context_overrides(args.context_overrides) + # Build dataset-specific overrides (key=value list; config is loaded from config/dataset/{dataset}.yaml) + dataset_overrides = [f"skip_heavy_processing={args.skip_heavy_processing}"] + if args.dataset == "pecanstreet": + dataset_overrides.extend(["time_series_dims=1", "user_group=all"]) + dataset_cfg = _load_dataset_config(args.dataset, dataset_overrides) + if args.dataset == "pecanstreet": dataset = PecanStreetDataset( - overrides=[f"skip_heavy_processing={args.skip_heavy_processing}, time_series_dims=1, user_group=all"], + cfg=dataset_cfg, force_retrain_normalizer=args.force_retrain_normalizer, run_dir=str(run_dir), ) elif args.dataset == "commercial": dataset = CommercialDataset( - overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"], + cfg=dataset_cfg, force_retrain_normalizer=args.force_retrain_normalizer, run_dir=str(run_dir), ) elif args.dataset == "airquality": dataset = AirQualityDataset( - overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"], + cfg=dataset_cfg, force_retrain_normalizer=args.force_retrain_normalizer, run_dir=str(run_dir), ) elif args.dataset == "vehicle": dataset = VehicleDataset( - overrides=[f"skip_heavy_processing={args.skip_heavy_processing}"], + cfg=dataset_cfg, force_retrain_normalizer=args.force_retrain_normalizer, run_dir=str(run_dir), ) @@ -133,8 +155,7 @@ def main(args) -> None: help="Evaluate after training") parser.add_argument("--skip_heavy_processing", action="store_true", help="Skip heavy processing of dataset") - parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_false", - help="DDP strategy; use 'ddp' (no find_unused_parameters) for best perf") + parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_true") parser.add_argument("--enable_checkpointing", action="store_true", help="Enable checkpointing") parser.add_argument("--context-config-path", type=str, default=None, @@ -151,4 +172,4 @@ def main(args) -> None: ) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file From 0128858ae6e5bac7fccb29af379034235bcbf2fc Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Tue, 10 Feb 2026 10:13:54 -0500 Subject: [PATCH 32/50] Removed comments, config changes --- cents/config/context/default.yaml | 2 +- cents/config/trainer/normalizer.yaml | 2 -- cents/datasets/timeseries_dataset.py | 2 +- cents/models/context.py | 22 +--------------------- cents/models/normalizer.py | 7 +------ 5 files changed, 4 insertions(+), 31 deletions(-) diff --git a/cents/config/context/default.yaml b/cents/config/context/default.yaml index d2585ac..6327799 100644 --- a/cents/config/context/default.yaml +++ b/cents/config/context/default.yaml @@ -3,7 +3,7 @@ # Static context: used by generative models (ACGAN, Diffusion_TS) for conditioning static_context: - type: sep_mlp # Context module type (e.g., "mlp", "sep_mlp") + type: mlp # Context module type (e.g., "mlp", "sep_mlp") # Future parameters can be added here: # n_layers: 2 # hidden_dim: 256 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index f5d8daf..ea62a3a 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,14 +1,12 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu devices: 2,3 -precision: 16-mixed log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 n_epochs: 2000 batch_size: 4096 lr: 3e-4 -gradient_clip_val: 1.0 save_cycle: 5000 eval_after_training: False loss_type: mse # Options: "mse" or "gaussian_nll" diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 34e670d..942bd37 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -246,7 +246,7 @@ def get_train_dataloader( self._normalize_continuous_vars() return DataLoader( - self, batch_size=batch_size, shuffle=shuffle, num_workers=8, persistent_workers=persistent_workers + self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, persistent_workers=persistent_workers ) def split_timeseries(self, df: pd.DataFrame) -> pd.DataFrame: diff --git a/cents/models/context.py b/cents/models/context.py index f11cded..010528e 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -55,11 +55,7 @@ def __init__(self, context_vars: dict[str, int], embedding_dim: int): self.classification_heads = nn.ModuleDict( { - var_name: nn.Sequential( - nn.Linear(embedding_dim, embedding_dim), - nn.ReLU(), - nn.Linear(embedding_dim, num_categories[1]) - ) + var_name: nn.Linear(embedding_dim, num_categories[1]) for var_name, num_categories in context_vars.items() } ) @@ -78,34 +74,18 @@ def forward( classification_logits (Dict[str, Tensor]): Logits per variable, each of shape (batch_size, num_categories). """ -# # At start of forward, before any embedding(context_vars[name]) -# for name in context_vars: -# t = context_vars[name] -# if t.dtype in (torch.long, torch.int): -# t_cpu = t.detach().cpu() -# print(f"{name}: shape={t_cpu.shape}, min={t_cpu.min().item()}, max={t_cpu.max().item()}") -# print(context_vars.keys(), self.context_embeddings.keys(), "context_vars and context_embeddings") embeddings = [ layer(context_vars[name]) for name, layer in self.context_embeddings.items() ] - # print("max", embeddings[0].max(), "min", embeddings[0].min(), "mean", embeddings[0].mean(), "std", embeddings[0].std(), "nan", embeddings[0].isnan().sum(), "inf", embeddings[0].isinf().sum()) - # print(embeddings, "embeddings") - context_matrix = torch.cat(embeddings, dim=1) embedding = self.mlp(context_matrix) - # print("max", embedding.max(), "min", embedding.min(), "mean", embedding.mean(), "std", embedding.std(), "nan", embedding.isnan().sum(), "inf", embedding.isinf().sum()) - - # print(embedding, "embedding") - classification_logits = { var_name: head(embedding) for var_name, head in self.classification_heads.items() } - # print(classification_logits, "classification_logits") - return embedding, classification_logits @register_context_module("default", "sep_mlp") diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 3f61e82..4bf6e7b 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -89,12 +89,6 @@ def _initialize_output_layer(self, init_sigma: float = 1.0): end = 2 * D out_layer.bias[start:end].fill_(log_sigma_bias) - # If you want, you can also bias z_min/z_max to plausible values here when do_scale=True - # Example (optional): - # if self.do_scale: - # out_layer.bias[2*D:3*D].fill_(-2.0) # z_min - # out_layer.bias[3*D:4*D].fill_( 2.0) # z_max - # Sanity: ensure output dimension matches expectation expected_out = K * D if out_layer.out_features != expected_out: @@ -196,6 +190,7 @@ def __init__( # Use registry to get the stats head class StatsHeadCls = get_stats_head_cls(stats_head_type) + print(do_scale, "do_scale") self.stats_head = StatsHeadCls( embedding_dim=self.embedding_dim, hidden_dim=hidden_dim, From adb16f6d95fee4f4475416ffc8cf65ed5a3e6b8e Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Thu, 12 Feb 2026 19:48:35 -0500 Subject: [PATCH 33/50] config chagnes for training, eval, remove nan checks --- cents/config/model/diffusion_ts.yaml | 8 ++-- cents/config/trainer/diffusion_ts.yaml | 6 +-- cents/models/context.py | 52 ++++++++------------------ cents/models/diffusion_ts.py | 16 +------- cents/models/model_utils.py | 15 ++++++-- cents/models/normalizer.py | 26 +++++-------- scripts/eval_pretrained.py | 8 ++-- scripts/train.py | 7 ---- 8 files changed, 49 insertions(+), 89 deletions(-) diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index f26d309..3f0b69a 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -11,10 +11,10 @@ n_steps: 1000 sampling_timesteps: 1000 sampling_batch_size: 4096 loss_type: l1 #l2 -training_objective: eps -loss_weighting: snr +training_objective: x0 +loss_weighting: uniform min_snr_gamma: 5.0 -beta_schedule: cosine #linear diffusion ts paper uses linear schedule +beta_schedule: linear #linear diffusion ts paper uses linear schedule n_heads: 4 mlp_hidden_times: 4 eta: 0.0 @@ -27,5 +27,5 @@ reg_weight: null gradient_accumulate_every: 2 ema_decay: 0.99 ema_update_interval: 10 -use_ema_sampling: True +use_ema_sampling: False k_bins: 20 \ No newline at end of file diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index c881f09..cfc1ec2 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -1,11 +1,11 @@ precision: "16-mixed" accelerator: auto devices: auto -strategy: ddp_find_unused_parameters_false +strategy: ddp_find_unused_parameters_true gradient_accumulate_every: 4 log_every_n_steps: 1 batch_size: 512 -max_epochs: 5000 +max_epochs: 2500 base_lr: 1e-4 eval_after_training: False @@ -17,7 +17,7 @@ checkpoint: lr_scheduler_params: factor: 0.5 - patience: 50 + patience: 200 min_lr: 1.0e-6 threshold: 1.0e-1 threshold_mode: rel diff --git a/cents/models/context.py b/cents/models/context.py index 010528e..b6fee2b 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -187,22 +187,22 @@ def forward(self, context_vars): # Process continuous variables (only those present in context_vars) for name, layer in self.continuous_projections.items(): if name in context_vars: - # Reshape to (batch_size, 1) for linear layer - # Ensure proper shape and gradient flow + # # Reshape to (batch_size, 1) for linear layer + # # Ensure proper shape and gradient flow continuous_val = context_vars[name] - # Handle different input shapes - if continuous_val.dim() == 0: - # Scalar: add batch dimension - continuous_val = continuous_val.unsqueeze(0) - elif continuous_val.dim() == 1: - # 1D tensor: add feature dimension - continuous_val = continuous_val.unsqueeze(-1) - # Ensure float type while preserving gradients - if not continuous_val.is_floating_point(): - continuous_val = continuous_val.float() - - if continuous_val.dim() == 1: - continuous_val = continuous_val.unsqueeze(-1) + # # Handle different input shapes + # if continuous_val.dim() == 0: + # # Scalar: add batch dimension + # continuous_val = continuous_val.unsqueeze(0) + # elif continuous_val.dim() == 1: + # # 1D tensor: add feature dimension + # continuous_val = continuous_val.unsqueeze(-1) + # # Ensure float type while preserving gradients + # if not continuous_val.is_floating_point(): + # continuous_val = continuous_val.float() + + # if continuous_val.dim() == 1: + # continuous_val = continuous_val.unsqueeze(-1) encodings[name] = layer(continuous_val) embeddings = [] @@ -223,28 +223,8 @@ def forward(self, context_vars): ) embeddings.append(embedding_output) - if not embeddings: - raise ValueError("No context variables found in context_vars dict") - context_matrix = torch.cat(embeddings, dim=1) - - # Check for NaN before mixing MLP - if torch.isnan(context_matrix).any(): - raise ValueError( - f"NaN detected in context_matrix before mixing MLP. " - f"This suggests one of the context variable embeddings contains NaN." - ) - embedding = self.mixing_mlp(context_matrix) - - # Check for NaN after mixing MLP - if torch.isnan(embedding).any(): - raise ValueError( - f"NaN detected in final embedding after mixing MLP. " - f"Context matrix stats: mean={context_matrix.mean():.4f}, " - f"std={context_matrix.std():.4f}, " - f"min={context_matrix.min():.4f}, max={context_matrix.max():.4f}" - ) classification_logits = { var_name: head(embedding) @@ -259,8 +239,6 @@ def forward(self, context_vars): if var_name in context_vars } - # Combine both into a single dict for backward compatibility - # The training step will need to distinguish between them all_outputs = {**classification_logits, **regression_outputs} return embedding, all_outputs diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 5c72bf9..f74e754 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -401,20 +401,6 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di # t = torch.randint(0, self.num_timesteps, (b,), device=self.device) t = self.stratified_timesteps(b, self.num_timesteps, self.cfg.model.k_bins, device=self.device) embedding, cond_classification_logits = self._get_context_embedding(context_vars) - # Check embedding for NaN/Inf - # if embedding.isnan().any() or embedding.isinf().any(): - # raise ValueError( - # f"NaN/Inf detected in embedding from context module. " - # f"NaN count: {embedding.isnan().sum()}, Inf count: {embedding.isinf().sum()}, " - # f"shape: {embedding.shape}, min: {embedding.min()}, max: {embedding.max()}" - # ) - - # Embedding should now be normalized by the context module (mean=0, std=1 per sample) - # Check that values are in reasonable range - # if embedding.abs().max() > 100.0: - # print(f"[Warning] Embedding has large values despite normalization: " - # f"min={embedding.min():.4f}, max={embedding.max():.4f}, " - # f"mean={embedding.mean():.4f}, std={embedding.std():.4f}") # Check diffusion schedule parameters noise = torch.randn_like(x) @@ -428,7 +414,7 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di # Compute loss based on training objective (network always predicts x0; we derive epsilon/v as needed) if self.training_objective == "x0": loss_per_elem = self.recon_loss_fn(x_start_pred, x, reduction="none") - elif self.training_objective == "epsilon": + elif self.training_objective == "eps": pred_noise = self.predict_noise_from_start(x_noisy, t, x_start_pred) loss_per_elem = self.recon_loss_fn(pred_noise, noise, reduction="none") else: # v diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index c6b3920..7a07bfb 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -775,6 +775,12 @@ def __init__( self.time_emb = SinusoidalPosEmb(n_embd) + self.cond_mix_mlp = nn.Sequential( + nn.Linear(n_embd * 2, n_embd), + nn.ReLU(), + nn.Linear(n_embd, n_embd), + ) + if conv_params is None or conv_params[0] is None: if n_feat < 32 and n_channel < 64: kernel_size, padding = 1, 0 @@ -839,10 +845,13 @@ def forward(self, input, t, padding_masks=None, return_res=False, cond=None): if (cond is not None) and (self.cond_proj is not None): label_emb = self.cond_proj(cond) # (B, n_embd) # Add them up here to pass a single vector down - total_cond_emb = t_emb + label_emb + total_cond_emb = torch.concat([t_emb, label_emb], dim=1) else: - total_cond_emb = t_emb - + total_cond_emb = torch.concat([t_emb, torch.zeros_like(t_emb)], dim=1) + + ## Use MLP to combine t_emb and label_emb + total_cond_emb = self.cond_mix_mlp(total_cond_emb) + emb = self.emb(input) inp_enc = self.pos_enc(emb) diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 4bf6e7b..3ee535e 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -155,6 +155,7 @@ def __init__( time_series_dims: int = 2, do_scale: bool = True, stats_head_type: str = "mlp", + dynamic_var_names: list[str] = None, ): """ Args: @@ -168,7 +169,7 @@ def __init__( super().__init__() self.static_cond_module = static_cond_module self.dynamic_cond_module = dynamic_cond_module - + self.dynamic_var_names = dynamic_var_names # Determine embedding dimension from available modules if static_cond_module is not None: self.embedding_dim = static_cond_module.embedding_dim @@ -190,7 +191,6 @@ def __init__( # Use registry to get the stats head class StatsHeadCls = get_stats_head_cls(stats_head_type) - print(do_scale, "do_scale") self.stats_head = StatsHeadCls( embedding_dim=self.embedding_dim, hidden_dim=hidden_dim, @@ -293,10 +293,9 @@ def __init__( # Get continuous variables from config if specified self.continuous_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "continuous"] self.categorical_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "categorical"] - dynamic_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "time_series"] + self.dynamic_context_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "time_series"] self.static_context_vars = self.categorical_vars + self.continuous_vars - self.dynamic_context_vars = dynamic_vars self.context_vars = self.static_context_vars + self.dynamic_context_vars self.time_series_cols = dataset_cfg.time_series_columns[ @@ -318,7 +317,7 @@ def __init__( self.loss_type = getattr(self.normalizer_training_cfg, "loss_type", "mse") # Create static context module (for categorical + continuous) - static_context_module = None + self.static_context_module = None if self.static_context_vars: StaticContextModuleCls = get_context_module_cls(self.static_module_type) # Filter context_vars to only static ones @@ -326,13 +325,13 @@ def __init__( k: v for k, v in self.dataset.context_var_dict.items() if k in self.static_context_vars } - static_context_module = StaticContextModuleCls( + self.static_context_module = StaticContextModuleCls( self.static_context_vars_dict, 256, ) # Create dynamic context module (for time_series) - dynamic_context_module = None + self.dynamic_context_module = None if self.dynamic_context_vars and self.dynamic_module_type is not None: DynamicContextModuleCls = get_context_module_cls("dynamic", self.dynamic_module_type) # Filter context_vars to only dynamic ones @@ -342,26 +341,21 @@ def __init__( } # Use num_ts_steps for dynamic context length if available, otherwise seq_len dynamic_seq_len = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len - dynamic_context_module = DynamicContextModuleCls( + self.dynamic_context_module = DynamicContextModuleCls( dynamic_context_vars_dict, 256, seq_len=dynamic_seq_len, ) self.normalizer_model = _NormalizerModule( - static_cond_module=static_context_module, - dynamic_cond_module=dynamic_context_module, + static_cond_module=self.static_context_module, + dynamic_cond_module=self.dynamic_context_module, hidden_dim=512, time_series_dims=self.time_series_dims, do_scale=self.do_scale, stats_head_type=self.stats_head_type, + dynamic_var_names=self.dynamic_context_vars, ) - # Store dynamic var names for filtering in forward - self.normalizer_model._dynamic_var_names = self.dynamic_context_vars - # For backward compatibility, expose the static context module - self.context_module = self.normalizer_model.static_cond_module - # Expose the dynamic context module at top level so it shows in model summary - self.dynamic_cond_module = self.normalizer_model.dynamic_cond_module # Will be populated in setup() self.sample_stats = [] diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index b126361..c02c720 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -22,7 +22,7 @@ level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) -DATASET_OVERRIDES = ["max_samples=10000", "skip_heavy_processing=True"] +DATASET_OVERRIDES = ["max_samples=10000"] PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] CONFIG_DATASET_DIR = Path(__file__).resolve().parent.parent / "cents" / "config" / "dataset" @@ -156,9 +156,9 @@ def main() -> None: help="Path to main config YAML file.", ) parser.add_argument( - "--no-ema", + "--ema", action="store_true", - help="Disable EMA sampling (EMA is used by default when present in checkpoint).", + help="Enable EMA sampling.", ) parser.add_argument( "--eval-pv-shift", @@ -264,7 +264,7 @@ def main() -> None: ) # Set EMA sampling - cfg.model.use_ema_sampling = not args.no_ema + cfg.model.use_ema_sampling = args.ema # Set evaluation flags (use config defaults if not overridden) cfg.eval_pv_shift = args.eval_pv_shift if args.eval_pv_shift else eval_cfg.get("eval_pv_shift", False) diff --git a/scripts/train.py b/scripts/train.py index 034e718..2f82cc4 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -5,7 +5,6 @@ from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset from cents.datasets.airquality import AirQualityDataset -from cents.datasets.vehicle import VehicleDataset from cents.trainer import Trainer from cents.utils.utils import set_context_config_path, set_context_overrides, get_context_config from cents.utils.config_loader import load_yaml, apply_overrides @@ -96,12 +95,6 @@ def main(args) -> None: force_retrain_normalizer=args.force_retrain_normalizer, run_dir=str(run_dir), ) - elif args.dataset == "vehicle": - dataset = VehicleDataset( - cfg=dataset_cfg, - force_retrain_normalizer=args.force_retrain_normalizer, - run_dir=str(run_dir), - ) else: raise ValueError(f"Dataset {args.dataset} not supported") From 32dc636e566713a3ddbd9328bdc633ea81de2522 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Thu, 12 Feb 2026 20:38:12 -0500 Subject: [PATCH 34/50] Removed print statement --- scripts/eval_pretrained.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index c02c720..06dba15 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -306,22 +306,6 @@ def main() -> None: # gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) cfg.dataset = gen.model.cfg.dataset - - with torch.no_grad(): - betas = gen.model.betas - alphas = 1 - betas - abar = torch.cumprod(alphas, dim=0) - abar_prev = F.pad(abar[:-1], (1,0), value=1.0) - - pv_expected = betas * (1 - abar_prev) / (1 - abar) - print((pv_expected - gen.model.posterior_variance).abs().max()) - - pmc1_expected = betas * abar_prev.sqrt() / (1 - abar) - pmc2_expected = (1 - abar_prev) * alphas.sqrt() / (1 - abar) - - print((pmc1_expected - gen.model.posterior_mean_coef1).abs().max()) - print((pmc2_expected - gen.model.posterior_mean_coef2).abs().max()) - logging.info("Checkpoint loaded. Starting evaluation...") results = Evaluator(cfg, dataset).evaluate_model(data_generator=gen) From c5fac35696a2193c94c29f163d620b1ecf68904e Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Fri, 13 Feb 2026 10:47:58 -0500 Subject: [PATCH 35/50] Global smoothing for normalizer --- cents/config/context/default.yaml | 2 +- cents/config/dataset/commercial.yaml | 6 +- cents/config/model/diffusion_ts.yaml | 2 +- cents/config/trainer/normalizer.yaml | 2 +- cents/data_generator.py | 2 + cents/datasets/timeseries_dataset.py | 15 +- cents/models/diffusion_ts.py | 4 +- cents/models/model_utils.py | 13 +- cents/models/normalizer.py | 697 ++++++++------------------- scripts/eval_pretrained.py | 6 +- 10 files changed, 227 insertions(+), 522 deletions(-) diff --git a/cents/config/context/default.yaml b/cents/config/context/default.yaml index 6327799..a147043 100644 --- a/cents/config/context/default.yaml +++ b/cents/config/context/default.yaml @@ -12,7 +12,7 @@ static_context: normalizer: stats_head_type: mlp # Stats head type (e.g., "mlp") # Future parameters can be added here: - # n_layers: 3 + n_layers: 3 # hidden_dim: 512 # Dynamic context: context module used by the normalizer for time series context variables diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index 07c601c..346549b 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -16,7 +16,7 @@ data_columns: ["dataid","energy_meter","timestamp"] metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sqft", "yearbuilt"] numeric_context_bins: 5 reduce_cardinality: False - +normalizer_stats_mode: group context_vars: year: ["categorical", 2] @@ -24,5 +24,5 @@ context_vars: weekday: ["categorical", 7] site_id: ["categorical", 19] primaryspaceusage: ["categorical", 16] - sqft: ["continuous", null] - yearbuilt: ["continuous", null] \ No newline at end of file + sqft: ["categorical", null] + yearbuilt: ["categorical", null] \ No newline at end of file diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 3f0b69a..4ed648f 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -14,7 +14,7 @@ loss_type: l1 #l2 training_objective: x0 loss_weighting: uniform min_snr_gamma: 5.0 -beta_schedule: linear #linear diffusion ts paper uses linear schedule +beta_schedule: cosine #linear diffusion ts paper uses linear schedule n_heads: 4 mlp_hidden_times: 4 eta: 0.0 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index ea62a3a..a77f62f 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,6 +1,6 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: 2,3 +devices: 2, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 diff --git a/cents/data_generator.py b/cents/data_generator.py index 10f3c46..81f7038 100644 --- a/cents/data_generator.py +++ b/cents/data_generator.py @@ -17,6 +17,7 @@ get_device, get_normalizer_training_config, parse_dims_from_name, + get_context_config, ) from cents.utils.config_loader import load_yaml, apply_overrides @@ -246,6 +247,7 @@ def load_from_checkpoint( dataset_cfg=self.cfg.dataset, normalizer_training_cfg=get_normalizer_training_config(), dataset=self.dataset, + context_cfg=get_context_config(), ) state = torch.load(normalizer_ckpt, map_location=device) sd = state.get("state_dict", state) diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 942bd37..16bbd5a 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -104,8 +104,8 @@ def __init__( # Store categorical time series info self.categorical_time_series = categorical_time_series or {} - if self.scale: - assert self.normalize, "Normalization must be enabled if scaling is enabled" + # if self.scale: + # assert self.normalize, "Normalization must be enabled if scaling is enabled" # Preprocess and optionally encode context self.data = self._preprocess_data(data) @@ -143,6 +143,11 @@ def __init__( self.data = self.merge_timeseries_columns(self.data) self.data = self.data.reset_index() + self.context_cfg = get_context_config() + self.dynamic_module_type = self.context_cfg.dynamic_context.type + self.static_module_type = self.context_cfg.static_context.type + self.stats_head_type = self.context_cfg.normalizer.stats_head_type + # Check if we should skip heavy processing for DDP if is_ddp_subprocess and skip_heavy_processing: @@ -671,11 +676,6 @@ def _init_normalizer(self) -> None: normalizer_dir.mkdir(parents=True, exist_ok=True) # Get context_module_type and stats_head_type from context config - context_cfg = get_context_config() - - self.dynamic_module_type = context_cfg.dynamic_context.type - self.static_module_type = context_cfg.static_context.type - self.stats_head_type = context_cfg.normalizer.stats_head_type cache_path = normalizer_dir / _ckpt_name( self.name, "normalizer", @@ -692,6 +692,7 @@ def _init_normalizer(self) -> None: dataset_cfg=self.cfg, normalizer_training_cfg=ncfg, dataset=self, + context_cfg=self.context_cfg, ) # attempt to load existing state dict (unless force_retrain_normalizer is True) diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index f74e754..37447e2 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -398,8 +398,8 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di # f"NaN count: {torch.isnan(x).sum()}, Inf count: {torch.isinf(x).sum()}") b = x.shape[0] - # t = torch.randint(0, self.num_timesteps, (b,), device=self.device) - t = self.stratified_timesteps(b, self.num_timesteps, self.cfg.model.k_bins, device=self.device) + t = torch.randint(0, self.num_timesteps, (b,), device=self.device) + # t = self.stratified_timesteps(b, self.num_timesteps, self.cfg.model.k_bins, device=self.device) embedding, cond_classification_logits = self._get_context_embedding(context_vars) # Check diffusion schedule parameters diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index 7a07bfb..d3fe320 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -271,6 +271,9 @@ def __init__(self, n_embd): # self.emb = SinusoidalPosEmb(n_embd) self.silu = nn.SiLU() self.linear = nn.Linear(n_embd, n_embd * 2) + + nn.init.zeros_(self.linear.bias) + nn.init.zeros_(self.linear.weight) self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False) def forward(self, x, emb): @@ -777,7 +780,7 @@ def __init__( self.cond_mix_mlp = nn.Sequential( nn.Linear(n_embd * 2, n_embd), - nn.ReLU(), + nn.SiLU(), nn.Linear(n_embd, n_embd), ) @@ -845,12 +848,9 @@ def forward(self, input, t, padding_masks=None, return_res=False, cond=None): if (cond is not None) and (self.cond_proj is not None): label_emb = self.cond_proj(cond) # (B, n_embd) # Add them up here to pass a single vector down - total_cond_emb = torch.concat([t_emb, label_emb], dim=1) + total_cond_emb = self.cond_mix_mlp(torch.concat([t_emb, label_emb], dim=1)) else: - total_cond_emb = torch.concat([t_emb, torch.zeros_like(t_emb)], dim=1) - - ## Use MLP to combine t_emb and label_emb - total_cond_emb = self.cond_mix_mlp(total_cond_emb) + total_cond_emb = t_emb emb = self.emb(input) inp_enc = self.pos_enc(emb) @@ -864,7 +864,6 @@ def forward(self, input, t, padding_masks=None, return_res=False, cond=None): res = self.inverse(output) - # .contiguous() usage here was correct in your original code res_m = torch.mean(res, dim=1, keepdim=True).contiguous() combine_m_out = self.combine_m(mean).contiguous() combine_s_out = self.combine_s(season.transpose(1, 2)).transpose(1, 2).contiguous() diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 3ee535e..a698e22 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -21,10 +21,6 @@ @register_stats_head("default", "mlp") class MLPStatsHead(nn.Module): - """ - Head module predicting summary statistics (mean, std, and optionally min/max z-scores) from context embedding. - """ - def __init__( self, embedding_dim: int, @@ -33,16 +29,6 @@ def __init__( do_scale: bool, n_layers: int = 3, ): - """ - Initializes the statistics head network. - - Args: - embedding_dim: Dimensionality of the input context embedding. - hidden_dim: Number of units in each hidden layer. - time_series_dims: Number of dimensions in the original time series. - do_scale: Whether to predict scaling min/max parameters. - n_layers: Number of hidden linear layers before the output. - """ super().__init__() self.time_series_dims = time_series_dims self.do_scale = do_scale @@ -56,77 +42,54 @@ def __init__( layers.append(nn.Linear(in_dim, out_dim)) self.net = nn.Sequential(*layers) - # Initialize the output layer properly - # For log_sigma, initialize to small negative values so exp(log_sigma) starts around 1 - # This helps with training stability self._initialize_output_layer() def _initialize_output_layer(self, init_sigma: float = 1.0): - """ - Initialize the last layer so log_sigma starts around log(init_sigma). - Assumes outputs are later reshaped as out.view(B, K, D) where K=2 or 4. - Therefore log_sigma lives in the SECOND block: indices [D:2D) in the flattened vector. - """ - assert init_sigma > 0.0, "init_sigma must be > 0" D = self.time_series_dims K = 4 if self.do_scale else 2 - out_layer = self.net[-1] - if not isinstance(out_layer, nn.Linear): - raise RuntimeError("Expected last module in self.net to be nn.Linear") with torch.no_grad(): - # Reasonable default: keep biases at 0, then set log_sigma bias block. - nn.init.zeros_(out_layer.bias) - - # Optional: small weights so output starts near the bias. - # (If you prefer PyTorch defaults, comment this out.) nn.init.xavier_uniform_(out_layer.weight, gain=0.01) - - # Set log_sigma block bias - log_sigma_bias = np.log(init_sigma) - start = 1 * D - end = 2 * D - out_layer.bias[start:end].fill_(log_sigma_bias) - - # Sanity: ensure output dimension matches expectation - expected_out = K * D - if out_layer.out_features != expected_out: - raise ValueError( - f"Output layer out_features={out_layer.out_features}, expected {expected_out} (K={K}, D={D})." - ) + nn.init.zeros_(out_layer.bias) + + if self.do_scale: + # 1. Initialize z_min to -2.0 + out_layer.bias[2 * D : 3 * D].fill_(-2.0) + + # 2. Initialize the RAW DELTA to ~4.0 + # Softplus(4.0) is approx 4.018. + # z_max = -2.0 + 4.018 = 2.018 (Perfect starting point) + out_layer.bias[3 * D : 4 * D].fill_(4.0) @staticmethod def _soft_clamp_tanh(x: torch.Tensor, bound: float) -> torch.Tensor: - """ - Smoothly maps x into [-bound, bound] using tanh. - """ if bound <= 0: raise ValueError("bound must be > 0") return bound * torch.tanh(x / bound) def forward(self, z: torch.Tensor): - """ - Forward pass to compute predicted statistics. - - Args: - z: Context embedding tensor of shape (batch_size, embedding_dim). - - Returns: - pred_mu: Predicted means, shape (batch_size, time_series_dims). - pred_sigma: Predicted standard deviations, shape (batch_size, time_series_dims). - pred_z_min: Predicted min z-scores, or None if do_scale=False. - pred_z_max: Predicted max z-scores, or None if do_scale=False. - pred_log_sigma_unclamped: Unclamped log_sigma for loss computation. - """ out = self.net(z) batch_size = out.size(0) + if self.do_scale: out = out.view(batch_size, 4, self.time_series_dims) pred_mu = out[:, 0, :] pred_log_sigma = out[:, 1, :] + + # z_min is predicted normally pred_z_min = out[:, 2, :] - pred_z_max = out[:, 3, :] + + # The 4th output is now the "raw delta" + raw_delta = out[:, 3, :] + + # Ensure the range is strictly positive (minimum 1e-4) + # F.softplus(x) = log(1 + exp(x)) + actual_range = F.softplus(raw_delta) + 1e-4 + + # Structurally guarantee z_max > z_min + pred_z_max = pred_z_min + actual_range + else: out = out.view(batch_size, 2, self.time_series_dims) pred_mu = out[:, 0, :] @@ -134,11 +97,10 @@ def forward(self, z: torch.Tensor): pred_z_min = None pred_z_max = None - # Store unclamped version for loss computation BEFORE clamping - # This must be done before any operations that might break the computation graph pred_log_sigma_unclamped = pred_log_sigma pred_log_sigma_clamped = self._soft_clamp_tanh(pred_log_sigma, bound=10.0) pred_sigma = torch.exp(pred_log_sigma_clamped) + return pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped @@ -156,20 +118,13 @@ def __init__( do_scale: bool = True, stats_head_type: str = "mlp", dynamic_var_names: list[str] = None, + n_layers: int = 3, ): - """ - Args: - static_cond_module: ContextModule instance for static context variables (categorical + continuous). - dynamic_cond_module: ContextModule instance for dynamic context variables (time_series). - hidden_dim: Hidden dimension size for the stats head. - time_series_dims: Number of time series dimensions. - do_scale: Whether to include scaling predictions. - stats_head_type: Type of stats head to use (from registry). - """ super().__init__() self.static_cond_module = static_cond_module self.dynamic_cond_module = dynamic_cond_module self.dynamic_var_names = dynamic_var_names + # Determine embedding dimension from available modules if static_cond_module is not None: self.embedding_dim = static_cond_module.embedding_dim @@ -180,7 +135,6 @@ def __init__( # If both modules exist, combine their embeddings if static_cond_module is not None and dynamic_cond_module is not None: - # Combine embeddings from both modules combined_dim = static_cond_module.embedding_dim + dynamic_cond_module.embedding_dim self.combine_mlp = nn.Sequential( nn.Linear(combined_dim, self.embedding_dim), @@ -189,32 +143,20 @@ def __init__( else: self.combine_mlp = None - # Use registry to get the stats head class StatsHeadCls = get_stats_head_cls(stats_head_type) self.stats_head = StatsHeadCls( embedding_dim=self.embedding_dim, hidden_dim=hidden_dim, time_series_dims=time_series_dims, do_scale=do_scale, + n_layers=n_layers, ) def forward(self, context_vars_dict: dict): - """ - Compute normalization parameters from categorical context. - - Args: - context_vars_dict: Mapping of context variable names to label tensors. - Static vars: single values (categorical: long, continuous: float) - Dynamic vars: time series sequences (batch, seq_len) - - Returns: - Tuple of (pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped). - """ embeddings = [] # Process static context variables if self.static_cond_module is not None: - # Filter static context variables static_vars = { k: v for k, v in context_vars_dict.items() if k not in getattr(self, '_dynamic_var_names', []) @@ -230,7 +172,6 @@ def forward(self, context_vars_dict: dict): # Process dynamic context variables if self.dynamic_cond_module is not None: - # Filter dynamic context variables dynamic_var_names = getattr(self, '_dynamic_var_names', []) dynamic_vars = { k: v for k, v in context_vars_dict.items() @@ -243,15 +184,11 @@ def forward(self, context_vars_dict: dict): for k, v in dynamic_vars.items() } dynamic_embedding, _ = self.dynamic_cond_module(dynamic_vars) - # Check for NaN in dynamic embedding if torch.isnan(dynamic_embedding).any() or torch.isinf(dynamic_embedding).any(): - raise ValueError( - f"NaN/Inf detected in dynamic embedding. " - f"Dynamic vars: {list(dynamic_vars.keys())}" - ) + raise ValueError(f"NaN/Inf detected in dynamic embedding.") embeddings.append(dynamic_embedding) - # Combine embeddings if both exist + # Combine embeddings if len(embeddings) == 2: combined = torch.cat(embeddings, dim=1) embedding = self.combine_mlp(combined) @@ -274,15 +211,8 @@ def __init__( dataset_cfg, normalizer_training_cfg, dataset, + context_cfg, ): - """ - Initializes the Normalizer training module. - - Args: - dataset_cfg: OmegaConf dataset config (provides context_vars, columns). - normalizer_training_cfg: Config for normalizer training (lr, batch_size). - dataset: Instance of TimeSeriesDataset containing data DataFrame. - """ super().__init__() self.save_hyperparameters(ignore=["dataset"]) @@ -298,29 +228,25 @@ def __init__( self.static_context_vars = self.categorical_vars + self.continuous_vars self.context_vars = self.static_context_vars + self.dynamic_context_vars - self.time_series_cols = dataset_cfg.time_series_columns[ - : dataset_cfg.time_series_dims - ] + self.time_series_cols = dataset_cfg.time_series_columns[: dataset_cfg.time_series_dims] self.time_series_dims = dataset_cfg.time_series_dims self.do_scale = dataset_cfg.scale self.seq_len = dataset_cfg.seq_len - self.num_ts_steps = getattr(dataset_cfg, "num_ts_steps", None) # For dynamic context length - - # Get context config - # context_cfg = get_context_config() + self.num_ts_steps = getattr(dataset_cfg, "num_ts_steps", None) self.static_module_type = self.dataset.static_module_type self.dynamic_module_type = self.dataset.dynamic_module_type self.stats_head_type = self.dataset.stats_head_type - - # Get loss type from config (default to "mse") self.loss_type = getattr(self.normalizer_training_cfg, "loss_type", "mse") + + self.register_buffer("global_mu_mean", torch.tensor(0.0)) + self.register_buffer("global_mu_std", torch.tensor(1.0)) + self.register_buffer("global_log_sigma_mean", torch.tensor(0.0)) - # Create static context module (for categorical + continuous) + # Create static context module self.static_context_module = None if self.static_context_vars: StaticContextModuleCls = get_context_module_cls(self.static_module_type) - # Filter context_vars to only static ones self.static_context_vars_dict = { k: v for k, v in self.dataset.context_var_dict.items() if k in self.static_context_vars @@ -330,16 +256,14 @@ def __init__( 256, ) - # Create dynamic context module (for time_series) + # Create dynamic context module self.dynamic_context_module = None if self.dynamic_context_vars and self.dynamic_module_type is not None: DynamicContextModuleCls = get_context_module_cls("dynamic", self.dynamic_module_type) - # Filter context_vars to only dynamic ones dynamic_context_vars_dict = { k: v for k, v in self.dataset_cfg.context_vars.items() if k in self.dynamic_context_vars } - # Use num_ts_steps for dynamic context length if available, otherwise seq_len dynamic_seq_len = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len self.dynamic_context_module = DynamicContextModuleCls( dynamic_context_vars_dict, @@ -355,6 +279,7 @@ def __init__( do_scale=self.do_scale, stats_head_type=self.stats_head_type, dynamic_var_names=self.dynamic_context_vars, + n_layers=context_cfg.normalizer.n_layers, ) # Will be populated in setup() @@ -362,10 +287,6 @@ def __init__( self._verify_parameters() def _verify_parameters(self): - """ - Verify that all parameters including context module are registered. - This helps debug parameter counting issues. - """ all_param_names = [name for name, _ in self.named_parameters()] context_param_names = [name for name in all_param_names if 'cond_module' in name or 'context_module' in name] stats_head_param_names = [name for name in all_param_names if 'stats_head' in name] @@ -373,7 +294,6 @@ def _verify_parameters(self): if not context_param_names: raise RuntimeError( "Context module parameters not found! " - "Expected parameters with 'cond_module' in name. " f"Found parameter names: {all_param_names[:10]}..." ) @@ -385,11 +305,31 @@ def setup(self, stage: Optional[str] = None): """ Lightning hook: prepare training data before training. """ - # Compute per-sample statistics - no grouping needed + # Compute per-sample statistics + # Note: Using robust quantile scaling for targets to avoid outlier instability mode = getattr(self.dataset_cfg, "normalizer_stats_mode", "sample") - self.sample_stats = self._build_training_samples(mode, use_quantile_scale=False) - - # Log initial predictions to check if model is in the right ballpark + self.sample_stats = self._build_training_samples(mode, use_quantile_scale=True) + + # --- COMPUTE GLOBAL TARGET STATS FOR SCALING --- + # 1. Global Mu Stats (for Z-score scaling) + all_mus = np.concatenate([s[2] for s in self.sample_stats]) + self.target_mu_mean = torch.tensor(all_mus.mean(), dtype=torch.float32) + self.target_mu_std = torch.tensor(all_mus.std() + 1e-8, dtype=torch.float32) + + # 2. Global Sigma Stats (for Log-Space Centering) + all_sigmas_concat = np.concatenate([s[3] for s in self.sample_stats]) + # Calculate the mean of the logs (Geometric mean center) + self.target_log_sigma_mean = torch.tensor(np.log(all_sigmas_concat + 1e-8).mean(), dtype=torch.float32) + + # Register buffers so they persist with model + self.global_mu_mean.fill_(self.target_mu_mean) + self.global_mu_std.fill_(self.target_mu_std) + self.global_log_sigma_mean.fill_(self.target_log_sigma_mean) + + print(f"Global Target Stats: Mu Mean={self.target_mu_mean:.4f}, Mu Std={self.target_mu_std:.4f}") + print(f"Global Target Log Sigma Mean: {self.target_log_sigma_mean:.4f}") + + # Log initial predictions if stage == "fit" or stage is None: self._log_initial_predictions() @@ -397,189 +337,113 @@ def _log_initial_predictions(self): """Log initial model predictions to diagnose initialization issues.""" self.eval() with torch.no_grad(): - # Get a sample batch dataloader = self.train_dataloader() batch = next(iter(dataloader)) cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch - # Move to device device = next(self.parameters()).device - cat_vars_dict = { - k: v.to(device) if isinstance(v, torch.Tensor) else v - for k, v in cat_vars_dict.items() - } + cat_vars_dict = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in cat_vars_dict.items()} mu_t = mu_t.to(device) sigma_t = sigma_t.to(device) + # Predict (Returns Real Unscaled values via Forward) pred_mu, pred_sigma, pred_z_min, pred_z_max, _ = self(cat_vars_dict) print(f"\n[Initial Predictions]") - print(f" Target mu: mean={mu_t.mean().item():.4f}, std={mu_t.std().item():.4f}, range=[{mu_t.min().item():.4f}, {mu_t.max().item():.4f}]") - print(f" Predicted mu: mean={pred_mu.mean().item():.4f}, std={pred_mu.std().item():.4f}, range=[{pred_mu.min().item():.4f}, {pred_mu.max().item():.4f}]") - print(f" Target sigma: mean={sigma_t.mean().item():.4f}, std={sigma_t.std().item():.4f}, range=[{sigma_t.min().item():.4f}, {sigma_t.max().item():.4f}]") - print(f" Predicted sigma: mean={pred_sigma.mean().item():.4f}, std={pred_sigma.std().item():.4f}, range=[{pred_sigma.min().item():.4f}, {pred_sigma.max().item():.4f}]") + print(f" Target mu: mean={mu_t.mean().item():.4f}, std={mu_t.std().item():.4f}") + print(f" Predicted mu: mean={pred_mu.mean().item():.4f}, std={pred_mu.std().item():.4f}") print(f" Initial loss_mu: {F.mse_loss(pred_mu, mu_t).item():.6f}") - print(f" Initial loss_sigma: {F.mse_loss(pred_sigma, sigma_t).item():.6f}") print() self.train() def forward(self, cat_vars_dict: dict): """ - Predict normalization parameters for a batch of categorical contexts. - - Args: - cat_vars_dict: Mapping of context variable names to label tensors. + Predict normalization parameters. + Applies UNSCALING logic to convert network outputs back to real-world range. Returns: - Tuple of (pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped). + Tuple of (pred_mu_real, pred_sigma_real, pred_z_min, pred_z_max, pred_log_sigma_raw). """ - return self.normalizer_model(cat_vars_dict) + # Get raw network outputs (scaled space) + pred_mu_raw, pred_sigma, pred_zmin, pred_zmax, pred_log_sigma_raw = self.normalizer_model(cat_vars_dict) - def _compute_loss_mse(self, pred_mu, pred_sigma, pred_log_sigma_unclamped, mu_t, sigma_t): - loss_mu = F.mse_loss(pred_mu, mu_t) - - # FIX: Clamp the target log sigma to a reasonable range - # (e.g., nothing smaller than e^-5 approx 0.006) - target_log_sigma = torch.log(sigma_t + 1e-8) - target_log_sigma = torch.clamp(target_log_sigma, min=-5.0, max=10.0) - - # FIX: Use Huber Loss (SmoothL1Loss) instead of MSE for stability - loss_sigma = F.smooth_l1_loss(pred_log_sigma_unclamped, target_log_sigma) - - return loss_mu, loss_sigma + # 1. Unscale Mu: (NetworkOutput * GlobalStd) + GlobalMean + pred_mu_real = (pred_mu_raw * self.global_mu_std) + self.global_mu_mean - - def _compute_loss_gaussian_nll(self, pred_mu, pred_sigma, mu_t, sigma_t): + # 2. Unscale Sigma: exp(NetworkLogOutput + GlobalLogMean) + # Note: We reconstruct log_sigma first to ensure numerical stability + pred_log_sigma_real = pred_log_sigma_raw + self.global_log_sigma_mean + pred_sigma_real = torch.exp(pred_log_sigma_real) + + return pred_mu_real, pred_sigma_real, pred_zmin, pred_zmax, pred_log_sigma_raw + + def _compute_loss_mse(self, pred_mu_raw, pred_log_sigma_raw, mu_t_scaled, target_log_sigma_centered): """ - Compute Gaussian Negative Log Likelihood loss. - - Treats target mu_t as observations from N(pred_mu, pred_sigma^2). - For sigma, still uses log-space MSE since it's a scale parameter. - - Args: - pred_mu: Predicted means - pred_sigma: Predicted standard deviations - mu_t: Target means (treated as observations) - sigma_t: Target standard deviations - - Returns: - loss_mu, loss_sigma + Compute MSE loss in the SCALED space. """ - # Use Gaussian NLL for mu: treat mu_t as observations from N(pred_mu, pred_sigma^2) - # GaussianNLLLoss expects: input (mean), target (observations), var (variance) - # We need to ensure variance is positive and not too small - pred_var = torch.clamp(pred_sigma ** 2, min=1e-6) - gaussian_nll = nn.GaussianNLLLoss(reduction='mean') - loss_mu = gaussian_nll(pred_mu, mu_t, pred_var) + loss_mu = F.mse_loss(pred_mu_raw, mu_t_scaled) - # For sigma, still use log-space MSE (sigma is a scale parameter, not a location) - pred_log_sigma = torch.log(pred_sigma + 1e-8) - target_log_sigma = torch.log(sigma_t + 1e-8) - loss_sigma = F.mse_loss(pred_log_sigma, target_log_sigma) + # Use Huber Loss for Log Sigma to be robust against outliers + # Clamp targets slightly to prevent infinite loss if data is broken + target_log_sigma_centered = torch.clamp(target_log_sigma_centered, min=-10.0, max=10.0) + loss_sigma = F.smooth_l1_loss(pred_log_sigma_raw, target_log_sigma_centered, beta=1.0) return loss_mu, loss_sigma def training_step(self, batch, batch_idx: int): - """ - Training step: regress predicted stats against true group stats. - - Args: - batch: Tuple of (cat_vars_dict, mu, sigma, zmin, zmax). - batch_idx: Batch index. - - Returns: - loss tensor. - """ context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch - pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped = self(context_vars_dict) + + # 1. Get RAW network outputs (scaled space) from internal model + # We call self.normalizer_model directly to avoid the unscaling logic in self.forward + # This gives us values that match the standardized targets + pred_mu_raw, _, pred_z_min, pred_z_max, pred_log_sigma_raw = self.normalizer_model(context_vars_dict) - # Compute loss based on loss_type - if self.loss_type == "mse": - loss_mu, loss_sigma = self._compute_loss_mse( - pred_mu, pred_sigma, pred_log_sigma_unclamped, mu_t, sigma_t - ) - elif self.loss_type == "gaussian_nll": - loss_mu, loss_sigma = self._compute_loss_gaussian_nll( - pred_mu, pred_sigma, mu_t, sigma_t - ) - else: - raise ValueError( - f"Unknown loss_type: {self.loss_type}. " - f"Supported types: 'mse', 'gaussian_nll'" - ) + # 2. Scale Targets to match Network Space + + # Scale Mu: Z-score + mu_t_scaled = (mu_t - self.global_mu_mean) / self.global_mu_std + + # Scale Sigma: Log-Space Centering + target_log_sigma_centered = torch.log(sigma_t + 1e-8) - self.global_log_sigma_mean + # 3. Compute Loss + loss_mu, loss_sigma = self._compute_loss_mse( + pred_mu_raw, pred_log_sigma_raw, mu_t_scaled, target_log_sigma_centered + ) total_loss = loss_mu + loss_sigma + # 4. Scaling parameters (z_min/z_max) - These are already naturally roughly scaled (-2 to 2) + # We use SmoothL1Loss to be robust to outliers if self.do_scale: if torch.isnan(pred_z_min).any() or torch.isnan(pred_z_max).any(): - raise ValueError( - f"NaN detected in scale predictions at batch {batch_idx}" - ) - loss_zmin = F.mse_loss(pred_z_min, zmin_t) - loss_zmax = F.mse_loss(pred_z_max, zmax_t) + raise ValueError(f"NaN detected in scale predictions at batch {batch_idx}") + loss_zmin = F.smooth_l1_loss(pred_z_min, zmin_t, beta=1.0) + loss_zmax = F.smooth_l1_loss(pred_z_max, zmax_t, beta=1.0) total_loss += loss_zmin + loss_zmax else: loss_zmin = torch.tensor(0.0, device=total_loss.device) loss_zmax = torch.tensor(0.0, device=total_loss.device) - # Check for NaN in loss if torch.isnan(total_loss) or torch.isinf(total_loss): - raise ValueError( - f"NaN/Inf loss detected at batch {batch_idx}. " - f"loss_mu: {loss_mu.item():.6f}, loss_sigma: {loss_sigma.item():.6f}" - ) + raise ValueError(f"NaN/Inf loss detected.") - # Log individual components to understand what's happening + # Log metrics self.log("train_loss", total_loss, prog_bar=True, sync_dist=True) self.log("loss_mu", loss_mu, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) self.log("loss_sigma", loss_sigma, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) - if self.do_scale: - self.log("loss_zmin", loss_zmin, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) - self.log("loss_zmax", loss_zmax, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) - # Log prediction statistics to monitor if model is learning - if batch_idx % 100 == 0: # Log every 100 batches to avoid spam + if batch_idx % 100 == 0: with torch.no_grad(): - # Log shapes (as number of elements for logging purposes) - self.log("pred_mu_num_elements", pred_mu.numel(), on_step=True, on_epoch=False, sync_dist=True) - self.log("mu_t_num_elements", mu_t.numel(), on_step=True, on_epoch=False, sync_dist=True) - self.log("pred_mu_batch_size", pred_mu.shape[0] if len(pred_mu.shape) > 0 else 1, on_step=True, on_epoch=False) - self.log("pred_mu_dims", pred_mu.shape[1] if len(pred_mu.shape) > 1 else 1, on_step=True, on_epoch=False, sync_dist=True) - - # Log ranges - self.log("pred_mu_min", pred_mu.min(), on_step=True, on_epoch=False, sync_dist=True) - self.log("pred_mu_max", pred_mu.max(), on_step=True, on_epoch=False, sync_dist=True) - self.log("mu_t_min", mu_t.min(), on_step=True, on_epoch=False, sync_dist=True) - self.log("mu_t_max", mu_t.max(), on_step=True, on_epoch=False, sync_dist=True) - - # Log error statistics - mu_errors = (pred_mu - mu_t).abs() - mu_squared_errors = (pred_mu - mu_t) ** 2 - self.log("mu_error_mean", mu_errors.mean(), on_step=True, on_epoch=False, sync_dist=True) - self.log("mu_error_max", mu_errors.max(), on_step=True, on_epoch=False, sync_dist=True) - self.log("mu_error_min", mu_errors.min(), on_step=True, on_epoch=False, sync_dist=True) - self.log("mu_squared_error_mean", mu_squared_errors.mean(), on_step=True, on_epoch=False, sync_dist=True) - self.log("mu_squared_error_max", mu_squared_errors.max(), on_step=True, on_epoch=False, sync_dist=True) - - # Log existing statistics - self.log("pred_mu_mean", pred_mu.mean(), on_step=True, on_epoch=True, sync_dist=True) - self.log("pred_mu_std", pred_mu.std(), on_step=True, on_epoch=True, sync_dist=True) - self.log("pred_sigma_mean", pred_sigma.mean(), on_step=True, on_epoch=True, sync_dist=True) - self.log("pred_sigma_std", pred_sigma.std(), on_step=True, on_epoch=True, sync_dist=True) - self.log("target_mu_mean", mu_t.mean(), on_step=True, on_epoch=True, sync_dist=True) - self.log("target_sigma_mean", sigma_t.mean(), on_step=True, on_epoch=True, sync_dist=True) + # Reconstruct real values for logging intelligibility + pred_mu_real = (pred_mu_raw * self.global_mu_std) + self.global_mu_mean + self.log("pred_mu_mean_real", pred_mu_real.mean(), on_step=True, on_epoch=False) + self.log("target_mu_mean_real", mu_t.mean(), on_step=True, on_epoch=False) return total_loss def configure_optimizers(self): - """ - Configure optimizer for normalizer training. - - Returns: - Adam optimizer instance. - """ optimizer = torch.optim.Adam( self.parameters(), lr=self.normalizer_training_cfg.lr, @@ -587,154 +451,76 @@ def configure_optimizers(self): eps=1e-8, weight_decay=0.0 ) - return optimizer def on_train_batch_end(self, outputs, batch, batch_idx): - """ - Monitor gradients after each training step to diagnose training issues. - """ - if batch_idx % 100 == 0: # Check every 100 batches + if batch_idx % 100 == 0: total_norm = 0.0 - param_count = 0 - zero_grad_count = 0 - - for name, param in self.named_parameters(): - if param.grad is not None: - param_norm = param.grad.data.norm(2) + for p in self.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 - param_count += 1 - if param_norm.item() < 1e-8: - zero_grad_count += 1 - print(name, "HAS ZERO GRAD") - else: - # Parameter has no gradient - this might indicate a problem - if 'cond_module' in name or 'stats_head' in name: - # Only warn about important parameters - pass - total_norm = total_norm ** (1. / 2) - self.log("grad_norm", total_norm, on_step=True, on_epoch=False) - self.log("params_with_grad", param_count, on_step=True, on_epoch=False) - - if total_norm < 1e-6: - print(f"[Warning] Very small gradient norm at batch {batch_idx}: {total_norm:.2e}") - if zero_grad_count > 0: - print(f"[Warning] {zero_grad_count} parameters have near-zero gradients at batch {batch_idx}") def train_dataloader(self): - """ - Returns a DataLoader over per-group statistics samples. - """ ds = self._create_training_dataset() return DataLoader( ds, batch_size=self.normalizer_training_cfg.batch_size, shuffle=True, - num_workers=4, # Use fewer workers to reduce overhead - persistent_workers=True, # Disable to avoid multiprocessing cleanup issues - pin_memory=torch.cuda.is_available(), # Helps with GPU transfer - prefetch_factor=2, # Reduce prefetch to avoid memory issues + num_workers=4, + persistent_workers=True, + pin_memory=torch.cuda.is_available(), + prefetch_factor=2, ) def _compute_per_sample_stats(self) -> list: - """ - Compute statistics for each individual sample. - This allows the model to learn context → normalization_params for all context types - (categorical, continuous, and dynamic) without requiring grouping. - - Returns: - List of tuples: (context_vars_dict, mu_array, std_array, zmin_array, zmax_array) - """ + # Same implementation as before df = self.dataset.data.copy() sample_stats = [] continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] for idx, row in df.iterrows(): context_vars_dict = {} - - # Process static context variables (categorical + continuous) for var_name in self.static_context_vars: if var_name in row: if var_name in continuous_vars: + # Normalize inputs using min-max if available, otherwise just cast context_vars_dict[var_name] = torch.tensor(row[var_name], dtype=torch.float32) else: context_vars_dict[var_name] = torch.tensor(row[var_name], dtype=torch.long) - # Process dynamic context variables (time series) dynamic_ctx_dict = {} for var_name in self.dynamic_context_vars: - # Check for both the original name and context_ prefix (vehicle dataset uses context_ prefix) - ts_data = None - if var_name in row: - ts_data = row[var_name] - + ts_data = row.get(var_name) if ts_data is not None: - if isinstance(ts_data, np.ndarray): - dynamic_ctx_dict[var_name] = ts_data - elif isinstance(ts_data, list): + if isinstance(ts_data, (np.ndarray, list)): dynamic_ctx_dict[var_name] = np.array(ts_data) else: - # If it's a scalar, repeat it to match the appropriate length - # Use num_ts_steps if available (for dynamic context), otherwise seq_len context_len = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len dynamic_ctx_dict[var_name] = np.full(context_len, ts_data) - # Compute statistics from this sample's target time series dimension_points = [] - for d, col_name in enumerate(self.time_series_cols): + for col_name in self.time_series_cols: arr = np.array(row[col_name], dtype=np.float32).flatten() dimension_points.append(arr) - mu_array = np.array( - [pts.mean() for pts in dimension_points], dtype=np.float32 - ) - std_array = np.array( - [pts.std() + 1e-8 for pts in dimension_points], dtype=np.float32 - ) + mu_array = np.array([pts.mean() for pts in dimension_points], dtype=np.float32) + std_array = np.array([pts.std() + 1e-8 for pts in dimension_points], dtype=np.float32) if self.do_scale: - z_min_array = np.array( - [ - (pts - mu).min() / std - for pts, mu, std in zip(dimension_points, mu_array, std_array) - ], - dtype=np.float32, - ) - z_max_array = np.array( - [ - (pts - mu).max() / std - for pts, mu, std in zip(dimension_points, mu_array, std_array) - ], - dtype=np.float32, - ) + z_min_array = np.array([(pts - mu).min() / std for pts, mu, std in zip(dimension_points, mu_array, std_array)], dtype=np.float32) + z_max_array = np.array([(pts - mu).max() / std for pts, mu, std in zip(dimension_points, mu_array, std_array)], dtype=np.float32) else: z_min_array = z_max_array = None - sample_stats.append(( - context_vars_dict, - dynamic_ctx_dict, - mu_array, - std_array, - z_min_array, - z_max_array, - )) + sample_stats.append((context_vars_dict, dynamic_ctx_dict, mu_array, std_array, z_min_array, z_max_array)) return sample_stats def _create_training_dataset(self) -> Dataset: - """ - Build an internal Dataset yielding per-sample statistics. - - Returns: - PyTorch Dataset of samples (context_vars_dict, mu, sigma, zmin, zmax). - """ class _TrainSet(Dataset): - """ - Adapter Dataset to wrap per-sample statistics for DataLoader. - """ - def __init__(self, samples, dynamic_context_vars, do_scale, dataset_cfg): self.samples = samples self.dynamic_context_vars = dynamic_context_vars @@ -745,43 +531,15 @@ def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): - """ - Returns one training sample. - - Args: - idx: Index of the sample. - - Returns: - context_vars_dict: Tensor dict of context labels (static + dynamic). - mu_t: True mean tensor. - sigma_t: True std tensor. - zmin_t: True min z-score tensor or None. - zmax_t: True max z-score tensor or None. - """ context_vars_dict, dynamic_ctx_dict, mu_arr, sigma_arr, zmin_arr, zmax_arr = self.samples[idx] - - # Process dynamic context variables (time series) - convert to tensors for var_name in self.dynamic_context_vars: if var_name in dynamic_ctx_dict: - ts_data = dynamic_ctx_dict[var_name] - # Convert to tensor - if isinstance(ts_data, np.ndarray): - # Check if it's categorical (integer) or numeric (float) - var_info = self.dataset_cfg.context_vars.get(var_name, None) - if var_info and var_info[1] is not None: - # Categorical time series - context_vars_dict[var_name] = torch.from_numpy(ts_data).long() - else: - # Numeric time series - context_vars_dict[var_name] = torch.from_numpy(ts_data).float() + ts_data = np.array(dynamic_ctx_dict[var_name]) + var_info = self.dataset_cfg.context_vars.get(var_name, None) + if var_info and var_info[1] is not None: + context_vars_dict[var_name] = torch.from_numpy(ts_data).long() else: - # Fallback: convert to array first - ts_array = np.array(ts_data) - var_info = self.dataset_cfg.context_vars.get(var_name, None) - if var_info and var_info[1] is not None: - context_vars_dict[var_name] = torch.from_numpy(ts_array).long() - else: - context_vars_dict[var_name] = torch.from_numpy(ts_array).float() + context_vars_dict[var_name] = torch.from_numpy(ts_data).float() mu_t = torch.from_numpy(mu_arr).float() sigma_t = torch.from_numpy(sigma_arr).float() @@ -792,32 +550,13 @@ def __getitem__(self, idx: int): return _TrainSet(self.sample_stats, self.dynamic_context_vars, self.do_scale, self.dataset_cfg) def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Normalize a DataFrame of time series using learned parameters. - - Pads or splits if needed, then applies z-score and min-max scaling. - - Args: - df: Input DataFrame with raw time series columns. - - Returns: - DataFrame with normalized series in same columns. - """ missing = [c for c in self.time_series_cols if c not in df.columns] - if missing: df = split_timeseries(df, self.time_series_cols) - missing = [c for c in self.time_series_cols if c not in df.columns] - - assert not missing, ( - "Normalizer.transform expects data in split format with columns " - f"{self.time_series_cols}." - ) + df_out = df.copy() self.eval() continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] - - # Get categorical time series from dataset if available categorical_ts = getattr(self.dataset, 'categorical_time_series', {}) with torch.no_grad(): @@ -827,138 +566,109 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: if v in continuous_vars: ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) elif v in self.dynamic_context_vars: - # Dynamic (time series) variable - if v in categorical_ts: - # Categorical time series - keep as long - ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) - else: - # Numeric time series - convert to float32 - ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + dtype = torch.long if v in categorical_ts else torch.float32 + ctx[v] = torch.tensor(row[v], dtype=dtype).unsqueeze(0) else: - # Static categorical variable ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) - mu, sigma, zmin, zmax, _ = self(ctx) - mu, sigma = mu[0].cpu().numpy(), sigma[0].cpu().numpy() + + # self(ctx) calls forward, which automatically UNSCALES predictions + pred_mu, pred_sigma, pred_zmin, pred_zmax, _ = self(ctx) + mu, sigma = pred_mu[0].cpu().numpy(), pred_sigma[0].cpu().numpy() for d, col in enumerate(self.time_series_cols): - arr = np.asarray(row[col], dtype=np.float32) - - # Skip normalization for categorical time series if col in categorical_ts: - # Keep as integers, just ensure proper dtype - df_out.at[i, col] = arr.astype(np.int32) + df_out.at[i, col] = np.asarray(row[col]).astype(np.int32) continue - # Normalize numeric time series + arr = np.asarray(row[col], dtype=np.float32) + + if sigma[d] < 1e-4: + print(f"[EXPLOSION ALERT] Row {i}, Col '{col}'") + print(f" -> Sigma is tiny: {sigma[d]:.8f}") + print(f" -> Raw Data Range: {arr.min()} to {arr.max()}") + print(f" -> This will multiply your data by {1/(sigma[d]+1e-8):.0f}x!") + z = (arr - mu[d]) / (sigma[d] + 1e-8) if self.do_scale: - zmin_, zmax_ = zmin[0, d].item(), zmax[0, d].item() + zmin_, zmax_ = pred_zmin[0, d].item(), pred_zmax[0, d].item() rng = (zmax_ - zmin_) + 1e-8 z = (z - zmin_) / rng + + if rng < 1e-4: + print(f"[EXPLOSION ALERT] Row {i}, Col '{col}'") + print(f" -> Range Collapsed: z_min={zmin_:.4f}, z_max={zmax_:.4f}") + print(f" -> Range delta: {rng:.8f}") + print(f" -> This will multiply your data by {1/(rng+1e-8):.0f}x!") + + df_out.at[i, col] = z return df_out def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Denormalize a DataFrame of z-scored series back to original scale. - - Args: - df: DataFrame with normalized series columns. - - Returns: - DataFrame with denormalized series. - """ missing = [c for c in self.time_series_cols if c not in df.columns] - if missing: df = split_timeseries(df, self.time_series_cols) - missing = [c for c in self.time_series_cols if c not in df.columns] - - assert not missing, ( - "Normalizer.inverse_transform expects split format with columns " - f"{self.time_series_cols}." - ) df_out = df.copy() self.eval() continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] - # Get categorical time series from dataset if available categorical_ts = getattr(self.dataset, 'categorical_time_series', {}) + with torch.no_grad(): for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Inverse normalizing"): ctx = {} for v in self.context_vars: if v in continuous_vars: - # Static continuous variable ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) elif v in self.dynamic_context_vars: - # Dynamic (time series) variable - if v in categorical_ts: - # Categorical time series - keep as long - ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) - else: - # Numeric time series - convert to float32 - ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + dtype = torch.long if v in categorical_ts else torch.float32 + ctx[v] = torch.tensor(row[v], dtype=dtype).unsqueeze(0) else: ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) - mu, sigma, zmin, zmax, _ = self(ctx) - mu, sigma = mu[0].cpu().numpy(), sigma[0].cpu().numpy() + + # self(ctx) calls forward, which automatically UNSCALES predictions + pred_mu, pred_sigma, pred_zmin, pred_zmax, _ = self(ctx) + mu, sigma = pred_mu[0].cpu().numpy(), pred_sigma[0].cpu().numpy() for d, col in enumerate(self.time_series_cols): z = np.asarray(row[col], dtype=np.float32) if self.do_scale: - zmin_, zmax_ = zmin[0, d].item(), zmax[0, d].item() + zmin_, zmax_ = pred_zmin[0, d].item(), pred_zmax[0, d].item() rng = (zmax_ - zmin_) + 1e-8 z = z * rng + zmin_ + arr = z * (sigma[d] + 1e-8) + mu[d] df_out.at[i, col] = arr return df_out def _build_training_samples( self, - mode: str = "sample", # "sample" or "group" + mode: str = "sample", group_vars: Optional[list[str]] = None, - use_quantile_scale: bool = False, # if True: use q01/q99 instead of min/max for zlow/zhigh - q_low: float = 0.01, - q_high: float = 0.99, + use_quantile_scale: bool = False, + q_low: float = 0.02, + q_high: float = 0.98, ) -> list: """ - Build training samples for the normalizer. - - Returns a list of tuples: - (context_vars_dict, dynamic_ctx_dict, mu_array, std_array, zlow_array, zhigh_array) - - - mode="sample": one tuple per row - - mode="group": one tuple per group (grouped by group_vars) - - Notes: - - group mode is only well-defined for *static* variables. For dynamic context, - this function will raise unless you explicitly exclude them from grouping. - - continuous vars: if you keep them continuous (float), grouping by them is usually pointless. - In group mode, we therefore ignore continuous vars by default unless you explicitly put them in group_vars. + Build training samples. + Note: Quantile scaling (q_low/q_high) only affects z_min/z_max calculation, + not the mean/std targets. """ assert mode in {"sample", "group"}, f"mode must be 'sample' or 'group', got {mode}" df = self.dataset.data.copy() - # Identify context types continuous_vars = set(getattr(self.dataset_cfg, "continuous_context_vars", None) or []) - dynamic_vars = set(self.dynamic_context_vars) # time_series context vars - static_vars = [v for v in self.static_context_vars] # categorical + continuous + dynamic_vars = set(self.dynamic_context_vars) + static_vars = [v for v in self.static_context_vars] - # Default grouping vars: static categorical only (exclude continuous + dynamic) if group_vars is None: group_vars = [v for v in static_vars if (v not in continuous_vars and v not in dynamic_vars)] - # Sanity: grouping by dynamic vars is almost always wrong (huge keys, high cardinality) bad = [v for v in group_vars if v in dynamic_vars] if bad: - raise ValueError( - f"group_vars contains dynamic(time_series) vars {bad}. " - f"Remove them or use mode='sample'." - ) + raise ValueError(f"group_vars contains dynamic vars {bad}") - # Helper: compute stats from a single row's time series columns def _row_stats(row) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: dim_points = [] for d, col_name in enumerate(self.time_series_cols): @@ -990,59 +700,53 @@ def _row_stats(row) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optio if mode == "sample": for _, row in df.iterrows(): context_vars_dict = {} - - # static vars for v in static_vars: - if v not in row: - continue + if v not in row: continue if v in continuous_vars: + # Ensure we store float for continuous, applying simple normalization if needed in dataset context_vars_dict[v] = torch.tensor(row[v], dtype=torch.float32) else: context_vars_dict[v] = torch.tensor(row[v], dtype=torch.long) - # dynamic vars (store separately; TrainSet will tensorize them) dynamic_ctx_dict = {} for v in self.dynamic_context_vars: - if v not in row: - continue + if v not in row: continue ts_data = row[v] - if isinstance(ts_data, np.ndarray): - dynamic_ctx_dict[v] = ts_data - elif isinstance(ts_data, list): + if isinstance(ts_data, (np.ndarray, list)): dynamic_ctx_dict[v] = np.array(ts_data) else: - # scalar -> repeat L = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len dynamic_ctx_dict[v] = np.full(L, ts_data) mu, std, zlow, zhigh = _row_stats(row) - samples.append((context_vars_dict, dynamic_ctx_dict, mu, std, zlow, zhigh)) return samples # mode == "group" - # Build grouped stats by aggregating all points from rows in each group. - # Context dict comes from group key values. + # Pre-process continuous vars for grouping by binning + for v in group_vars: + if v in continuous_vars: + n_bins = getattr(self, "numeric_context_bins", 5) + df[v] = pd.cut(df[v], bins=n_bins, labels=False, include_lowest=True) + grouped = df.groupby(group_vars, dropna=False) for group_key, gdf in grouped: - # group_key can be scalar or tuple depending on #group_vars if len(group_vars) == 1: group_key = (group_key,) context_vars_dict = {} for i, v in enumerate(group_vars): - # group vars should be categorical by default; cast to long. - # if user explicitly included a continuous var in group_vars, keep float. + val = group_key[i] if v in continuous_vars: - context_vars_dict[v] = torch.tensor(group_key[i], dtype=torch.float32) + if pd.isna(val): val = 0 + context_vars_dict[v] = torch.tensor(int(val), dtype=torch.long) else: - context_vars_dict[v] = torch.tensor(group_key[i], dtype=torch.long) + context_vars_dict[v] = torch.tensor(val, dtype=torch.long) - dynamic_ctx_dict = {} # undefined for grouping; keep empty + dynamic_ctx_dict = {} - # Aggregate raw points per dim dim_points = [[] for _ in range(self.time_series_dims)] for _, row in gdf.iterrows(): for d, col_name in enumerate(self.time_series_cols): @@ -1074,5 +778,4 @@ def _row_stats(row) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optio samples.append((context_vars_dict, dynamic_ctx_dict, mu, std, zlow, zhigh)) - return samples - + return samples \ No newline at end of file diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 06dba15..f1357d3 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -22,7 +22,7 @@ level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) -DATASET_OVERRIDES = ["max_samples=10000"] +DATASET_OVERRIDES = ["max_samples=10000", "normalize=False"] PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] CONFIG_DATASET_DIR = Path(__file__).resolve().parent.parent / "cents" / "config" / "dataset" @@ -236,8 +236,8 @@ def main() -> None: ) cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) - print("EVAL CONFIG:") - print(cfg) + # print("EVAL CONFIG:") + # print(cfg) # When loading from a local checkpoint, infer seq_len and time_series_dims from the # checkpoint so the model is built with the same architecture (avoids shape mismatch). From ae48d130f6872c9c75bc5132d94b68fc9ae778a3 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Tue, 17 Feb 2026 11:43:27 -0500 Subject: [PATCH 36/50] intermitent checkpointing, better result display --- cents/config/context/default.yaml | 2 +- cents/config/model/diffusion_ts.yaml | 4 ++-- cents/config/trainer/diffusion_ts.yaml | 4 ++-- cents/datasets/timeseries_dataset.py | 11 +++++------ cents/trainer.py | 13 +++++++++++-- scripts/eval_pretrained.py | 27 ++++++++++++++++++++++---- scripts/train.py | 2 +- 7 files changed, 45 insertions(+), 18 deletions(-) diff --git a/cents/config/context/default.yaml b/cents/config/context/default.yaml index a147043..4a413ca 100644 --- a/cents/config/context/default.yaml +++ b/cents/config/context/default.yaml @@ -12,7 +12,7 @@ static_context: normalizer: stats_head_type: mlp # Stats head type (e.g., "mlp") # Future parameters can be added here: - n_layers: 3 + n_layers: 5 # hidden_dim: 512 # Dynamic context: context module used by the normalizer for time series context variables diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 4ed648f..dba50da 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -11,8 +11,8 @@ n_steps: 1000 sampling_timesteps: 1000 sampling_batch_size: 4096 loss_type: l1 #l2 -training_objective: x0 -loss_weighting: uniform +training_objective: v +loss_weighting: snr min_snr_gamma: 5.0 beta_schedule: cosine #linear diffusion ts paper uses linear schedule n_heads: 4 diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index cfc1ec2..970321a 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -11,9 +11,9 @@ eval_after_training: False checkpoint: save_last: True # Save final model - save_top_k: 3 # Save top 3 best models + save_top_k: 0 # 0 = only periodic saves; use >0 to also save top-k by metric every_n_train_steps: null - every_n_epochs: 20 # Save every 500 epochs + every_n_epochs: 250 # Save a distinct checkpoint every 250 epochs (250, 500, 750, ...) lr_scheduler_params: factor: 0.5 diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 16bbd5a..912930a 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -121,6 +121,11 @@ def __init__( self._save_context_var_codes() + self.context_cfg = get_context_config() + self.dynamic_module_type = self.context_cfg.dynamic_context.type + self.static_module_type = self.context_cfg.static_context.type + self.stats_head_type = self.context_cfg.normalizer.stats_head_type + is_ddp_subprocess = self._is_ddp_subprocess() if self.normalize: self._init_normalizer() @@ -142,12 +147,6 @@ def __init__( print(f"[Main Process] Cached normalized data for subprocesses") self.data = self.merge_timeseries_columns(self.data) self.data = self.data.reset_index() - - self.context_cfg = get_context_config() - self.dynamic_module_type = self.context_cfg.dynamic_context.type - self.static_module_type = self.context_cfg.static_context.type - self.stats_head_type = self.context_cfg.normalizer.stats_head_type - # Check if we should skip heavy processing for DDP if is_ddp_subprocess and skip_heavy_processing: diff --git a/cents/trainer.py b/cents/trainer.py index 5fab43b..96ce5b2 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -228,12 +228,21 @@ def _instantiate_trainer(self) -> pl.Trainer: checkpoint_dir = getattr(self.cfg, "checkpoint_dir", None) or str(Path(self.cfg.run_dir) / "checkpoints") Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) + base_name = "_".join(filename_parts) + # Include epoch in filename when saving every N epochs so each checkpoint is distinct + every_n = getattr(tc.checkpoint, "every_n_epochs", None) + if every_n is not None and every_n > 0: + filename = f"{base_name}_epoch={{epoch:04d}}" + else: + filename = base_name callbacks.append( ModelCheckpoint( dirpath=checkpoint_dir, - filename="_".join(filename_parts), + filename=filename, save_last=tc.checkpoint.save_last, - save_on_train_epoch_end=True, ### Perhaps excessive + save_on_train_epoch_end=True, + every_n_epochs=every_n if (every_n is not None and every_n > 0) else None, + save_top_k=getattr(tc.checkpoint, "save_top_k", 1), ) ) callbacks.append(EvalAfterTraining(self.cfg, self.dataset)) diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index f1357d3..9ee2e2f 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -2,6 +2,7 @@ import os from pathlib import Path from typing import Tuple +import json import torch import torch.nn.functional as F @@ -230,7 +231,7 @@ def main() -> None: cfg = OmegaConf.create({}) cfg.evaluator = eval_cfg cfg.wandb = top_cfg.get("wandb", {}) - cfg.device = "cuda:0" + cfg.device = "cuda:1" cfg.model = OmegaConf.create( OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{model_type}.yaml"), resolve=True) ) @@ -307,10 +308,28 @@ def main() -> None: # gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) cfg.dataset = gen.model.cfg.dataset - logging.info("Checkpoint loaded. Starting evaluation...") + print("\n" + "=" * 60) + print("EVALUATION RESULTS") + print("=" * 60 + "\n") results = Evaluator(cfg, dataset).evaluate_model(data_generator=gen) - logging.info("Evaluation complete!") - print(results) + + print("\n📊 METRICS:") + print("-" * 60) + metrics = results.get("metrics", {}) + for key, value in metrics.items(): + if isinstance(value, dict): + print(f"\n{key}:") + for subkey, subval in value.items(): + print(f" {subkey}: {subval:.6f}" if isinstance(subval, (int, float)) else f" {subkey}: {subval}") + else: + print(f"{key}: {value:.6f}" if isinstance(value, (int, float)) else f"{key}: {value}") + + # Results are automatically saved if save_results=True + if args.save_dir: + with open(Path(args.save_dir) / "metrics.json", "w") as f: + json.dump(metrics, f, indent=4) + print(f"\n✅ Results saved to: {Path(args.save_dir) / "metrics.json"}") + print("\n" + "=" * 60) if __name__ == "__main__": diff --git a/scripts/train.py b/scripts/train.py index 2f82cc4..06005b3 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -138,7 +138,7 @@ def main(args) -> None: parser.add_argument("--cr_loss_weight", type=float, default=0.1) parser.add_argument("--tc_loss_weight", type=float, default=0.1) parser.add_argument("--dataset", type=str, default="pecanstreet") - parser.add_argument("--epochs", type=int, default=5000) + parser.add_argument("--epochs", type=int, default=2500) parser.add_argument("--batch_size", type=int, default=None) parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights and Biases logging") From 3653211572f49a698d43ef41132328cbdaa712d0 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 18 Feb 2026 13:12:19 -0500 Subject: [PATCH 37/50] Fixed checkpointing, use ff loss, optional conditional guidance --- cents/config/model/diffusion_ts.yaml | 11 +- cents/config/trainer/normalizer.yaml | 2 +- cents/models/diffusion_ts.py | 242 +++++++++++++++++++++++++-- cents/trainer.py | 67 ++++---- scripts/eval_pretrained.py | 4 +- 5 files changed, 283 insertions(+), 43 deletions(-) diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index dba50da..8b57121 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -28,4 +28,13 @@ gradient_accumulate_every: 2 ema_decay: 0.99 ema_update_interval: 10 use_ema_sampling: False -k_bins: 20 \ No newline at end of file +k_bins: 20 +# Reconstruction-guided sampling (Algorithms 1 & 2) +recon_guide_eta: 0.1 # gradient scale for guidance +recon_guide_gamma: 1.0 # trade-off L1 vs L2 (L1 + gamma*L2) +recon_guide_algorithm: none # none | alg1 | alg2 +recon_guide_K: 3 # inner steps per t for alg2 (int or list for K[t]) +# Optional: dual head for x̂_a / x̂_b (set to cond_len used in recon-guided sampling) +recon_cond_len: null # int or null; if set, use fc_a / fc_b for first vs rest of sequence +# Context embedding dropout (training only) for more robust recon-guided sampling +context_embed_dropout: 0 # 0 = disabled \ No newline at end of file diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index a77f62f..d5b3603 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,6 +1,6 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: 2, +devices: 1, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 37447e2..25677b0 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -31,7 +31,9 @@ class Diffusion_TS(GenerativeModel): Uses a Transformer backbone to predict and denoise time series over discrete diffusion timesteps. Supports EMA smoothing and configurable - beta schedules. + beta schedules. Optional reconstruction-guided sampling (Algorithms 1 & 2) + via sample_reconstruction_guided(shape, context_vars, x_a, algorithm="alg1"|"alg2") + when conditional observed data x_a is provided. Training objective (config model.training_objective): x0, epsilon, or v. - x0: predict clean sample; loss = L1/L2(model_out, x_clean). @@ -67,14 +69,33 @@ def __init__(self, cfg: DictConfig): self.context_reconstruction_loss_weight = ( cfg.model.context_reconstruction_loss_weight ) + # Reconstruction-guided sampling (Algorithms 1 & 2) + self.recon_guide_eta = getattr(cfg.model, "recon_guide_eta", 0.1) + self.recon_guide_gamma = getattr(cfg.model, "recon_guide_gamma", 1.0) + self.recon_guide_algorithm = getattr(cfg.model, "recon_guide_algorithm", "none") + self.recon_guide_K = getattr(cfg.model, "recon_guide_K", 3) # Verify context modules are initialized (static, dynamic, or both) if not hasattr(self, 'static_context_module') and not hasattr(self, 'dynamic_context_module'): raise ValueError("At least one context module (static or dynamic) must be initialized") - # linear layer for denoised output (no longer includes embedding_dim) - self.fc = nn.Linear( - self.time_series_dims, self.time_series_dims - ) + # Context embedding dropout (training only) for robust reconstruction-guided sampling + context_embed_dropout = getattr(cfg.model, "context_embed_dropout", 0.0) + self.context_embed_dropout = nn.Dropout(p=context_embed_dropout) + + # Optional dual head for x̂_a / x̂_b: separate output heads for conditional vs rest of sequence + self.recon_cond_len = getattr(cfg.model, "recon_cond_len", None) + if self.recon_cond_len is not None: + cond_len = int(self.recon_cond_len) + assert 0 < cond_len < self.seq_len, "recon_cond_len must be in (0, seq_len)" + self.fc_a = nn.Linear(self.time_series_dims, self.time_series_dims) + self.fc_b = nn.Linear(self.time_series_dims, self.time_series_dims) + self.fc = None + else: + self.fc_a = None + self.fc_b = None + self.fc = nn.Linear( + self.time_series_dims, self.time_series_dims + ) # Transformer backbone (now uses AdaLN conditioning instead of input concatenation) self.model = Transformer( n_feat=self.time_series_dims, @@ -239,9 +260,24 @@ def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict elif len(embeddings) == 1: embedding = embeddings[0] else: - raise ValueError("No context variables provided") + raise ValueError("No context variables provided") + if self.training and self.context_embed_dropout.p > 0: + embedding = self.context_embed_dropout(embedding) return embedding, all_logits + def _decode_to_x0(self, backbone: torch.Tensor) -> torch.Tensor: + """ + Map backbone output (trend+season) to x0 prediction. Uses single fc or dual fc_a/fc_b when recon_cond_len is set. + backbone: (B, L, time_series_dims). + """ + if self.fc is not None: + return self.fc(backbone) + cond_len = self.recon_cond_len + return torch.cat([ + self.fc_a(backbone[:, :cond_len]), + self.fc_b(backbone[:, cond_len:]), + ], dim=1) + def predict_noise_from_start( self, x_t: torch.Tensor, t: torch.Tensor, x0: torch.Tensor ) -> torch.Tensor: @@ -410,7 +446,7 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di ) # Pass embedding as cond parameter instead of concatenating to input trend, season = self.model(x_noisy, t, padding_masks=None, cond=embedding) - x_start_pred = self.fc((trend + season).contiguous()) + x_start_pred = self._decode_to_x0((trend + season).contiguous()) # Compute loss based on training objective (network always predicts x0; we derive epsilon/v as needed) if self.training_objective == "x0": loss_per_elem = self.recon_loss_fn(x_start_pred, x, reduction="none") @@ -431,7 +467,31 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di rec_loss = ( self.loss_weight[t].view(-1, 1, 1) * loss_per_elem ).mean() - return rec_loss, cond_classification_logits + + fourier_loss = torch.tensor(0.0, device=self.device) + if self.use_ff: + # FFT is not generally supported in fp16 for non power-of-2 sizes on cuFFT. + # Run FFT in fp32 outside autocast. + with torch.autocast(device_type=x.device.type, enabled=False): + x1 = x_start_pred.transpose(1, 2).float() + x2 = x.transpose(1, 2).float() + + fft1 = torch.fft.fft(x1, norm="forward") + fft2 = torch.fft.fft(x2, norm="forward") + + # (optional) keep loss in fp32 for stability; no need to cast back to fp16 + fft1 = fft1.transpose(1, 2) + fft2 = fft2.transpose(1, 2) + + fourier_loss = ( + self.recon_loss_fn(fft1.real, fft2.real, reduction="none") + + self.recon_loss_fn(fft1.imag, fft2.imag, reduction="none") + ) + fourier_loss = ( + self.loss_weight[t].view(-1, 1, 1) * fourier_loss + ).mean() + + return rec_loss, cond_classification_logits, fourier_loss.mean() def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: """ @@ -445,7 +505,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: total_loss: Scalar training loss. """ ts_batch, cond_batch = batch - rec_loss, cond_class_logits = self(ts_batch, cond_batch) + rec_loss, cond_class_logits, fourier_loss = self(ts_batch, cond_batch) cond_loss = 0.0 @@ -475,7 +535,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: ) total_loss = ( - rec_loss + self.context_reconstruction_loss_weight * cond_loss + tc_term + rec_loss + self.context_reconstruction_loss_weight * cond_loss + tc_term + fourier_loss * self.ff_weight ) # Check for NaN in total loss @@ -491,6 +551,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: "rec_loss": rec_loss.item(), "cond_loss": cond_loss.item(), "tc_loss": tc_term, + "fourier_loss": fourier_loss.item(), }, prog_bar=True, sync_dist=True, @@ -582,6 +643,17 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: # else: # raise ValueError("No EMA keys found in checkpoint") + def _predict_x0_from_xt_with_grad( + self, x_t: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor + ) -> torch.Tensor: + """ + Predict x0 from x_t with gradients enabled (for reconstruction-guided sampling). + Returns x_start of shape (B, L, C). Call with x_t.requires_grad_(True). + """ + trend, season = self.model(x_t, t, padding_masks=None, cond=embedding) + x_start = self._decode_to_x0((trend + season).contiguous()) + return x_start + @torch.no_grad() def model_predictions( self, x: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor @@ -594,10 +666,22 @@ def model_predictions( x_start: predicted clean sample tensor. """ trend, season = self.model(x, t, padding_masks=None, cond=embedding) - x_start = self.fc((trend + season).contiguous()) + x_start = self._decode_to_x0((trend + season).contiguous()) pred_noise = self.predict_noise_from_start(x, t, x_start) return pred_noise, x_start + @staticmethod + def _replace_conditional( + x_a: torch.Tensor, x_prev: torch.Tensor, cond_len: int + ) -> torch.Tensor: + """ + Replace the first cond_len time steps of x_prev with conditional data x_a. + x_a: (B, cond_len, C), x_prev: (B, L, C). Returns (B, L, C). + """ + out = x_prev.clone() + out[:, :cond_len] = x_a + return out + @torch.no_grad() def p_mean_variance( self, x: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor @@ -624,6 +708,142 @@ def p_sample( noise = torch.randn_like(x) if t > 0 else 0 return pm + (0.5 * plv).exp() * noise + def _reconstruction_guided_step_alg1( + self, + x_t: torch.Tensor, + t: int, + embedding: torch.Tensor, + x_a: torch.Tensor, + cond_len: int, + eta: float, + gamma: float, + ) -> torch.Tensor: + """ + One step of Algorithm 1: predict x̂_0, compute L_1 + γ*L_2, then + x̃_0 = x̂_0 + η ∇_{x_t}(L_1 + γ*L_2); sample x_{t-1} ~ N(μ(x̃_0, x_t), Σ) and Replace(x_a, x_{t-1}). + """ + bt = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long) + x_t = x_t.detach().requires_grad_(True) + + x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding) + x_hat_a = x_start[:, :cond_len] + L_1 = (x_a - x_hat_a).pow(2).mean() + + pm, pv, plv = self.q_posterior(x_start, x_t, bt) + noise = torch.randn_like(x_t, device=x_t.device) if t > 0 else torch.zeros_like(x_t, device=x_t.device) + x_prev_initial = (pm + (0.5 * plv).exp() * noise).detach() + L_2 = ((x_prev_initial - pm).pow(2) / pv.clamp(min=1e-8)).mean() + + loss = L_1 + gamma * L_2 + loss.backward() + with torch.no_grad(): + # x̃_0 = x̂_0 + η ∇_{x_t}(L_1 + γ*L_2); gradient has same shape as x_t (B,L,C) = x̂_0 + x_tilde_0 = x_start.detach() + eta * x_t.grad + pm_final, pv_final, plv_final = self.q_posterior(x_tilde_0, x_t.detach(), bt) + noise_final = torch.randn_like(x_t, device=x_t.device) if t > 0 else torch.zeros_like(x_t, device=x_t.device) + x_prev = pm_final + (0.5 * plv_final).exp() * noise_final + x_prev = self._replace_conditional(x_a, x_prev, cond_len) + return x_prev + + def _reconstruction_guided_step_alg2( + self, + x_t: torch.Tensor, + t: int, + embedding: torch.Tensor, + x_a: torch.Tensor, + cond_len: int, + eta: float, + gamma: float, + K: int, + ) -> torch.Tensor: + """ + One step of Algorithm 2: K inner gradient updates on x_t, then one final sample and Replace. + """ + bt = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long) + embedding_detach = embedding.detach() + x_t = x_t.detach().clone() + + for _ in range(K): + x_t = x_t.requires_grad_(True) + x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach) + x_hat_a = x_start[:, :cond_len] + L_1 = (x_a - x_hat_a).pow(2).mean() + + pm, pv, plv = self.q_posterior(x_start, x_t, bt) + noise = torch.randn_like(x_t, device=x_t.device) if t > 0 else torch.zeros_like(x_t, device=x_t.device) + x_prev_initial = (pm + (0.5 * plv).exp() * noise).detach() + L_2 = ((x_prev_initial - pm).pow(2) / pv.clamp(min=1e-8)).mean() + + loss = L_1 + gamma * L_2 + loss.backward() + with torch.no_grad(): + x_t = x_t + eta * x_t.grad + x_t = x_t.detach() + + with torch.no_grad(): + x_start_final = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach) + pm_final, pv_final, plv_final = self.q_posterior(x_start_final, x_t, bt) + noise_final = torch.randn_like(x_t, device=x_t.device) if t > 0 else torch.zeros_like(x_t, device=x_t.device) + x_prev = pm_final + (0.5 * plv_final).exp() * noise_final + x_prev = self._replace_conditional(x_a, x_prev, cond_len) + return x_prev + + def sample_reconstruction_guided( + self, + shape: Tuple[int, int, int], + context_vars: dict, + x_a: torch.Tensor, + algorithm: str = "alg1", + ) -> torch.Tensor: + """ + Full reverse-pass sampling with reconstruction guidance (Algorithm 1 or 2). + + Args: + shape: (batch_size, seq_len, time_series_dims). + context_vars: context conditioning dict. + x_a: Conditional (observed) data, shape (B, cond_len, C). First cond_len + time steps to reconstruct; model output is split as x̂_0 = [x̂_a, x̂_b]. + algorithm: "alg1" (one gradient step per t) or "alg2" (K inner steps per t). + + Returns: + Generated samples (B, L, C) with first cond_len steps equal to x_a. + + Architecture / config notes (optional improvements): + - x̂_0 split: Currently x̂_0 is the single model output (B,L,C); we split by time + so x̂_a = x̂_0[:, :cond_len], x̂_b = x̂_0[:, cond_len:]. Optionally use two output + heads (one for x̂_a, one for x̂_b) if you want different capacities. + - Context dropout: Consider adding dropout on the context embedding (or on + context encoder outputs) during training to improve robustness of + reconstruction-guided sampling at test time. + - Config: recon_guide_eta (gradient scale), recon_guide_gamma (L1 vs L2 trade-off), + recon_guide_K (int or list of length num_timesteps for per-t inner steps in alg2). + """ + cond_len = x_a.shape[1] + assert cond_len < shape[1], "x_a length must be < seq_len" + assert x_a.shape[0] == shape[0] and x_a.shape[2] == shape[2] + eta = self.recon_guide_eta + gamma = self.recon_guide_gamma + + x = torch.randn(shape, device=self.device) + embedding, _ = self._get_context_embedding(context_vars) + x_a = x_a.to(self.device) + + for t in reversed(range(self.num_timesteps)): + K_t = ( + self.recon_guide_K[t] if t < len(self.recon_guide_K) else self.recon_guide_K[-1] + if isinstance(self.recon_guide_K, (list, tuple)) + else self.recon_guide_K + ) + if algorithm == "alg1": + x = self._reconstruction_guided_step_alg1( + x, t, embedding, x_a, cond_len, eta, gamma + ) + else: + x = self._reconstruction_guided_step_alg2( + x, t, embedding, x_a, cond_len, eta, gamma, K_t + ) + return x + @torch.no_grad() def sample(self, shape: Tuple[int, int, int], context_vars: dict) -> torch.Tensor: """ diff --git a/cents/trainer.py b/cents/trainer.py index 96ce5b2..c5bfcd6 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -196,58 +196,68 @@ def _instantiate_model(self): def _instantiate_trainer(self) -> pl.Trainer: """ Build a PyTorch Lightning Trainer with ModelCheckpoint and loggers. - - Returns: - Configured pl.Trainer instance. + Saves checkpoints every N epochs (if configured) and always keeps last.ckpt. """ tc = self.cfg.trainer callbacks = [] - # Build filename with optional context_module_type + + # ---- Build descriptive base filename ---- filename_parts = [ self.cfg.dataset.name, self.model_type, - f"dim{self.cfg.dataset.time_series_dims}" + f"dim{self.cfg.dataset.time_series_dims}", ] - - # Add context_module_type from context config + from cents.utils.utils import get_context_config context_cfg = get_context_config() - static_context_module_type = context_cfg.static_context.type - if static_context_module_type: - filename_parts.append(f"ctx{static_context_module_type}") - - dynamic_context_module_type = context_cfg.dynamic_context.type - if dynamic_context_module_type: - filename_parts.append(f"dyn{dynamic_context_module_type}") - - # Add stats_head_type from context config - stats_head_type = context_cfg.normalizer.stats_head_type - if stats_head_type: - filename_parts.append(f"stats{stats_head_type}") - - - checkpoint_dir = getattr(self.cfg, "checkpoint_dir", None) or str(Path(self.cfg.run_dir) / "checkpoints") - Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) + + if context_cfg.static_context.type: + filename_parts.append(f"ctx{context_cfg.static_context.type}") + + if context_cfg.dynamic_context.type: + filename_parts.append(f"dyn{context_cfg.dynamic_context.type}") + + if context_cfg.normalizer.stats_head_type: + filename_parts.append(f"stats{context_cfg.normalizer.stats_head_type}") + base_name = "_".join(filename_parts) - # Include epoch in filename when saving every N epochs so each checkpoint is distinct + + # ---- Checkpoint directory ---- + checkpoint_dir = (Path(self.cfg.run_dir) / "checkpoints") + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # ---- Periodic saving config ---- every_n = getattr(tc.checkpoint, "every_n_epochs", None) - if every_n is not None and every_n > 0: + + if every_n and every_n > 0: filename = f"{base_name}_epoch={{epoch:04d}}" + every_n_epochs = every_n + save_top_k = -1 # keep ALL periodic checkpoints else: filename = base_name + every_n_epochs = None + save_top_k = getattr(tc.checkpoint, "save_top_k", 1) + + print(f"Saving every {every_n_epochs} epochs (if configured)") + callbacks.append( ModelCheckpoint( dirpath=checkpoint_dir, filename=filename, - save_last=tc.checkpoint.save_last, + every_n_epochs=every_n_epochs, save_on_train_epoch_end=True, - every_n_epochs=every_n if (every_n is not None and every_n > 0) else None, - save_top_k=getattr(tc.checkpoint, "save_top_k", 1), + save_last=True, # always keep last.ckpt + save_top_k=save_top_k, + auto_insert_metric_name=False, ) ) + callbacks.append(EvalAfterTraining(self.cfg, self.dataset)) + if getattr(self.cfg, "run_dir", None): callbacks.append(LogLossToCsv(self.cfg.run_dir)) + + # ---- Logger ---- logger = False if getattr(self.cfg, "wandb", None) and self.cfg.wandb.enabled: logger = WandbLogger( @@ -271,6 +281,7 @@ def _instantiate_trainer(self) -> pl.Trainer: ) + class LogLossToCsv(Callback): """Append epoch loss values to runs//train_losses.csv.""" diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 9ee2e2f..bce61dd 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -23,7 +23,7 @@ level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) -DATASET_OVERRIDES = ["max_samples=10000", "normalize=False"] +DATASET_OVERRIDES = ["normalize=False"] PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] CONFIG_DATASET_DIR = Path(__file__).resolve().parent.parent / "cents" / "config" / "dataset" @@ -231,7 +231,7 @@ def main() -> None: cfg = OmegaConf.create({}) cfg.evaluator = eval_cfg cfg.wandb = top_cfg.get("wandb", {}) - cfg.device = "cuda:1" + cfg.device = "cuda:2" cfg.model = OmegaConf.create( OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{model_type}.yaml"), resolve=True) ) From e1ed0e8a6c1069590dc9cccceb9199192422904b Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Tue, 24 Feb 2026 14:46:14 -0500 Subject: [PATCH 38/50] New run tracking, eval in normalized and raw domain, stabilization for commm dataset --- cents/config/dataset/commercial.yaml | 7 +- cents/config/model/diffusion_ts.yaml | 7 +- cents/config/trainer/normalizer.yaml | 10 + cents/datasets/commercial.py | 4 + cents/datasets/timeseries_dataset.py | 19 +- cents/eval/eval.py | 35 ++- cents/models/diffusion_ts.py | 320 ++++++++++++++++++----- cents/models/model_utils.py | 62 ++++- cents/models/normalizer.py | 168 +++++++----- cents/utils/utils.py | 8 +- scripts/eval_pretrained.py | 365 ++++++++++++++++----------- scripts/train.py | 34 ++- 12 files changed, 743 insertions(+), 296 deletions(-) diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index 346549b..a8c2096 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -2,7 +2,7 @@ name: commercial geography: null user_group: all normalize: True -scale: True +scale: False use_learned_normalizer: True threshold: 8 seq_len: 24 @@ -13,7 +13,7 @@ max_samples: null # Limit dataset size (null = use all data) path: "./data/commercial/csv" time_series_columns: "energy_meter" data_columns: ["dataid","energy_meter","timestamp"] -metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sqft", "yearbuilt"] +metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sqft", "yearbuilt", "sub_primaryspaceusage"] numeric_context_bins: 5 reduce_cardinality: False normalizer_stats_mode: group @@ -25,4 +25,5 @@ context_vars: site_id: ["categorical", 19] primaryspaceusage: ["categorical", 16] sqft: ["categorical", null] - yearbuilt: ["categorical", null] \ No newline at end of file + yearbuilt: ["categorical", null] + sub_primaryspaceusage: ["categorical", 104] \ No newline at end of file diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 8b57121..67f9bfa 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -22,7 +22,7 @@ attn_pd: 0.0 resid_pd: 0.0 kernel_size: null padding_size: null -use_ff: True +use_ff: False reg_weight: null gradient_accumulate_every: 2 ema_decay: 0.99 @@ -37,4 +37,7 @@ recon_guide_K: 3 # inner steps per t for alg2 (int or list for K[t]) # Optional: dual head for x̂_a / x̂_b (set to cond_len used in recon-guided sampling) recon_cond_len: null # int or null; if set, use fc_a / fc_b for first vs rest of sequence # Context embedding dropout (training only) for more robust recon-guided sampling -context_embed_dropout: 0 # 0 = disabled \ No newline at end of file +context_embed_dropout: 0 # 0 = disabled +blue_noise_power: 0.0 # 0.0 = white noise, 1.0 = blue noise, 2.0 = violet noise +# When true and time_series_dims > 1, noise is correlated across dimensions (same draw per timestep). +correlated_noise: True \ No newline at end of file diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index d5b3603..d5af861 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -11,6 +11,16 @@ save_cycle: 5000 eval_after_training: False loss_type: mse # Options: "mse" or "gaussian_nll" +# If true (default), targets are scaled by global mu/sigma stats and network predicts in scaled space (then unscaled in forward). +# If false, network predicts mu in asinh space and log(sigma) directly (no global preprocessing); can be more stable for some datasets (e.g. commercial). +use_global_stats_preprocessing: true +# When use_global_stats_preprocessing=false: mu is predicted as asinh(mu); clamp to [-max_asinh_mu, max_asinh_mu]. sinh(10) ~ 11013. +max_asinh_mu: 10 + +# Floor predicted sigma and scale range so (x - mu) / sigma and (z - zmin) / rng never explode +min_sigma: 0.1 # sigma_effective = max(pred_sigma, min_sigma) +min_scale_range: 0.25 # rng_effective = max(zmax - zmin, min_scale_range) + checkpoint: save_last: False save_top_k: 0 diff --git a/cents/datasets/commercial.py b/cents/datasets/commercial.py index 43b2d15..7bb7b66 100644 --- a/cents/datasets/commercial.py +++ b/cents/datasets/commercial.py @@ -137,6 +137,10 @@ def _preprocess_data(self, data): if merged[context_cols].isna().sum().sum() > 0: print(f"Warning: {merged[context_cols].isna().sum().sum()} NaN values remain after handling") + mask = merged["energy_meter"].apply(lambda x: x.std() < 0.01) + + merged = merged[~mask] + return merged def _handle_missing_data(self, merged): diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 912930a..254204c 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -11,7 +11,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from sklearn.cluster import KMeans from torch.utils.data import DataLoader, Dataset -from omegaconf import ListConfig +from omegaconf import ListConfig, OmegaConf import pickle from cents.datasets.utils import encode_context_variables @@ -673,20 +673,27 @@ def _init_normalizer(self) -> None: Path.home() / ".cache" / "cents" / "checkpoints" / self.name / "normalizer" ) normalizer_dir.mkdir(parents=True, exist_ok=True) - - # Get context_module_type and stats_head_type from context config + + ncfg = get_normalizer_training_config() + if hasattr(self.cfg, "normalizer_use_global_stats_preprocessing"): + ncfg = OmegaConf.merge( + ncfg, + OmegaConf.create({"use_global_stats_preprocessing": self.cfg.normalizer_use_global_stats_preprocessing}), + ) + use_global = ncfg.get("use_global_stats_preprocessing", True) + cache_path = normalizer_dir / _ckpt_name( - self.name, - "normalizer", + self.name, + "normalizer", self.time_series_dims, static_module_type=self.static_module_type, stats_head_type=self.stats_head_type, dynamic_module_type=self.dynamic_module_type, + use_global_stats_preprocessing=use_global, ) print(f"[Cents] cache_path: {cache_path}") - ncfg = get_normalizer_training_config() self._normalizer = Normalizer( dataset_cfg=self.cfg, normalizer_training_cfg=ncfg, diff --git a/cents/eval/eval.py b/cents/eval/eval.py index 9e6d6c0..14c0e26 100644 --- a/cents/eval/eval.py +++ b/cents/eval/eval.py @@ -187,19 +187,23 @@ def compute_quality_metrics( syn_data: np.ndarray, real_data_frame: pd.DataFrame, mask: Optional[np.ndarray] = None, + target: Optional[Dict] = None, + log_prefix: str = "", ) -> Dict: """ - Compute evaluation metrics and store them in current_results. + Compute evaluation metrics and store them in current_results (or in target if provided). Args: real_data (np.ndarray): Real data array (shape: [N, seq_len, dims]) syn_data (np.ndarray): Synthetic data array (shape: [N, seq_len, dims]) real_data_frame (pd.DataFrame): Real data subset (inverse-transformed) mask (Optional[np.ndarray]): Boolean array indicating which rows are "rare" + target (Optional[Dict]): If set, write metrics into this dict instead of current_results (for normalized_domain). + log_prefix (str): Prefix for log messages (e.g. "[normalized]"). """ - logger.info(f"[Cents] --- Starting Full-Subset Metrics ---") + logger.info(f"[Cents] --- {log_prefix}Full-Subset Metrics ---") - metrics = {} + metrics = target if target is not None else {} # Compute and store metrics dtw_mean, dtw_std = dynamic_time_warping_dist(real_data, syn_data) @@ -269,7 +273,9 @@ def compute_quality_metrics( logger.info("[Cents] Done computing Rare-Subset Metrics.") metrics["rare_subset"] = rare_metrics - self.current_results["metrics"] = metrics + if target is None: + self.current_results["metrics"] = metrics + return metrics def compute_disentanglement_metrics( self, @@ -382,9 +388,28 @@ def evaluate_subset( ): rare_mask = real_data_subset["is_rare"].values + # Metrics in raw (un-normalized) domain self.compute_quality_metrics( - real_data_array, syn_data_array, real_data_inv, rare_mask + real_data_array, syn_data_array, real_data_inv, rare_mask, + log_prefix="[raw] ", ) + # Metrics in normalized (z) domain for cross-domain comparability (only when dataset is normalized) + if ( + getattr(dataset, "normalize", False) + and getattr(dataset, "_normalizer", None) is not None + and "timeseries" in real_data_subset.columns + ): + real_data_norm = np.stack(real_data_subset["timeseries"].values) + syn_data_norm = generated_ts + logger.info("[Cents] Computing metrics in normalized domain (z-space) for cross-domain comparison.") + normalized_metrics = {} + self.compute_quality_metrics( + real_data_norm, syn_data_norm, real_data_inv, rare_mask, + target=normalized_metrics, + log_prefix="[normalized] ", + ) + self.current_results["metrics"]["normalized_domain"] = normalized_metrics + if self.cfg.evaluator.eval_disentanglement: self.compute_disentanglement_metrics(context_vars, model) diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 25677b0..79ef313 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -13,6 +13,24 @@ from contextlib import contextmanager +def _nan_check(t: Optional[torch.Tensor], name: str, extra: str = "") -> None: + """Print location and stats when tensor contains NaN or Inf (for debugging).""" + if t is None or not isinstance(t, torch.Tensor): + return + if not (torch.isnan(t).any() or torch.isinf(t).any()): + return + nan_c = torch.isnan(t).sum().item() + inf_c = torch.isinf(t).sum().item() + finite = t[~(torch.isnan(t) | torch.isinf(t))] + min_s = finite.min().item() if finite.numel() > 0 else float("nan") + max_s = finite.max().item() if finite.numel() > 0 else float("nan") + mean_s = finite.float().mean().item() if finite.numel() > 0 else float("nan") + print( + f"[NaN/Inf] {name}: shape={tuple(t.shape)}, nan_count={nan_c}, inf_count={inf_c}, " + f"finite_min={min_s:.6g}, finite_max={max_s:.6g}, finite_mean={mean_s:.6g} {extra}".strip() + ) + + from cents.models.base import GenerativeModel from cents.models.model_utils import ( Transformer, @@ -23,6 +41,68 @@ ) from cents.models.registry import register_model +def _randn_like_correlated( + x: torch.Tensor, correlated: bool +) -> torch.Tensor: + """White noise with same shape as x. If correlated and C>1, same noise broadcast across last dim.""" + if not correlated or x.dim() < 3 or x.shape[-1] == 1: + return torch.randn_like(x) + return torch.randn(*x.shape[:-1], 1, device=x.device, dtype=x.dtype).expand_as(x).clone() + + +def _randn_shape_correlated( + shape: tuple, device: torch.device, dtype: torch.dtype, correlated: bool +) -> torch.Tensor: + """Randn with given shape. If correlated and shape[-1]>1, same noise broadcast across last dim.""" + if not correlated or len(shape) < 3 or shape[-1] == 1: + return torch.randn(shape, device=device, dtype=dtype) + B, L, C = shape[0], shape[1], shape[2] + return torch.randn(B, L, 1, device=device, dtype=dtype).expand(shape).clone() + + +def blueish_noise_like( + x: torch.Tensor, power: float = 1.0, eps: float = 1e-6, correlated: bool = False +) -> torch.Tensor: + """ + Generate 'blue-ish' noise: more energy at higher frequencies. + - power = 0.0 -> white noise + - power > 0.0 -> increasingly high-frequency-heavy (blue/violet-ish) + Returns noise with ~unit std per sample/channel so diffusion scaling stays consistent. + - correlated: if True and x has multiple channels (C>1), same noise is used for all channels. + + x: (B, L, C) where L is time dimension. + """ + B, L, C = x.shape + + if power == 0.0: + return _randn_like_correlated(x, correlated) + + # When correlated and C>1, generate (B, L, 1) then expand after FFT shaping + if correlated and C > 1: + n = torch.randn(B, L, 1, device=x.device, dtype=torch.float32) + else: + n = torch.randn(B, L, C, device=x.device, dtype=torch.float32) + + # real FFT over time + N = torch.fft.rfft(n, dim=1) # (B, F, C) or (B, F, 1) + freqs = torch.fft.rfftfreq(L, d=1.0).to(x.device) # (F,) + + # Amplitude shaping: + amp = (freqs.clamp_min(eps) ** (power / 2.0)).view(1, -1, 1) + N = N * amp + + n_blue = torch.fft.irfft(N, n=L, dim=1) # (B, L, C) or (B, L, 1) + + # Re-normalize per (B,C) to unit std across time + n_blue = n_blue / n_blue.std(dim=1, keepdim=True).clamp_min(1e-6) + + if correlated and C > 1: + n_blue = n_blue.expand(B, L, C).clone() + + out = n_blue.to(dtype=x.dtype) + _nan_check(out, "blueish_noise_like output") + return out + @register_model("diffusion_ts", "Watts_2_1D", "Watts_2_2D") class Diffusion_TS(GenerativeModel): @@ -112,6 +192,9 @@ def __init__(self, cfg: DictConfig): cond_dim=self.embedding_dim, ) + self.blue_noise_power = cfg.model.blue_noise_power + self.correlated_noise = bool(getattr(cfg.model, "correlated_noise", False)) + # EMA helper will be initialized on train start self._ema: Optional[EMA] = None @@ -183,6 +266,9 @@ def __init__(self, cfg: DictConfig): gamma=min_snr_gamma, ) self.register_buffer("loss_weight", lw) + _nan_check(self.loss_weight, "init loss_weight") + _nan_check(self.betas, "init betas") + _nan_check(self.sqrt_alphas_cumprod, "init sqrt_alphas_cumprod") # choose reconstruction loss if self.loss_type == "l1": @@ -211,7 +297,10 @@ def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict """ embeddings = [] all_logits = {} - + for k, v in context_vars.items(): + if isinstance(v, torch.Tensor): + _nan_check(v, f"_get_context_embedding context_vars[{k}]") + # Process static context variables if self.static_context_module is not None: # Filter static context variables @@ -225,7 +314,31 @@ def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v for k, v in static_vars.items() } + # Debug: print which static input has NaN (only when we see one) + for k, v in static_vars.items(): + if isinstance(v, torch.Tensor) and (torch.isnan(v).any() or torch.isinf(v).any()): + nan_c = torch.isnan(v).sum().item() + inf_c = torch.isinf(v).sum().item() + finite = v[~(torch.isnan(v) | torch.isinf(v))] + min_s = finite.min().item() if finite.numel() > 0 else float("nan") + max_s = finite.max().item() if finite.numel() > 0 else float("nan") + mean_s = finite.float().mean().item() if finite.numel() > 0 else float("nan") + print( + f"[NaN/Inf] static_var '{k}': shape={tuple(v.shape)}, dtype={v.dtype}, " + f"nan_count={nan_c}, inf_count={inf_c}, finite_min={min_s:.6g}, finite_max={max_s:.6g}, finite_mean={mean_s:.6g}" + ) + # Replace NaN/Inf in static inputs so the context module does not produce NaN embeddings + # def _sanitize(t: torch.Tensor) -> torch.Tensor: + # if not isinstance(t, torch.Tensor): + # return t + # if not (torch.isnan(t).any() or torch.isinf(t).any()): + # return t + # if t.is_floating_point(): + # return torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0) + # return t + # static_vars = {k: _sanitize(v) for k, v in static_vars.items()} static_embedding, static_logits = self.static_context_module(static_vars) + _nan_check(static_embedding, "_get_context_embedding static_embedding") embeddings.append(static_embedding) all_logits.update(static_logits) @@ -244,25 +357,23 @@ def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict for k, v in dynamic_vars.items() } dynamic_embedding, dynamic_logits = self.dynamic_context_module(dynamic_vars) - # Check for NaN in dynamic embedding - # if torch.isnan(dynamic_embedding).any() or torch.isinf(dynamic_embedding).any(): - # raise ValueError( - # f"NaN/Inf detected in dynamic embedding. " - # f"Dynamic vars: {list(dynamic_vars.keys())}" - # ) + _nan_check(dynamic_embedding, "_get_context_embedding dynamic_embedding") embeddings.append(dynamic_embedding) all_logits.update(dynamic_logits) # Combine embeddings if both exist if len(embeddings) == 2: combined = torch.cat(embeddings, dim=1) + _nan_check(combined, "_get_context_embedding combined") embedding = self.combine_mlp(combined) elif len(embeddings) == 1: embedding = embeddings[0] else: raise ValueError("No context variables provided") + _nan_check(embedding, "_get_context_embedding embedding (before dropout)") if self.training and self.context_embed_dropout.p > 0: embedding = self.context_embed_dropout(embedding) + _nan_check(embedding, "_get_context_embedding embedding (final)") return embedding, all_logits def _decode_to_x0(self, backbone: torch.Tensor) -> torch.Tensor: @@ -270,13 +381,17 @@ def _decode_to_x0(self, backbone: torch.Tensor) -> torch.Tensor: Map backbone output (trend+season) to x0 prediction. Uses single fc or dual fc_a/fc_b when recon_cond_len is set. backbone: (B, L, time_series_dims). """ + _nan_check(backbone, "_decode_to_x0 backbone") if self.fc is not None: - return self.fc(backbone) - cond_len = self.recon_cond_len - return torch.cat([ - self.fc_a(backbone[:, :cond_len]), - self.fc_b(backbone[:, cond_len:]), - ], dim=1) + out = self.fc(backbone) + else: + cond_len = self.recon_cond_len + out = torch.cat([ + self.fc_a(backbone[:, :cond_len]), + self.fc_b(backbone[:, cond_len:]), + ], dim=1) + _nan_check(out, "_decode_to_x0 output") + return out def predict_noise_from_start( self, x_t: torch.Tensor, t: torch.Tensor, x0: torch.Tensor @@ -292,9 +407,11 @@ def predict_noise_from_start( Returns: Noise prediction tensor same shape as x_t. """ - return ( + out = ( self.sqrt_recip_alphas_cumprod[t].view(-1, 1, 1) * x_t - x0 ) / self.sqrt_recipm1_alphas_cumprod[t].view(-1, 1, 1) + _nan_check(out, "predict_noise_from_start output") + return out def predict_start_from_noise( self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor @@ -310,10 +427,12 @@ def predict_start_from_noise( Returns: Reconstructed x0 tensor same shape as x_t. """ - return ( + out = ( self.sqrt_recip_alphas_cumprod[t].view(-1, 1, 1) * x_t - self.sqrt_recipm1_alphas_cumprod[t].view(-1, 1, 1) * noise ) + _nan_check(out, "predict_start_from_noise output") + return out def predict_start_from_v( self, x_t: torch.Tensor, t: torch.Tensor, v: torch.Tensor @@ -322,10 +441,12 @@ def predict_start_from_v( Reconstruct x0 from x_t and v-parameterization. v = sqrt(alpha_bar_t) * epsilon - sqrt(1 - alpha_bar_t) * x0 => x0 = sqrt(alpha_bar_t) * x_t - sqrt(1 - alpha_bar_t) * v """ - return ( + out = ( self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x_t - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * v ) + _nan_check(out, "predict_start_from_v output") + return out def predict_noise_from_v( self, x_t: torch.Tensor, t: torch.Tensor, v: torch.Tensor @@ -334,10 +455,12 @@ def predict_noise_from_v( Reconstruct epsilon from x_t and v-parameterization. v = sqrt(alpha_bar_t) * epsilon - sqrt(1 - alpha_bar_t) * x0 => epsilon = sqrt(1 - alpha_bar_t) * x_t + sqrt(alpha_bar_t) * v """ - return ( + out = ( self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x_t + self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * v ) + _nan_check(out, "predict_noise_from_v output") + return out def compute_snr_weights( @@ -411,6 +534,9 @@ def q_posterior( ) pv = self.posterior_variance[t].view(-1, 1, 1) plv = self.posterior_log_variance_clipped[t].view(-1, 1, 1) + _nan_check(pm, "q_posterior pm") + _nan_check(pv, "q_posterior pv") + _nan_check(plv, "q_posterior plv") return pm, pv, plv def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, dict]: @@ -425,36 +551,44 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di rec_loss: Reconstruction loss tensor. cond_logits: Classification logits dict from context module. """ - # Check input x for extreme values - # if x.abs().max() > 100.0: - # print(f"[Warning] Input x has extreme values: min={x.min():.4f}, max={x.max():.4f}, " - # f"mean={x.mean():.4f}, std={x.std():.4f}, shape={x.shape}") - # if torch.isnan(x).any() or torch.isinf(x).any(): - # raise ValueError(f"NaN/Inf detected in input x. Shape: {x.shape}, " - # f"NaN count: {torch.isnan(x).sum()}, Inf count: {torch.isinf(x).sum()}") - + _nan_check(x, "forward input x") + # Log when x is in reasonable range but we still see NaN later (helps distinguish bad input vs numerical instability) + # if isinstance(x, torch.Tensor): + # x_abs_max = x.abs().max().item() + # if x_abs_max > 50.0: + # print( + # f"[forward] input x has large values: min={x.min().item():.6g}, max={x.max().item():.6g}, abs_max={x_abs_max:.6g}" + # ) + b = x.shape[0] t = torch.randint(0, self.num_timesteps, (b,), device=self.device) - # t = self.stratified_timesteps(b, self.num_timesteps, self.cfg.model.k_bins, device=self.device) embedding, cond_classification_logits = self._get_context_embedding(context_vars) - - # Check diffusion schedule parameters - noise = torch.randn_like(x) + _nan_check(embedding, "forward embedding") + + noise = blueish_noise_like( + x, power=self.blue_noise_power, correlated=self.correlated_noise + ) + _nan_check(noise, "forward noise") x_noisy = ( self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * noise ) - # Pass embedding as cond parameter instead of concatenating to input + _nan_check(x_noisy, "forward x_noisy") trend, season = self.model(x_noisy, t, padding_masks=None, cond=embedding) + _nan_check(trend, "forward trend") + _nan_check(season, "forward season") x_start_pred = self._decode_to_x0((trend + season).contiguous()) + _nan_check(x_start_pred, "forward x_start_pred") # Compute loss based on training objective (network always predicts x0; we derive epsilon/v as needed) if self.training_objective == "x0": loss_per_elem = self.recon_loss_fn(x_start_pred, x, reduction="none") elif self.training_objective == "eps": pred_noise = self.predict_noise_from_start(x_noisy, t, x_start_pred) + _nan_check(pred_noise, "forward pred_noise (eps)") loss_per_elem = self.recon_loss_fn(pred_noise, noise, reduction="none") else: # v pred_noise = self.predict_noise_from_start(x_noisy, t, x_start_pred) + _nan_check(pred_noise, "forward pred_noise (v)") pred_v = ( self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * pred_noise - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x_start_pred @@ -463,10 +597,19 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * noise - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x ) + _nan_check(pred_v, "forward pred_v") + _nan_check(true_v, "forward true_v") loss_per_elem = self.recon_loss_fn(pred_v, true_v, reduction="none") + _nan_check(loss_per_elem, "forward loss_per_elem") rec_loss = ( self.loss_weight[t].view(-1, 1, 1) * loss_per_elem ).mean() + _nan_check(rec_loss, "forward rec_loss") + # When loss is NaN but input x was in reasonable range, point to numerical instability downstream + if (torch.isnan(rec_loss) | torch.isinf(rec_loss)).any(): + print( + f"[forward] rec_loss is NaN/Inf while input x had min={x.min().item():.6g}, max={x.max().item():.6g}, abs_max={x.abs().max().item():.6g}" + ) fourier_loss = torch.tensor(0.0, device=self.device) if self.use_ff: @@ -479,17 +622,28 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di fft1 = torch.fft.fft(x1, norm="forward") fft2 = torch.fft.fft(x2, norm="forward") - # (optional) keep loss in fp32 for stability; no need to cast back to fp16 - fft1 = fft1.transpose(1, 2) - fft2 = fft2.transpose(1, 2) + mag1 = torch.abs(fft1) + mag2 = torch.abs(fft2) + _nan_check(mag1, "forward fourier mag1") + _nan_check(mag2, "forward fourier mag2") fourier_loss = ( - self.recon_loss_fn(fft1.real, fft2.real, reduction="none") - + self.recon_loss_fn(fft1.imag, fft2.imag, reduction="none") + self.recon_loss_fn(mag1, mag2, reduction="none") ) + _nan_check(fourier_loss, "forward fourier_loss (per-elem)") fourier_loss = ( self.loss_weight[t].view(-1, 1, 1) * fourier_loss ).mean() + _nan_check(fourier_loss, "forward fourier_loss (scalar)") + + + # fourier_loss = ( + # self.recon_loss_fn(fft1.real, fft2.real, reduction="none") + # + self.recon_loss_fn(fft1.imag, fft2.imag, reduction="none") + # ) + # fourier_loss = ( + # self.loss_weight[t].view(-1, 1, 1) * fourier_loss + # ).mean() return rec_loss, cond_classification_logits, fourier_loss.mean() @@ -505,18 +659,23 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: total_loss: Scalar training loss. """ ts_batch, cond_batch = batch + _nan_check(ts_batch, "training_step ts_batch") rec_loss, cond_class_logits, fourier_loss = self(ts_batch, cond_batch) - - cond_loss = 0.0 - + _nan_check(rec_loss, "training_step rec_loss") + _nan_check(fourier_loss, "training_step fourier_loss") + cond_loss = 0.0 for var_name, outputs in cond_class_logits.items(): labels = cond_batch[var_name] + if isinstance(outputs, torch.Tensor): + _nan_check(outputs, f"training_step cond_logits[{var_name}]") + if isinstance(labels, torch.Tensor): + _nan_check(labels, f"training_step cond_labels[{var_name}]") if var_name in self.continuous_context_vars: loss = F.mse_loss(outputs, labels.float()) elif var_name in self.categorical_context_vars: loss = self.auxiliary_loss(outputs, labels) - + _nan_check(loss, f"training_step cond_loss[{var_name}]") cond_loss += loss.mean() # # if var_name in self.continuous_context_vars: @@ -528,23 +687,19 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: # cond_loss /= len(cond_class_logits) h, _ = self._get_context_embedding(cond_batch) + _nan_check(h, "training_step h (for tc)") tc_term = ( self.cfg.model.tc_loss_weight * total_correlation(h) if self.cfg.model.tc_loss_weight > 0.0 else torch.tensor(0.0, device=self.device) ) + _nan_check(tc_term, "training_step tc_term") total_loss = ( rec_loss + self.context_reconstruction_loss_weight * cond_loss + tc_term + fourier_loss * self.ff_weight ) - - # Check for NaN in total loss - # if torch.isnan(total_loss) or torch.isinf(total_loss): - # raise ValueError( - # f"NaN/Inf detected in total_loss at batch {batch_idx}. " - # f"rec_loss: {rec_loss.item():.6f}, cond_loss: {cond_loss:.6f}, tc_term: {tc_term.item():.6f}" - # ) - + _nan_check(total_loss, f"training_step total_loss batch_idx={batch_idx}") + self.log_dict( { "train_loss": total_loss.item(), @@ -652,6 +807,7 @@ def _predict_x0_from_xt_with_grad( """ trend, season = self.model(x_t, t, padding_masks=None, cond=embedding) x_start = self._decode_to_x0((trend + season).contiguous()) + _nan_check(x_start, "_predict_x0_from_xt_with_grad x_start") return x_start @torch.no_grad() @@ -668,6 +824,8 @@ def model_predictions( trend, season = self.model(x, t, padding_masks=None, cond=embedding) x_start = self._decode_to_x0((trend + season).contiguous()) pred_noise = self.predict_noise_from_start(x, t, x_start) + _nan_check(x_start, "model_predictions x_start") + _nan_check(pred_noise, "model_predictions pred_noise") return pred_noise, x_start @staticmethod @@ -694,6 +852,7 @@ def p_mean_variance( """ pred_noise, x_start = self.model_predictions(x, t, embedding) pm, pv, plv = self.q_posterior(x_start, x, t) + _nan_check(x_start, "p_mean_variance x_start") return pm, pv, plv, x_start @torch.no_grad() @@ -705,8 +864,14 @@ def p_sample( """ bt = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long) pm, pv, plv, _ = self.p_mean_variance(x, bt, embedding) - noise = torch.randn_like(x) if t > 0 else 0 - return pm + (0.5 * plv).exp() * noise + noise = ( + blueish_noise_like(x, power=self.blue_noise_power, correlated=self.correlated_noise) + if t > 0 + else 0 + ) + out = pm + (0.5 * plv).exp() * noise + _nan_check(out, "p_sample output") + return out def _reconstruction_guided_step_alg1( self, @@ -726,23 +891,37 @@ def _reconstruction_guided_step_alg1( x_t = x_t.detach().requires_grad_(True) x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding) + _nan_check(x_start, "_reconstruction_guided_step_alg1 x_start") x_hat_a = x_start[:, :cond_len] L_1 = (x_a - x_hat_a).pow(2).mean() + _nan_check(L_1, "_reconstruction_guided_step_alg1 L_1") pm, pv, plv = self.q_posterior(x_start, x_t, bt) - noise = torch.randn_like(x_t, device=x_t.device) if t > 0 else torch.zeros_like(x_t, device=x_t.device) + noise = ( + blueish_noise_like(x_t, power=self.blue_noise_power, correlated=self.correlated_noise) + if t > 0 + else 0 + ) x_prev_initial = (pm + (0.5 * plv).exp() * noise).detach() L_2 = ((x_prev_initial - pm).pow(2) / pv.clamp(min=1e-8)).mean() + _nan_check(L_2, "_reconstruction_guided_step_alg1 L_2") loss = L_1 + gamma * L_2 + _nan_check(loss, "_reconstruction_guided_step_alg1 loss") loss.backward() with torch.no_grad(): - # x̃_0 = x̂_0 + η ∇_{x_t}(L_1 + γ*L_2); gradient has same shape as x_t (B,L,C) = x̂_0 x_tilde_0 = x_start.detach() + eta * x_t.grad + _nan_check(x_t.grad, "_reconstruction_guided_step_alg1 x_t.grad") + _nan_check(x_tilde_0, "_reconstruction_guided_step_alg1 x_tilde_0") pm_final, pv_final, plv_final = self.q_posterior(x_tilde_0, x_t.detach(), bt) - noise_final = torch.randn_like(x_t, device=x_t.device) if t > 0 else torch.zeros_like(x_t, device=x_t.device) + noise_final = ( + _randn_like_correlated(x_t, self.correlated_noise) + if t > 0 + else torch.zeros_like(x_t, device=x_t.device) + ) x_prev = pm_final + (0.5 * plv_final).exp() * noise_final x_prev = self._replace_conditional(x_a, x_prev, cond_len) + _nan_check(x_prev, "_reconstruction_guided_step_alg1 x_prev") return x_prev def _reconstruction_guided_step_alg2( @@ -766,26 +945,37 @@ def _reconstruction_guided_step_alg2( for _ in range(K): x_t = x_t.requires_grad_(True) x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach) + _nan_check(x_start, "_reconstruction_guided_step_alg2 x_start (inner)") x_hat_a = x_start[:, :cond_len] L_1 = (x_a - x_hat_a).pow(2).mean() - pm, pv, plv = self.q_posterior(x_start, x_t, bt) - noise = torch.randn_like(x_t, device=x_t.device) if t > 0 else torch.zeros_like(x_t, device=x_t.device) + noise = ( + blueish_noise_like(x_t, power=self.blue_noise_power, correlated=self.correlated_noise) + if t > 0 + else 0 + ) x_prev_initial = (pm + (0.5 * plv).exp() * noise).detach() L_2 = ((x_prev_initial - pm).pow(2) / pv.clamp(min=1e-8)).mean() - loss = L_1 + gamma * L_2 + _nan_check(loss, "_reconstruction_guided_step_alg2 loss (inner)") loss.backward() with torch.no_grad(): + _nan_check(x_t.grad, "_reconstruction_guided_step_alg2 x_t.grad") x_t = x_t + eta * x_t.grad x_t = x_t.detach() with torch.no_grad(): x_start_final = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach) + _nan_check(x_start_final, "_reconstruction_guided_step_alg2 x_start_final") pm_final, pv_final, plv_final = self.q_posterior(x_start_final, x_t, bt) - noise_final = torch.randn_like(x_t, device=x_t.device) if t > 0 else torch.zeros_like(x_t, device=x_t.device) + noise_final = ( + blueish_noise_like(x_t, power=self.blue_noise_power, correlated=self.correlated_noise) + if t > 0 + else 0 + ) x_prev = pm_final + (0.5 * plv_final).exp() * noise_final x_prev = self._replace_conditional(x_a, x_prev, cond_len) + _nan_check(x_prev, "_reconstruction_guided_step_alg2 x_prev") return x_prev def sample_reconstruction_guided( @@ -824,7 +1014,9 @@ def sample_reconstruction_guided( eta = self.recon_guide_eta gamma = self.recon_guide_gamma - x = torch.randn(shape, device=self.device) + x = _randn_shape_correlated( + shape, self.device, torch.get_default_dtype(), self.correlated_noise + ) embedding, _ = self._get_context_embedding(context_vars) x_a = x_a.to(self.device) @@ -856,10 +1048,13 @@ def sample(self, shape: Tuple[int, int, int], context_vars: dict) -> torch.Tenso Returns: Generated samples tensor. """ - x = torch.randn(shape, device=self.device) + x = _randn_shape_correlated( + shape, self.device, torch.get_default_dtype(), self.correlated_noise + ) embedding, _ = self._get_context_embedding(context_vars) for t in reversed(range(self.num_timesteps)): x = self.p_sample(x, t, embedding) + _nan_check(x, "sample() output") return x @torch.no_grad() @@ -869,7 +1064,9 @@ def fast_sample( """ Faster sampling using a reduced number of timesteps. """ - x = torch.randn(shape, device=self.device) + x = _randn_shape_correlated( + shape, self.device, torch.get_default_dtype(), self.correlated_noise + ) embedding, _ = self._get_context_embedding(context_vars) times = torch.linspace( -1, self.num_timesteps - 1, steps=self.sampling_timesteps + 1 @@ -881,6 +1078,7 @@ def fast_sample( pred_noise, x_start = self.model_predictions(x, bt, embedding) if time_next < 0: x = x_start + _nan_check(x, "fast_sample x (final step)") continue alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next] @@ -889,8 +1087,10 @@ def fast_sample( * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() ) c = (1 - alpha_next - sigma**2).sqrt() - noise = torch.randn_like(x) + noise = _randn_like_correlated(x, self.correlated_noise) x = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise + _nan_check(x, "fast_sample x (mid)") + _nan_check(x, "fast_sample x (final)") return x @contextmanager diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index d3fe320..bbd015e 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -9,6 +9,7 @@ """ import math +from typing import Optional import numpy as np import torch @@ -17,6 +18,24 @@ from torch import nn +def _nan_check(t: Optional[torch.Tensor], name: str, extra: str = "") -> None: + """Print location and stats when tensor contains NaN or Inf (for debugging).""" + if t is None or not isinstance(t, torch.Tensor): + return + if not (torch.isnan(t).any() or torch.isinf(t).any()): + return + nan_c = torch.isnan(t).sum().item() + inf_c = torch.isinf(t).sum().item() + finite = t[~(torch.isnan(t) | torch.isinf(t))] + min_s = finite.min().item() if finite.numel() > 0 else float("nan") + max_s = finite.max().item() if finite.numel() > 0 else float("nan") + mean_s = finite.float().mean().item() if finite.numel() > 0 else float("nan") + print( + f"[NaN/Inf] Transformer {name}: shape={tuple(t.shape)}, nan_count={nan_c}, inf_count={inf_c}, " + f"finite_min={min_s:.6g}, finite_max={max_s:.6g}, finite_mean={mean_s:.6g} {extra}".strip() + ) + + def linear_beta_schedule(timesteps: int) -> torch.Tensor: """ Create a linear schedule of betas for diffusion noise levels. @@ -219,10 +238,24 @@ def __init__(self, in_dim, out_dim, resid_pdrop=0.0): def forward(self, x): # x: (B, T, C) + _nan_check(x, "Conv_MLP forward x (initial)") + # Print when values are extreme (even if not NaN) to debug downstream NaN + # if isinstance(x, torch.Tensor): + # x_abs_max = x.abs().max().item() + # if x_abs_max > 50.0 or torch.isnan(x).any() or torch.isinf(x).any(): + # print( + # f"[Conv_MLP] x (initial): shape={tuple(x.shape)}, min={x.min().item():.6g}, max={x.max().item():.6g}, " + # f"abs_max={x_abs_max:.6g}, has_nan={torch.isnan(x).any().item()}, has_inf={torch.isinf(x).any().item()}" + # ) x = x.transpose(1, 2).contiguous() # (B, C, T) contiguous + _nan_check(x, "Conv_MLP forward x (transposed)") x = self.conv(x) + _nan_check(x, "Conv_MLP forward x (conv)") x = self.drop(x) - return x.transpose(1, 2).contiguous() # back to (B, T, C), contiguous + _nan_check(x, "Conv_MLP forward x (drop)") + out = x.transpose(1, 2).contiguous() # back to (B, T, C), contiguous + _nan_check(out, "Conv_MLP forward out (transposed)") + return out @@ -842,36 +875,55 @@ def __init__( def forward(self, input, t, padding_masks=None, return_res=False, cond=None): # cond: (B, cond_dim) or None + _nan_check(input, "forward input") t_emb = self.time_emb(t) + _nan_check(t_emb, "forward t_emb") label_emb = None if (cond is not None) and (self.cond_proj is not None): - label_emb = self.cond_proj(cond) # (B, n_embd) - # Add them up here to pass a single vector down + label_emb = self.cond_proj(cond) # (B, n_embd) + _nan_check(label_emb, "forward label_emb") total_cond_emb = self.cond_mix_mlp(torch.concat([t_emb, label_emb], dim=1)) else: total_cond_emb = t_emb + _nan_check(total_cond_emb, "forward total_cond_emb") emb = self.emb(input) + _nan_check(emb, "forward emb") inp_enc = self.pos_enc(emb) + _nan_check(inp_enc, "forward inp_enc") enc_cond = self.encoder(inp_enc, total_cond_emb, padding_masks=padding_masks) + _nan_check(enc_cond, "forward enc_cond") inp_dec = self.pos_dec(emb) + _nan_check(inp_dec, "forward inp_dec") output, mean, trend, season = self.decoder( inp_dec, total_cond_emb, enc_cond, padding_masks=padding_masks ) + _nan_check(output, "forward decoder output") + _nan_check(mean, "forward decoder mean") + _nan_check(trend, "forward decoder trend") + _nan_check(season, "forward decoder season") res = self.inverse(output) - + _nan_check(res, "forward res (inverse output)") + res_m = torch.mean(res, dim=1, keepdim=True).contiguous() + _nan_check(res_m, "forward res_m") combine_m_out = self.combine_m(mean).contiguous() + _nan_check(combine_m_out, "forward combine_m_out") combine_s_out = self.combine_s(season.transpose(1, 2)).transpose(1, 2).contiguous() + _nan_check(combine_s_out, "forward combine_s_out") season_error = (combine_s_out + res - res_m).contiguous() + _nan_check(season_error, "forward season_error") trend = (combine_m_out + res_m + trend).contiguous() + _nan_check(trend, "forward trend (final)") if return_res: - return trend, combine_s_out, res - res_m + out_res = res - res_m + _nan_check(out_res, "forward return res - res_m") + return trend, combine_s_out, out_res return trend, season_error diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index a698e22..2883906 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -238,11 +238,23 @@ def __init__( self.dynamic_module_type = self.dataset.dynamic_module_type self.stats_head_type = self.dataset.stats_head_type self.loss_type = getattr(self.normalizer_training_cfg, "loss_type", "mse") + self.use_global_stats_preprocessing = bool( + getattr(self.normalizer_training_cfg, "use_global_stats_preprocessing", True) + ) + if hasattr(self.dataset_cfg, "normalizer_use_global_stats_preprocessing"): + self.use_global_stats_preprocessing = bool( + self.dataset_cfg.normalizer_use_global_stats_preprocessing + ) + # When not using global preprocessing: mu is predicted in asinh space; clamp to avoid sinh explosion + self.max_asinh_mu = float(getattr(self.normalizer_training_cfg, "max_asinh_mu", 10.0)) + # Floor sigma and scale range so normalized values (z) cannot explode + self.min_sigma = float(getattr(self.normalizer_training_cfg, "min_sigma", 1e-3)) + self.min_scale_range = float(getattr(self.normalizer_training_cfg, "min_scale_range", 0.25)) self.register_buffer("global_mu_mean", torch.tensor(0.0)) self.register_buffer("global_mu_std", torch.tensor(1.0)) self.register_buffer("global_log_sigma_mean", torch.tensor(0.0)) - + # Create static context module self.static_context_module = None if self.static_context_vars: @@ -310,24 +322,39 @@ def setup(self, stage: Optional[str] = None): mode = getattr(self.dataset_cfg, "normalizer_stats_mode", "sample") self.sample_stats = self._build_training_samples(mode, use_quantile_scale=True) - # --- COMPUTE GLOBAL TARGET STATS FOR SCALING --- - # 1. Global Mu Stats (for Z-score scaling) - all_mus = np.concatenate([s[2] for s in self.sample_stats]) - self.target_mu_mean = torch.tensor(all_mus.mean(), dtype=torch.float32) - self.target_mu_std = torch.tensor(all_mus.std() + 1e-8, dtype=torch.float32) + if self.use_global_stats_preprocessing: + # --- COMPUTE GLOBAL TARGET STATS FOR SCALING --- + # 1. Global Mu Stats (for Z-score scaling) + all_mus = np.concatenate([s[2] for s in self.sample_stats]) + self.target_mu_mean = torch.tensor(all_mus.mean(), dtype=torch.float32) + self.target_mu_std = torch.tensor(all_mus.std() + 1e-8, dtype=torch.float32) + + # 2. Global Sigma Stats (for Log-Space Centering) + all_sigmas_concat = np.concatenate([s[3] for s in self.sample_stats]) + self.target_log_sigma_mean = torch.tensor( + np.log(all_sigmas_concat + 1e-8).mean(), dtype=torch.float32 + ) - # 2. Global Sigma Stats (for Log-Space Centering) - all_sigmas_concat = np.concatenate([s[3] for s in self.sample_stats]) - # Calculate the mean of the logs (Geometric mean center) - self.target_log_sigma_mean = torch.tensor(np.log(all_sigmas_concat + 1e-8).mean(), dtype=torch.float32) + self.global_mu_mean.fill_(self.target_mu_mean) + self.global_mu_std.fill_(self.target_mu_std) + self.global_log_sigma_mean.fill_(self.target_log_sigma_mean) - # Register buffers so they persist with model - self.global_mu_mean.fill_(self.target_mu_mean) - self.global_mu_std.fill_(self.target_mu_std) - self.global_log_sigma_mean.fill_(self.target_log_sigma_mean) + print(f"Global Target Stats: Mu Mean={self.target_mu_mean:.4f}, Mu Std={self.target_mu_std:.4f}") + print(f"Global Target Log Sigma Mean: {self.target_log_sigma_mean:.4f}") + else: + # No global preprocessing: mu in asinh space, log(sigma) direct; identity for sigma centering + self.global_mu_mean.zero_() + self.global_mu_std.fill_(1.0) + self.global_log_sigma_mean.zero_() + print(f"Normalizer: use_global_stats_preprocessing=False — predicting asinh(mu) (clamp ±{self.max_asinh_mu}) and log(sigma) directly.") - print(f"Global Target Stats: Mu Mean={self.target_mu_mean:.4f}, Mu Std={self.target_mu_std:.4f}") - print(f"Global Target Log Sigma Mean: {self.target_log_sigma_mean:.4f}") + # Global range mean for safeguard: floor rng to 0.01 * global_rng_mean when do_scale + if self.do_scale: + all_rngs = [] + for s in self.sample_stats: + zlow, zhigh = s[4], s[5] + if zlow is not None and zhigh is not None: + all_rngs.extend((np.asarray(zhigh) - np.asarray(zlow)).flatten().tolist()) # Log initial predictions if stage == "fit" or stage is None: @@ -335,27 +362,34 @@ def setup(self, stage: Optional[str] = None): def _log_initial_predictions(self): """Log initial model predictions to diagnose initialization issues.""" - self.eval() - with torch.no_grad(): - dataloader = self.train_dataloader() - batch = next(iter(dataloader)) - cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch + # self.eval() + # with torch.no_grad(): + # dataloader = self.train_dataloader() + # batch = next(iter(dataloader)) + # cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch - device = next(self.parameters()).device - cat_vars_dict = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in cat_vars_dict.items()} - mu_t = mu_t.to(device) - sigma_t = sigma_t.to(device) + # device = next(self.parameters()).device + # cat_vars_dict = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in cat_vars_dict.items()} + # mu_t = mu_t.to(device) + # sigma_t = sigma_t.to(device) - # Predict (Returns Real Unscaled values via Forward) - pred_mu, pred_sigma, pred_z_min, pred_z_max, _ = self(cat_vars_dict) + # # Predict (Returns Real Unscaled values via Forward) + # pred_mu, pred_sigma, pred_z_min, pred_z_max, _ = self(cat_vars_dict) - print(f"\n[Initial Predictions]") - print(f" Target mu: mean={mu_t.mean().item():.4f}, std={mu_t.std().item():.4f}") - print(f" Predicted mu: mean={pred_mu.mean().item():.4f}, std={pred_mu.std().item():.4f}") - print(f" Initial loss_mu: {F.mse_loss(pred_mu, mu_t).item():.6f}") - print() + # print(f"\n[Initial Predictions]") + # print(f" Target mu: mean={mu_t.mean().item():.4f}, std={mu_t.std().item():.4f}") + # print(f" Predicted mu: mean={pred_mu.mean().item():.4f}, std={pred_mu.std().item():.4f}") + # print(f" Initial loss_mu: {F.mse_loss(pred_mu, mu_t).item():.6f}") + # print() self.train() + + def _raw_mu_to_real(self, pred_mu_raw: torch.Tensor) -> torch.Tensor: + """Convert network mu output to real-world mu (handles both global and direct/asinh paths).""" + if self.use_global_stats_preprocessing: + return (pred_mu_raw * self.global_mu_std) + self.global_mu_mean + # Direct path: network predicts asinh(mu); invert with sinh and clamp for stability + return torch.sinh(torch.clamp(pred_mu_raw, -self.max_asinh_mu, self.max_asinh_mu)) def forward(self, cat_vars_dict: dict): """ @@ -365,16 +399,17 @@ def forward(self, cat_vars_dict: dict): Returns: Tuple of (pred_mu_real, pred_sigma_real, pred_z_min, pred_z_max, pred_log_sigma_raw). """ - # Get raw network outputs (scaled space) pred_mu_raw, pred_sigma, pred_zmin, pred_zmax, pred_log_sigma_raw = self.normalizer_model(cat_vars_dict) - # 1. Unscale Mu: (NetworkOutput * GlobalStd) + GlobalMean - pred_mu_real = (pred_mu_raw * self.global_mu_std) + self.global_mu_mean + pred_mu_real = self._raw_mu_to_real(pred_mu_raw) - # 2. Unscale Sigma: exp(NetworkLogOutput + GlobalLogMean) - # Note: We reconstruct log_sigma first to ensure numerical stability + # Unscale Sigma: exp(NetworkLogOutput + GlobalLogMean) or exp(NetworkLogOutput) when no global pred_log_sigma_real = pred_log_sigma_raw + self.global_log_sigma_mean - pred_sigma_real = torch.exp(pred_log_sigma_real) + pred_sigma_real = torch.exp(pred_log_sigma_real).clamp(min=self.min_sigma) + # Safeguard: sigma must be at least 0.01 * global sigma (exp(global_log_sigma_mean)) + sigma_global = torch.exp(self.global_log_sigma_mean) + sigma_floor = max(self.min_sigma, (0.01 * sigma_global).item()) + pred_sigma_real = torch.clamp(pred_sigma_real, min=sigma_floor) return pred_mu_real, pred_sigma_real, pred_zmin, pred_zmax, pred_log_sigma_raw @@ -399,12 +434,16 @@ def training_step(self, batch, batch_idx: int): # This gives us values that match the standardized targets pred_mu_raw, _, pred_z_min, pred_z_max, pred_log_sigma_raw = self.normalizer_model(context_vars_dict) - # 2. Scale Targets to match Network Space - - # Scale Mu: Z-score - mu_t_scaled = (mu_t - self.global_mu_mean) / self.global_mu_std - - # Scale Sigma: Log-Space Centering + # 2. Targets: scaled (global) or asinh(mu) + log(sigma) (direct) + if self.use_global_stats_preprocessing: + mu_t_scaled = (mu_t - self.global_mu_mean) / self.global_mu_std + else: + # asinh(mu) compresses full range to ~[-8, 8]; clamp so target matches forward clamp + mu_t_scaled = torch.clamp( + torch.asinh(mu_t), + -self.max_asinh_mu, + self.max_asinh_mu, + ) target_log_sigma_centered = torch.log(sigma_t + 1e-8) - self.global_log_sigma_mean # 3. Compute Loss @@ -436,8 +475,7 @@ def training_step(self, batch, batch_idx: int): if batch_idx % 100 == 0: with torch.no_grad(): - # Reconstruct real values for logging intelligibility - pred_mu_real = (pred_mu_raw * self.global_mu_std) + self.global_mu_mean + pred_mu_real = self._raw_mu_to_real(pred_mu_raw) self.log("pred_mu_mean_real", pred_mu_real.mean(), on_step=True, on_epoch=False) self.log("target_mu_mean_real", mu_t.mean(), on_step=True, on_epoch=False) @@ -543,8 +581,13 @@ def __getitem__(self, idx: int): mu_t = torch.from_numpy(mu_arr).float() sigma_t = torch.from_numpy(sigma_arr).float() - zmin_t = torch.from_numpy(zmin_arr).float() if self.do_scale else None - zmax_t = torch.from_numpy(zmax_arr).float() if self.do_scale else None + if self.do_scale: + zmin_t = torch.from_numpy(zmin_arr).float() + zmax_t = torch.from_numpy(zmax_arr).float() + else: + # Return dummy tensors so DataLoader collate does not see None + zmin_t = torch.zeros_like(mu_t) + zmax_t = torch.zeros_like(mu_t) return context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t return _TrainSet(self.sample_stats, self.dynamic_context_vars, self.do_scale, self.dataset_cfg) @@ -582,25 +625,16 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: arr = np.asarray(row[col], dtype=np.float32) - if sigma[d] < 1e-4: - print(f"[EXPLOSION ALERT] Row {i}, Col '{col}'") - print(f" -> Sigma is tiny: {sigma[d]:.8f}") - print(f" -> Raw Data Range: {arr.min()} to {arr.max()}") - print(f" -> This will multiply your data by {1/(sigma[d]+1e-8):.0f}x!") - - z = (arr - mu[d]) / (sigma[d] + 1e-8) + sigma_floor = max(self.min_sigma, 0.01 * np.exp(self.global_log_sigma_mean.cpu().item())) + sigma_eff = max(float(sigma[d]), sigma_floor) + z = (arr - mu[d]) / sigma_eff if self.do_scale: zmin_, zmax_ = pred_zmin[0, d].item(), pred_zmax[0, d].item() rng = (zmax_ - zmin_) + 1e-8 - z = (z - zmin_) / rng + rng_floor = max(self.min_scale_range, .25) + rng_eff = max(rng, rng_floor) + z = (z - zmin_) / rng_eff - if rng < 1e-4: - print(f"[EXPLOSION ALERT] Row {i}, Col '{col}'") - print(f" -> Range Collapsed: z_min={zmin_:.4f}, z_max={zmax_:.4f}") - print(f" -> Range delta: {rng:.8f}") - print(f" -> This will multiply your data by {1/(rng+1e-8):.0f}x!") - - df_out.at[i, col] = z return df_out @@ -632,12 +666,14 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: for d, col in enumerate(self.time_series_cols): z = np.asarray(row[col], dtype=np.float32) + sigma_floor = max(self.min_sigma, 0.01 * np.exp(self.global_log_sigma_mean.cpu().item())) + sigma_eff = max(float(sigma[d]), sigma_floor) if self.do_scale: zmin_, zmax_ = pred_zmin[0, d].item(), pred_zmax[0, d].item() rng = (zmax_ - zmin_) + 1e-8 - z = z * rng + zmin_ - - arr = z * (sigma[d] + 1e-8) + mu[d] + rng_eff = max(rng, self.min_scale_range) + z = z * rng_eff + zmin_ + arr = z * sigma_eff + mu[d] df_out.at[i, col] = arr return df_out diff --git a/cents/utils/utils.py b/cents/utils/utils.py index fb108a5..5c4b3cf 100644 --- a/cents/utils/utils.py +++ b/cents/utils/utils.py @@ -16,6 +16,7 @@ def _ckpt_name( static_module_type: str = None, stats_head_type: str = None, dynamic_module_type: str = None, + use_global_stats_preprocessing: bool = True, ) -> str: """ Generate checkpoint filename with optional context_module_type and stats_head_type. @@ -25,8 +26,10 @@ def _ckpt_name( model: Model name dims: Number of dimensions ext: File extension (default: "ckpt") - context_module_type: Optional context module type (e.g., "mlp", "sep_mlp") + static_module_type: Optional context module type (e.g., "mlp", "sep_mlp") stats_head_type: Optional stats head type (e.g., "mlp") + dynamic_module_type: Optional dynamic module type + use_global_stats_preprocessing: If False, suffix "noglobal" so direct-prediction normalizer uses a separate cache. Returns: Formatted checkpoint filename @@ -42,6 +45,9 @@ def _ckpt_name( if dynamic_module_type: parts.append(f"dyn{dynamic_module_type}") + if not use_global_stats_preprocessing: + parts.append("noglobal") + return "_".join(parts) + f".{ext}" diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index bce61dd..67ec2a7 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -1,7 +1,6 @@ import logging import os from pathlib import Path -from typing import Tuple import json import torch @@ -23,8 +22,8 @@ level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) -DATASET_OVERRIDES = ["normalize=False"] -PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] +DATASET_OVERRIDES = ["max_samples=10000", "normalize=False"] +PECAN_OVERRIDES = ["time_series_dims=2", "user_group=pv_users"] CONFIG_DATASET_DIR = Path(__file__).resolve().parent.parent / "cents" / "config" / "dataset" @@ -43,52 +42,78 @@ def _load_dataset_config(dataset_name: str, overrides: list) -> OmegaConf: return cfg -def _infer_dataset_shape_from_ckpt( - ckpt_path: str, cond_emb_dim: int -) -> Tuple[int, int]: - """ - Infer seq_len and time_series_dims from a Diffusion_TS checkpoint state_dict - so the model can be built with the same architecture as when the checkpoint was saved. - """ - ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) - state_dict = ckpt.get("state_dict", ckpt) - # Keys may be "model.pos_enc.pe" (Lightning) or "pos_enc.pe" (raw) - for pe_key in ("model.pos_enc.pe", "pos_enc.pe"): - if pe_key in state_dict: - # shape (1, seq_len, d_model) - seq_len = int(state_dict[pe_key].shape[1]) - break - else: - raise ValueError( - "Could not infer seq_len from checkpoint (no pos_enc.pe key in state_dict)" - ) - # combine_s: Conv1d(n_embd, n_feat, ...) -> weight shape (n_feat, n_embd, k) - # n_feat = time_series_dims + cond_emb_dim - for cs_key in ("model.combine_s.weight", "combine_s.weight"): - if cs_key in state_dict: - n_feat = int(state_dict[cs_key].shape[0]) - time_series_dims = n_feat - cond_emb_dim - if time_series_dims < 1: - time_series_dims = 1 - break - else: - raise ValueError( - "Could not infer time_series_dims from checkpoint (no combine_s.weight in state_dict)" - ) - return seq_len, time_series_dims - - -def _load_dataset(name: str, dataset_cfg: OmegaConf): - """Load a dataset by name using dataset-specific config (from config/dataset/{name}.yaml).""" +def _load_dataset(name: str, dataset_cfg: OmegaConf, run_dir: str = None): + """Load a dataset by name using dataset-specific config. Optionally pass run_dir for normalizer/cache.""" + kwargs = {"cfg": dataset_cfg} + if run_dir is not None: + kwargs["run_dir"] = run_dir if name == "pecanstreet": - return PecanStreetDataset(cfg=dataset_cfg) + return PecanStreetDataset(**kwargs) if name == "commercial": - return CommercialDataset(cfg=dataset_cfg) + return CommercialDataset(**kwargs) if name == "airquality": - return AirQualityDataset(cfg=dataset_cfg) + return AirQualityDataset(**kwargs) raise ValueError(f"Dataset {name} not supported. Use: pecanstreet, commercial, airquality.") +def _find_checkpoint_by_epoch(checkpoint_dir: Path, epoch: int) -> Path: + """Return path to a checkpoint matching the given epoch (e.g. epoch=0699). Prefer 4-digit zero-padded.""" + checkpoint_dir = Path(checkpoint_dir) + if not checkpoint_dir.exists(): + raise FileNotFoundError(f"Checkpoint dir not found: {checkpoint_dir}") + # Lightning ModelCheckpoint with epoch in filename: ..._epoch=0699.ckpt or ..._epoch=699.ckpt + pattern_4 = f"*epoch={epoch:04d}*" + pattern_1 = f"*epoch={epoch}*" + for pattern in (pattern_4, pattern_1): + matches = list(checkpoint_dir.glob(pattern)) + if matches: + return matches[0] + raise FileNotFoundError(f"No checkpoint for epoch {epoch} in {checkpoint_dir}") + + +def _resolve_run_path_config(run_path: Path, epoch: int = None): + """ + Load configs from run_path/config/ (or run_path/summary.yaml for older runs) and resolve checkpoint and epoch. + Returns (dataset_cfg, model_cfg, context_path, model_ckpt_path, normalizer_dir, metrics_epoch). + metrics_epoch is the epoch number or 'last' for saving metrics_{epoch}.json. + """ + run_path = Path(run_path) + config_dir = run_path / "config" + if config_dir.exists(): + dataset_cfg = OmegaConf.load(str(config_dir / "dataset.yaml")) + model_cfg = OmegaConf.load(str(config_dir / "model.yaml")) + context_path = config_dir / "context.yaml" + if not context_path.exists(): + context_path = None + else: + summary_path = run_path / "summary.yaml" + if not summary_path.exists(): + raise FileNotFoundError( + f"Run has no config/ or summary.yaml at {run_path}. Train with current code to write configs." + ) + summary = load_yaml(str(summary_path)) + dataset_cfg = OmegaConf.create(summary.get("dataset", {})) + model_cfg = OmegaConf.create(summary.get("model", {})) + config_dir.mkdir(parents=True, exist_ok=True) + if summary.get("context"): + OmegaConf.save(OmegaConf.create(summary["context"]), str(config_dir / "context.yaml")) + context_path = config_dir / "context.yaml" if summary.get("context") else None + checkpoint_dir = run_path / "checkpoints" + if epoch is not None: + model_ckpt_path = _find_checkpoint_by_epoch(checkpoint_dir, epoch) + metrics_epoch = epoch + else: + last_ckpt = checkpoint_dir / "last.ckpt" + if not last_ckpt.exists(): + raise FileNotFoundError( + f"No last.ckpt in {checkpoint_dir}. Specify --epoch or ensure training saved last.ckpt." + ) + model_ckpt_path = last_ckpt + metrics_epoch = "last" + normalizer_dir = run_path / "normalizer" + return dataset_cfg, model_cfg, context_path, model_ckpt_path, normalizer_dir, metrics_epoch + + def main() -> None: parser = argparse.ArgumentParser( description="Evaluate a trained model using comprehensive metrics.", @@ -130,7 +155,7 @@ def main() -> None: type=str, nargs="*", default=[], - help="Extra dataset overrides, e.g. time_series_dims=1.", + help="Extra dataset overrides, e.g. time_series_dims=2 for multivariate. Config after overrides sets model shape.", ) parser.add_argument( "--save-dir", @@ -192,113 +217,133 @@ def main() -> None: default=None, help="Path to custom context config YAML file (optional).", ) + parser.add_argument( + "--no-normalizer-global-preprocessing", + action="store_true", + help="Use normalizer without global-stats preprocessing (match training that used --no-normalizer-global-preprocessing).", + ) + parser.add_argument( + "--run-path", + type=str, + default=None, + help="Path to a run directory (e.g. runs/commercial_noglobal). Load config from run/config/, checkpoint from run/checkpoints/. Saves metrics to run/metrics/metrics_{epoch}.json.", + ) + parser.add_argument( + "--epoch", + type=int, + default=None, + help="Epoch number to evaluate when using --run-path (e.g. 699). If omitted, use last.ckpt and save as metrics_last.json.", + ) args = parser.parse_args() - if not args.model_ckpt and not args.model_key: - parser.error("One of --model-ckpt or --model-key is required.") - if args.model_ckpt and args.model_key: - parser.error("Use only one of --model-ckpt or --model-key.") + use_run_path = args.run_path is not None + if use_run_path and args.model_key: + parser.error("Do not use --model-key with --run-path.") + if not use_run_path and not args.model_ckpt and not args.model_key: + parser.error("One of --model-ckpt, --model-key, or --run-path is required.") + if use_run_path and args.model_ckpt: + parser.error("Do not use --model-ckpt with --run-path; checkpoint is resolved from run-path and --epoch.") # Set custom context config path if provided if args.context_config_path: set_context_config_path(args.context_config_path) - logging.info("Loading dataset %s...", args.dataset) - overrides = list(DATASET_OVERRIDES) - if args.dataset == "pecanstreet": - overrides = overrides + PECAN_OVERRIDES - # Use pretrained normalizer from checkpoint/HF: skip dataset normalizer init so it doesn't train - if args.model_key: - overrides = overrides + ["normalize=False"] - # Watts (and most pretrained) normalizers use scale=True (do_scale); match so stats_head shape loads - if args.model_key: - overrides = overrides + ["scale=True"] - if args.dataset_overrides: - overrides = overrides + list(args.dataset_overrides) - dataset_cfg = _load_dataset_config(args.dataset, overrides) - dataset = _load_dataset(args.dataset, dataset_cfg) - - # Resolve model type (from key when loading from HF, else from args) - if args.model_key: - model_type = get_model_type_from_hf_name(args.model_key) + if use_run_path: + run_path = Path(args.run_path).resolve() + logging.info("Using run-path: %s (epoch=%s)", run_path, args.epoch) + dataset_cfg, model_cfg, context_path, model_ckpt_path, normalizer_dir, metrics_epoch = _resolve_run_path_config( + run_path, args.epoch + ) + if context_path is not None: + set_context_config_path(str(context_path)) + dataset_name = dataset_cfg.get("name", "pecanstreet") + dataset = _load_dataset(dataset_name, dataset_cfg, run_dir=str(run_path)) + model_type = model_cfg.get("name", "diffusion_ts") + eval_cfg = load_yaml(args.evaluator_config) + top_cfg = load_yaml(args.config) + cfg = OmegaConf.create({}) + cfg.evaluator = eval_cfg + cfg.wandb = top_cfg.get("wandb", {}) + cfg.device = "cuda:0" + cfg.model = model_cfg + cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) + cfg.model.use_ema_sampling = args.ema + cfg.eval_pv_shift = args.eval_pv_shift if args.eval_pv_shift else eval_cfg.get("eval_pv_shift", False) + cfg.eval_metrics = False if args.no_eval_metrics else eval_cfg.get("eval_metrics", True) + cfg.eval_context_sparse = False if args.no_eval_context_sparse else eval_cfg.get("eval_context_sparse", True) + cfg.eval_disentanglement = False if args.no_eval_disentanglement else eval_cfg.get("eval_disentanglement", True) + cfg.save_results = False if args.no_save_results else True + cfg.job_name = args.job_name or eval_cfg.get("job_name", "default_job") + cfg.save_dir = run_path / "metrics" + cfg.save_dir.mkdir(parents=True, exist_ok=True) + logging.info("Loading dataset from run config (run_dir=%s)...", run_path) + logging.info("Model checkpoint: %s", model_ckpt_path) + gen = DataGenerator(model_type=model_type, dataset=dataset, cfg=cfg) + gen.load_from_checkpoint(str(model_ckpt_path), normalizer_ckpt=None) + args._metrics_epoch = metrics_epoch else: - model_type = args.model_type or "diffusion_ts" - - # Load configs - eval_cfg = load_yaml(args.evaluator_config) - top_cfg = load_yaml(args.config) - - cfg = OmegaConf.create({}) - cfg.evaluator = eval_cfg - cfg.wandb = top_cfg.get("wandb", {}) - cfg.device = "cuda:2" - cfg.model = OmegaConf.create( - OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{model_type}.yaml"), resolve=True) - ) - cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) - - # print("EVAL CONFIG:") - # print(cfg) - - # When loading from a local checkpoint, infer seq_len and time_series_dims from the - # checkpoint so the model is built with the same architecture (avoids shape mismatch). - if args.model_ckpt and Path(args.model_ckpt).suffix == ".ckpt": - try: - ckpt_seq_len, ckpt_time_series_dims = _infer_dataset_shape_from_ckpt( - args.model_ckpt, cond_emb_dim=int(cfg.model.cond_emb_dim) - ) - if ckpt_seq_len != cfg.dataset.seq_len or ckpt_time_series_dims != cfg.dataset.time_series_dims: - logging.info( - "Checkpoint has seq_len=%s, time_series_dims=%s; overriding dataset config to match.", - ckpt_seq_len, ckpt_time_series_dims, - ) - cfg.dataset.seq_len = ckpt_seq_len - cfg.dataset.time_series_dims = ckpt_time_series_dims - dataset.cfg.seq_len = ckpt_seq_len - dataset.cfg.time_series_dims = ckpt_time_series_dims - dataset.seq_len = ckpt_seq_len - dataset.time_series_dims = ckpt_time_series_dims - except (KeyError, ValueError) as e: - logging.warning( - "Could not infer dataset shape from checkpoint (%s). Using eval dataset config; shape mismatch may occur.", - e, - ) - - # Set EMA sampling - cfg.model.use_ema_sampling = args.ema - - # Set evaluation flags (use config defaults if not overridden) - cfg.eval_pv_shift = args.eval_pv_shift if args.eval_pv_shift else eval_cfg.get("eval_pv_shift", False) - cfg.eval_metrics = False if args.no_eval_metrics else eval_cfg.get("eval_metrics", True) - cfg.eval_context_sparse = False if args.no_eval_context_sparse else eval_cfg.get("eval_context_sparse", True) - cfg.eval_disentanglement = False if args.no_eval_disentanglement else eval_cfg.get("eval_disentanglement", True) - cfg.save_results = False if args.no_save_results else True + args._metrics_epoch = None + logging.info("Loading dataset %s...", args.dataset) + overrides = list(DATASET_OVERRIDES) + if args.dataset == "pecanstreet": + overrides = overrides + PECAN_OVERRIDES + if args.model_key: + overrides = overrides + ["normalize=False"] + if args.model_key: + overrides = overrides + ["scale=True"] + if args.dataset_overrides: + overrides = overrides + list(args.dataset_overrides) + dataset_cfg = _load_dataset_config(args.dataset, overrides) + dataset = _load_dataset(args.dataset, dataset_cfg) + + if args.model_key: + model_type = get_model_type_from_hf_name(args.model_key) + else: + model_type = args.model_type or "diffusion_ts" - # Set job name - cfg.job_name = args.job_name if args.job_name else eval_cfg.get("job_name", "default_job") + eval_cfg = load_yaml(args.evaluator_config) + top_cfg = load_yaml(args.config) - # Set save directory - if args.save_dir: - cfg.save_dir = Path(args.save_dir) - elif args.model_key: - cfg.save_dir = Path("outputs/eval") / args.model_key - else: - model_ckpt_path = Path(args.model_ckpt) - cfg.save_dir = model_ckpt_path.parent / "eval" + cfg = OmegaConf.create({}) + cfg.evaluator = eval_cfg + cfg.wandb = top_cfg.get("wandb", {}) + cfg.device = "cuda:0" + cfg.model = OmegaConf.create( + OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{model_type}.yaml"), resolve=True) + ) + cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) + if args.no_normalizer_global_preprocessing: + cfg.dataset.normalizer_use_global_stats_preprocessing = False + + cfg.model.use_ema_sampling = args.ema + cfg.eval_pv_shift = args.eval_pv_shift if args.eval_pv_shift else eval_cfg.get("eval_pv_shift", False) + cfg.eval_metrics = False if args.no_eval_metrics else eval_cfg.get("eval_metrics", True) + cfg.eval_context_sparse = False if args.no_eval_context_sparse else eval_cfg.get("eval_context_sparse", True) + cfg.eval_disentanglement = False if args.no_eval_disentanglement else eval_cfg.get("eval_disentanglement", True) + cfg.save_results = False if args.no_save_results else True + cfg.job_name = args.job_name if args.job_name else eval_cfg.get("job_name", "default_job") + + if args.save_dir: + cfg.save_dir = Path(args.save_dir) + elif args.model_key: + cfg.save_dir = Path("outputs/eval") / args.model_key + else: + model_ckpt_path = Path(args.model_ckpt) + cfg.save_dir = model_ckpt_path.parent / "eval" - if not os.path.exists(cfg.save_dir): - os.makedirs(cfg.save_dir, exist_ok=True) - logging.info("Created evaluation directory: %s", cfg.save_dir) + if not os.path.exists(cfg.save_dir): + os.makedirs(cfg.save_dir, exist_ok=True) + logging.info("Created evaluation directory: %s", cfg.save_dir) - use_hf = args.model_key is not None - if use_hf: - logging.info("Setting up DataGenerator from HuggingFace (model_key=%s)...", args.model_key) - gen = DataGenerator(model_name=args.model_key, dataset=dataset, cfg=cfg) - else: - logging.info("Setting up DataGenerator (model_type=%s)...", model_type) - gen = DataGenerator(model_type=model_type, dataset=dataset, cfg=cfg) - logging.info("Loading checkpoint... EMA sampling %s", "enabled" if cfg.model.use_ema_sampling else "disabled") - gen.load_from_checkpoint(args.model_ckpt, args.normalizer_ckpt) + use_hf = args.model_key is not None + if use_hf: + logging.info("Setting up DataGenerator from HuggingFace (model_key=%s)...", args.model_key) + gen = DataGenerator(model_name=args.model_key, dataset=dataset, cfg=cfg) + else: + logging.info("Setting up DataGenerator (model_type=%s)...", model_type) + gen = DataGenerator(model_type=model_type, dataset=dataset, cfg=cfg) + logging.info("Loading checkpoint... EMA sampling %s", "enabled" if cfg.model.use_ema_sampling else "disabled") + gen.load_from_checkpoint(args.model_ckpt, args.normalizer_ckpt) # Ensure EMA setting is applied to the config used by the model at generate time target = getattr(gen.model, "cfg", None) or gen.cfg @@ -313,22 +358,52 @@ def main() -> None: print("=" * 60 + "\n") results = Evaluator(cfg, dataset).evaluate_model(data_generator=gen) - print("\n📊 METRICS:") + print("\n📊 METRICS (raw domain):") print("-" * 60) metrics = results.get("metrics", {}) + normalized = metrics.pop("normalized_domain", None) + + def _print_metrics(m, prefix=" "): + for key, value in m.items(): + if key == "rare_subset": + print(f"\n{prefix}rare_subset:") + _print_metrics(value, prefix=prefix + " ") + elif isinstance(value, dict) and "mean" in value and "std" in value: + print(f"{prefix}{key}: mean={value['mean']:.6f}, std={value['std']:.6f}") + elif isinstance(value, dict): + print(f"\n{prefix}{key}:") + _print_metrics(value, prefix=prefix + " ") + elif isinstance(value, (int, float)): + print(f"{prefix}{key}: {value:.6f}") + else: + print(f"{prefix}{key}: {value}") + for key, value in metrics.items(): - if isinstance(value, dict): + if isinstance(value, dict) and "rare_subset" not in value: print(f"\n{key}:") - for subkey, subval in value.items(): - print(f" {subkey}: {subval:.6f}" if isinstance(subval, (int, float)) else f" {subkey}: {subval}") + _print_metrics(value) + elif isinstance(value, dict): + print(f"\n{key}:") + _print_metrics(value) else: print(f"{key}: {value:.6f}" if isinstance(value, (int, float)) else f"{key}: {value}") + + if normalized is not None: + print("\n📊 METRICS (normalized domain, z-space — comparable across domains):") + print("-" * 60) + _print_metrics(normalized, prefix="") + metrics["normalized_domain"] = normalized # restore for save # Results are automatically saved if save_results=True - if args.save_dir: + if use_run_path and getattr(args, "_metrics_epoch", None) is not None: + metrics_file = cfg.save_dir / f"metrics_{args._metrics_epoch}.json" + with open(metrics_file, "w") as f: + json.dump(metrics, f, indent=4) + print(f"\n✅ Results saved to {metrics_file}") + elif args.save_dir: with open(Path(args.save_dir) / "metrics.json", "w") as f: json.dump(metrics, f, indent=4) - print(f"\n✅ Results saved to: {Path(args.save_dir) / "metrics.json"}") + print(f"\n✅ Results saved to {Path(args.save_dir) / "metrics.json"}") print("\n" + "=" * 60) diff --git a/scripts/train.py b/scripts/train.py index 06005b3..80a1553 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -51,6 +51,27 @@ def _write_run_summary(run_dir: Path, run_name: str, trainer: Trainer) -> None: print(f"[Cents] Wrote run summary to {path}") +def _write_run_configs(run_dir: Path, trainer: Trainer) -> None: + """Write dataset, model, context, and trainer configs to run_dir/config/ for eval and reproducibility.""" + cfg = trainer.cfg + context_cfg = get_context_config() + config_dir = run_dir / "config" + config_dir.mkdir(parents=True, exist_ok=True) + if hasattr(cfg, "dataset") and cfg.dataset: + with open(config_dir / "dataset.yaml", "w") as f: + yaml.dump(OmegaConf.to_container(cfg.dataset, resolve=True), f, default_flow_style=False, sort_keys=False) + if hasattr(cfg, "model") and cfg.model: + with open(config_dir / "model.yaml", "w") as f: + yaml.dump(OmegaConf.to_container(cfg.model, resolve=True), f, default_flow_style=False, sort_keys=False) + if hasattr(cfg, "trainer") and cfg.trainer: + with open(config_dir / "trainer.yaml", "w") as f: + yaml.dump(OmegaConf.to_container(cfg.trainer, resolve=True), f, default_flow_style=False, sort_keys=False) + if context_cfg: + with open(config_dir / "context.yaml", "w") as f: + yaml.dump(OmegaConf.to_container(context_cfg, resolve=True), f, default_flow_style=False, sort_keys=False) + print(f"[Cents] Wrote run configs to {config_dir}") + + def main(args) -> None: MODEL_NAME = args.model_name CR_LOSS_WEIGHT = args.cr_loss_weight @@ -73,6 +94,8 @@ def main(args) -> None: # Build dataset-specific overrides (key=value list; config is loaded from config/dataset/{dataset}.yaml) dataset_overrides = [f"skip_heavy_processing={args.skip_heavy_processing}"] + if getattr(args, "no_normalizer_global_preprocessing", False): + dataset_overrides.append("normalizer_use_global_stats_preprocessing=false") if args.dataset == "pecanstreet": dataset_overrides.extend(["time_series_dims=1", "user_group=all"]) dataset_cfg = _load_dataset_config(args.dataset, dataset_overrides) @@ -103,6 +126,7 @@ def main(args) -> None: trainer_overrides = [ f"run_dir={run_dir}", f"trainer.max_epochs={args.epochs}", + f"trainer.checkpoint.every_n_epochs={args.every_n_epochs}", f"trainer.strategy={args.ddp_strategy}", f"trainer.devices={args.devices}", f"trainer.eval_after_training={args.eval_after_training}", @@ -127,11 +151,13 @@ def main(args) -> None: ) _write_run_summary(run_dir, run_name, trainer) + _write_run_configs(run_dir, trainer) trainer.fit(ckpt_path=args.resume_from_checkpoint) if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--every_n_epochs", type=int, default=250) parser.add_argument("--devices", type=str, default="auto") parser.add_argument("--accelerator", type=str, default="gpu") parser.add_argument("--model_name", type=str, default="diffusion_ts") @@ -144,12 +170,12 @@ def main(args) -> None: help="Enable Weights and Biases logging") parser.add_argument("--wandb-project", type=str, default="cents") parser.add_argument("--wandb-entity", type=str, default=None) - parser.add_argument("--eval_after_training", action="store_true", + parser.add_argument("--eval-after-training", action="store_true", help="Evaluate after training") - parser.add_argument("--skip_heavy_processing", action="store_true", + parser.add_argument("--skip-heavy-processing", action="store_true", help="Skip heavy processing of dataset") parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_true") - parser.add_argument("--enable_checkpointing", action="store_true", + parser.add_argument("--enable-checkpointing", action="store_true", help="Enable checkpointing") parser.add_argument("--context-config-path", type=str, default=None, help="Path to custom context config YAML file (optional)") @@ -157,6 +183,8 @@ def main(args) -> None: help="Override context config values (e.g., 'static_context.type=mlp' 'dynamic_context.type=cnn')") parser.add_argument("--force-retrain-normalizer", action="store_true", help="Force retraining of normalizer even if cached version exists") + parser.add_argument("--no-normalizer-global-preprocessing", action="store_true", + help="Predict normalizer mu/sigma directly (no global-stats scaling). Use if commercial runs had NaNs with global preprocessing.") parser.add_argument("--resume-from-checkpoint", type=str, default=None, help="Path to checkpoint file (.ckpt) to resume training from", ) From 8fd58cf79874170f48ce0767f3c7b10cd46dde73 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Sat, 28 Feb 2026 14:07:52 -0500 Subject: [PATCH 39/50] Eval in raw, normalized domain, select vars for normalization, dynamic ctx, aqr dataset --- cents/config/dataset/airquality.yaml | 51 ++++-- cents/config/dataset/commercial.yaml | 1 + cents/config/trainer/diffusion_ts.yaml | 4 +- cents/config/trainer/normalizer.yaml | 2 +- cents/datasets/airquality.py | 171 ++++++++++++++++--- cents/datasets/timeseries_dataset.py | 44 ++--- cents/eval/eval.py | 43 +++-- cents/eval/eval_metrics.py | 45 ++++- cents/eval/t2vec/t2vec.py | 16 +- cents/models/context.py | 56 ++++--- cents/models/context_registry.py | 2 +- cents/models/diffusion_ts.py | 121 ++++++++------ cents/models/model_utils.py | 35 +++- cents/models/normalizer.py | 219 ++++++++++++++++--------- scripts/eval_pretrained.py | 65 +++++++- scripts/train.py | 6 +- 16 files changed, 638 insertions(+), 243 deletions(-) diff --git a/cents/config/dataset/airquality.yaml b/cents/config/dataset/airquality.yaml index af299be..2aeb931 100644 --- a/cents/config/dataset/airquality.yaml +++ b/cents/config/dataset/airquality.yaml @@ -1,28 +1,59 @@ name: airquality geography: null normalize: True -scale: True +scale: False use_learned_normalizer: True threshold: 8 seq_len: 24 -time_series_dims: 1 shuffle: True -skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP) -max_samples: null # Limit dataset size (null = use all data) +skip_heavy_processing: False +max_samples: null path: "./data/airquality" -numeric_context_bins: 5 -time_series_columns: "PM2.5" -data_columns: ["No", "PM2.5", "year", "month", "day", "hour", "TEMP", "DEWP", "PRES", "RAIN", "WSPM", "wd", "station"] +numeric_context_bins: 1 reduce_cardinality: False +time_series_dims: 1 +normalizer_stats_mode: group +# Normalizer conditions only on these (e.g. per-station); diffusion still gets full context_vars +normalizer_group_vars: ["station", "year", "month"] + +# Targets (what becomes the merged "timeseries" dims) +# NOTE: use PMcoarse instead of PM10 +time_series_columns: ["PM2.5"] + +# Raw CSV columns to load +# Keep wd/WSPM because we need them to engineer wind_u/wind_v +# Keep PM10 because we need it to engineer PMcoarse +data_columns: + - "No" + - "year" + - "month" + - "day" + - "hour" + - "PM2.5" + - "PM10" + - "SO2" + - "NO2" + - "CO" + - "TEMP" + - "DEWP" + - "PRES" + - "RAIN" + - "WSPM" + - "wd" + - "station" context_vars: + # static categorical year: ["categorical", 5] month: ["categorical", 12] weekday: ["categorical", 7] + station: ["categorical", 12] + + # dynamic time-series context TEMP: ["time_series", null] DEWP: ["time_series", null] PRES: ["time_series", null] RAIN: ["time_series", null] - WSPM: ["time_series", null] - wd: ["time_series", 17] - station: ["categorical", 12] \ No newline at end of file + wind_u: ["time_series", null] + wind_v: ["time_series", null] + wd_valid: ["time_series", null] \ No newline at end of file diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml index a8c2096..81fd583 100644 --- a/cents/config/dataset/commercial.yaml +++ b/cents/config/dataset/commercial.yaml @@ -17,6 +17,7 @@ metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sqft", "yearb numeric_context_bins: 5 reduce_cardinality: False normalizer_stats_mode: group +normalizer_group_vars: null context_vars: year: ["categorical", 2] diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 970321a..f401396 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -5,7 +5,7 @@ strategy: ddp_find_unused_parameters_true gradient_accumulate_every: 4 log_every_n_steps: 1 batch_size: 512 -max_epochs: 2500 +max_epochs: 1000 base_lr: 1e-4 eval_after_training: False @@ -13,7 +13,7 @@ checkpoint: save_last: True # Save final model save_top_k: 0 # 0 = only periodic saves; use >0 to also save top-k by metric every_n_train_steps: null - every_n_epochs: 250 # Save a distinct checkpoint every 250 epochs (250, 500, 750, ...) + every_n_epochs: 100 # Save a distinct checkpoint every 250 epochs (250, 500, 750, ...) lr_scheduler_params: factor: 0.5 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index d5af861..f34db7f 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -4,7 +4,7 @@ devices: 1, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 -n_epochs: 2000 +n_epochs: 500 batch_size: 4096 lr: 3e-4 save_cycle: 5000 diff --git a/cents/datasets/airquality.py b/cents/datasets/airquality.py index b0ca2cd..7b81f85 100644 --- a/cents/datasets/airquality.py +++ b/cents/datasets/airquality.py @@ -94,17 +94,96 @@ def _load_data(self): def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: data = data.copy() - data['timestamp'] = pd.to_datetime(data[["year", "month", "day", "hour"]]) - data['weekday'] = data['timestamp'].dt.day_name() - ts_cols = self.context_series_names + self.target_time_series_columns - - data = data.sort_values(['station', 'year', 'month', 'day', 'hour']) + # Timestamp + weekday + data["timestamp"] = pd.to_datetime(data[["year", "month", "day", "hour"]]) + data["weekday"] = data["timestamp"].dt.day_name() + + # ------------------------- + # Engineer wind_u / wind_v + # ------------------------- + wd_deg_map = { + "N": 0.0, "NNE": 22.5, "NE": 45.0, "ENE": 67.5, + "E": 90.0, "ESE": 112.5, "SE": 135.0, "SSE": 157.5, + "S": 180.0, "SSW": 202.5, "SW": 225.0, "WSW": 247.5, + "W": 270.0, "WNW": 292.5, "NW": 315.0, "NNW": 337.5, + } - # Map month integer to month name string as quickly as possible - months = ["January", "February", "March", "April", "May", "June", - "July", "August", "September", "October", "November", "December"] - data['month'] = data['month'].map(lambda x: months[x-1]) + has_wd = "wd" in data.columns + has_wspm = "WSPM" in data.columns + if has_wd and has_wspm: + wd_clean = data["wd"].astype(str).str.strip().str.upper() + wd_deg = wd_clean.map(wd_deg_map) + + # indicator can help if any weird labels slip in + data["wd_valid"] = wd_deg.notna().astype(np.int8) + + theta = np.deg2rad(wd_deg.fillna(0.0).to_numpy(dtype=float)) + wspm = pd.to_numeric(data["WSPM"], errors="coerce").fillna(0.0) + + # u = speed * cos(theta), v = speed * sin(theta) + # (note: choice of axes is arbitrary here; consistency matters more than convention) + data["wind_u"] = wspm * np.cos(theta) + data["wind_v"] = wspm * np.sin(theta) + + # Drop raw wind columns after engineering + data.drop(columns=["wd", "WSPM"], inplace=True) + else: + # If one is missing, don't silently create nonsense + # You can choose to raise instead if this should never happen. + if "wd" in data.columns: + data.drop(columns=["wd"], inplace=True) + if "WSPM" in data.columns: + data.drop(columns=["WSPM"], inplace=True) + + # ------------------------- + # Engineer PMcoarse + # ------------------------- + if "PM10" in data.columns and "PM2.5" in data.columns: + pm10 = pd.to_numeric(data["PM10"], errors="coerce") + pm25 = pd.to_numeric(data["PM2.5"], errors="coerce") + data["PMcoarse"] = (pm10 - pm25).clip(lower=0.0) + + # ------------------------- + # Choose time-series columns + # ------------------------- + # Context TS columns come from cfg; targets come from cfg.time_series_columns + ctx_ts = list(self.context_series_names) + tgt_ts = list(self.target_time_series_columns) + + # Replace context wind variables: remove wd/WSPM if present, add wind_u/wind_v (+wd_valid if you want) + ctx_ts = [c for c in ctx_ts if c not in ("wd", "WSPM")] + for c in ("wind_u", "wind_v"): + if c in data.columns and c not in ctx_ts: + ctx_ts.append(c) + # optional + if "wd_valid" in data.columns and "wd_valid" not in ctx_ts: + ctx_ts.append("wd_valid") + + # Replace PM10 target with PMcoarse if PM10 is in targets + if "PMcoarse" in data.columns: + tgt_ts = ["PMcoarse" if c == "PM10" else c for c in tgt_ts] + # (optional) if you *also* want to drop PM10 if it still exists + tgt_ts = [c for c in tgt_ts if c != "PM10"] + + + + ts_cols = ctx_ts + tgt_ts + + # Ensure all ts_cols exist + missing = [c for c in ts_cols if c not in data.columns] + if missing: + raise ValueError(f"Missing required time-series columns after preprocessing: {missing}") + + # Sort + data = data.sort_values(["station", "year", "month", "day", "hour"]) + + # Month name mapping (keeps your categorical month encoding behavior) + months = [ + "January", "February", "March", "April", "May", "June", + "July", "August", "September", "October", "November", "December", + ] + data["month"] = data["month"].map(lambda x: months[x - 1]) group_keys = ["station", "year", "month", "day", "weekday"] @@ -113,17 +192,48 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: .agg({c: list for c in ts_cols}) ) - # Convert lists -> numpy arrays (fast + deterministic) + # lists -> numpy arrays for c in ts_cols: grouped[c] = grouped[c].map(np.asarray) + - grouped = grouped[grouped["PM2.5"].apply(len) == self.cfg.seq_len].reset_index( - drop=True - ) + # Keep only full-length sequences + # Use the first target if possible, else fall back to first ts col + len_col = tgt_ts[0] if len(tgt_ts) > 0 else ts_cols[0] + grouped = grouped[grouped[len_col].apply(len) == self.cfg.seq_len].reset_index(drop=True) grouped = self._handle_missing_data(grouped) - - # Convert all lists in time series columns into tuples to make them hashable + + ctx_numeric = [c for c in ctx_ts if c not in self.categorical_time_series] + # Optional: handle heavy-tailed / zero-inflated channels + log1p_channels = {"RAIN"} # add more if needed + + clip_bound = 5.0 + eps = 1e-8 + # Compute global mean/std per channel over all rows and timesteps + ctx_stats = {} + for c in ctx_numeric: + # stacked shape: (N, L) + X = np.stack(grouped[c].values).astype(np.float32) + + if c in log1p_channels: + X = np.log1p(np.clip(X, a_min=0.0, a_max=None)) + + mu = float(X.mean()) + sd = float(X.std()) + if sd < 1e-6: + sd = 1.0 # avoid divide-by-zero; effectively makes it "center only" + ctx_stats[c] = (mu, sd) + + Xn = (X - mu) / (sd + eps) + Xn = np.clip(Xn, -clip_bound, clip_bound).astype(np.float32) + + grouped[c] = list(Xn) + + # (Optional) store for later inverse-transform / debugging + self.context_ts_stats_ = ctx_stats + + # arrays -> tuples (hashable) for c in ts_cols: grouped[c] = grouped[c].map(tuple) @@ -131,22 +241,39 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: def _handle_missing_data(self, data): - # Only handle missing data for numeric time series numeric_series = [c for c in self.context_series_names if c not in self.categorical_time_series] - + mask = data[numeric_series].applymap(is_all_nan).any(axis=1) if numeric_series else pd.Series([False] * len(data)) data = data[~mask] for col in numeric_series: data[col] = data[col].apply(fill_with_row_mean) - data[list(self.categorical_time_series.keys())] + # categorical time series must have no NaNs + cat_cols = list(self.categorical_time_series.keys()) + if cat_cols: + mask = data[cat_cols].applymap(is_any_nan).any(axis=1) + data = data[~mask] + + # ensure no NaNs in target series columns + for tcol in self.target_time_series_columns: + # If you replaced PM10->PMcoarse in cfg, this remains correct + if tcol in data.columns: + data = data.loc[data[tcol].apply(lambda x: not np.isnan(np.asarray(x, dtype=float)).any())] + + def row_has_low_std(row, cols, thresh=0.01): + for c in cols: + arr = np.asarray(row[c], dtype=np.float32) + if arr.std() < thresh: + return True + return False + + mask = data.apply( + lambda row: row_has_low_std(row, self.target_time_series_columns, thresh=0.01), + axis=1 + ) - mask = data[list(self.categorical_time_series.keys())].applymap(is_any_nan).any(axis=1) data = data[~mask] - - data = data.loc[data["PM2.5"].apply(lambda x: not np.isnan(x).any())] - return data diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 254204c..b2bdd59 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -93,8 +93,11 @@ def __init__( if not hasattr(self, "name"): self.name = "custom" - # Add continuous variables to context_vars if specified + # Split context vars into static and dynamic once (no future re-splits) self.continuous_vars = [k for k, v in self.cfg.context_vars.items() if v[0] == "continuous"] + categorical_vars = [k for k, v in self.cfg.context_vars.items() if v[0] == "categorical"] + self.dynamic_context_vars = [k for k, v in self.cfg.context_vars.items() if v[0] == "time_series"] + self.static_context_vars = categorical_vars + self.continuous_vars self.normalize = normalize self.scale = scale @@ -193,31 +196,36 @@ def __getitem__(self, idx: int): idx (int): Sample index. Returns: - Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: - timeseries: Tensor of shape (seq_len, dims). - - context_vars: Dict of context variable tensors. - Categorical variables are long tensors, continuous variables are float tensors. + - static_context_vars: Dict of static context tensors (categorical long, continuous float). + - dynamic_context_vars: Dict of dynamic (time_series) context tensors. """ sample = self.data.iloc[idx] timeseries = torch.tensor(sample["timeseries"], dtype=torch.float32) - - continuous_vars = getattr(self.cfg, 'continuous_context_vars', None) or [] - context_vars_dict = {} - for var in self.context_vars: - if var in continuous_vars: - # Continuous variables: keep as float + + static_context_vars_dict = {} + for var in self.static_context_vars: + if var in self.continuous_vars: val = sample[var] - # Check for NaN/Inf in the data itself if isinstance(val, (float, int)) and (not isinstance(val, bool) and (np.isnan(val) or np.isinf(val))): raise ValueError( f"NaN/Inf detected in continuous variable '{var}' in dataset at index {idx}. " f"Value: {val}. This should not happen if normalization was done correctly." ) - context_vars_dict[var] = torch.tensor(val, dtype=torch.float32) + static_context_vars_dict[var] = torch.tensor(val, dtype=torch.float32) else: - # Categorical variables: use long - context_vars_dict[var] = torch.tensor(sample[var], dtype=torch.long) - return timeseries, context_vars_dict + static_context_vars_dict[var] = torch.tensor(sample[var], dtype=torch.long) + + dynamic_context_vars_dict = {} + for var in self.dynamic_context_vars: + arr = np.asarray(sample[var]) + if var in self.categorical_time_series: + dynamic_context_vars_dict[var] = torch.from_numpy(arr).long() + else: + dynamic_context_vars_dict[var] = torch.from_numpy(arr).float() + + return timeseries, static_context_vars_dict, dynamic_context_vars_dict def __getstate__(self): """ @@ -361,15 +369,13 @@ def _encode_context_vars( Returns: Tuple of encoded DataFrame and mapping codes. """ - continuous_vars = [k for k, v in self.cfg.context_vars.items() if v[0] == "continuous"] - time_series_cols = [k for k, v in self.cfg.context_vars.items() if v[0] == "time_series"] encoded_data, mapping = encode_context_variables( data=data, columns_to_encode=self.context_vars, bins=self.numeric_context_bins, numeric_cols=self.numeric_cols, - continuous_vars=continuous_vars, - time_series_cols=time_series_cols, + continuous_vars=self.continuous_vars, + time_series_cols=self.dynamic_context_vars, categorical_time_series=self.categorical_time_series, ) diff --git a/cents/eval/eval.py b/cents/eval/eval.py index 14c0e26..cd30b02 100644 --- a/cents/eval/eval.py +++ b/cents/eval/eval.py @@ -358,13 +358,24 @@ def evaluate_subset( dataset.data = dataset.get_combined_rarity() real_data_subset = dataset.data.iloc[indices].reset_index(drop=True) continuous_vars = getattr(dataset, "continuous_vars", []) - context_vars = {} - for name in dataset.context_vars: + static_context_vars = {} + for name in dataset.static_context_vars: vals = real_data_subset[name].values dtype = torch.float32 if name in continuous_vars else torch.long - context_vars[name] = torch.tensor(vals, dtype=dtype, device=self.device) + static_context_vars[name] = torch.tensor(vals, dtype=dtype, device=self.device) + dynamic_context_vars = {} + categorical_ts = getattr(dataset, "categorical_time_series", {}) + for name in dataset.dynamic_context_vars: + vals = real_data_subset[name].values + # Dynamic module expects tensors (training path uses torch.from_numpy in dataset __getitem__) + if len(vals) and hasattr(vals[0], "__len__") and not isinstance(vals[0], (str, bytes)): + arr = np.stack([np.asarray(v, dtype=np.float32 if name not in categorical_ts else np.int64) for v in vals]) + else: + arr = np.asarray(vals, dtype=np.float32 if name not in categorical_ts else np.int64) + dtype = torch.long if name in categorical_ts else torch.float32 + dynamic_context_vars[name] = torch.tensor(arr, dtype=dtype, device=self.device) - generated_ts = model.generate(context_vars).cpu().numpy() + generated_ts = model.generate(static_context_vars, dynamic_context_vars).cpu().numpy() if generated_ts.ndim == 2: generated_ts = generated_ts.reshape( generated_ts.shape[0], -1, generated_ts.shape[1] @@ -373,8 +384,19 @@ def evaluate_subset( syn_data_subset = real_data_subset.copy() syn_data_subset["timeseries"] = list(generated_ts) - real_data_inv = dataset.inverse_transform(real_data_subset) - syn_data_inv = dataset.inverse_transform(syn_data_subset) + # When normalize=False but a pretrained normalizer was applied via apply_pretrained_normalizer, + # dataset.inverse_transform() is a no-op (it checks self.normalize). Do the inverse manually. + normalizer = getattr(dataset, "_normalizer", None) + if not getattr(dataset, "normalize", True) and normalizer is not None: + def _inv(df): + split = dataset.split_timeseries(df.copy()) + split = normalizer.inverse_transform(split) + return dataset.merge_timeseries_columns(split) + real_data_inv = _inv(real_data_subset) + syn_data_inv = _inv(syn_data_subset) + else: + real_data_inv = dataset.inverse_transform(real_data_subset) + syn_data_inv = dataset.inverse_transform(syn_data_subset) real_data_array = np.stack(real_data_inv["timeseries"]) syn_data_array = np.stack(syn_data_inv["timeseries"]) @@ -394,10 +416,11 @@ def evaluate_subset( log_prefix="[raw] ", ) - # Metrics in normalized (z) domain for cross-domain comparability (only when dataset is normalized) + # Metrics in normalized (z) domain for cross-domain comparability. + # Fires whenever a normalizer is available — whether data was pre-normalized by dataset + # init (normalize=True) or normalized in-place via apply_pretrained_normalizer (normalize=False). if ( - getattr(dataset, "normalize", False) - and getattr(dataset, "_normalizer", None) is not None + getattr(dataset, "_normalizer", None) is not None and "timeseries" in real_data_subset.columns ): real_data_norm = np.stack(real_data_subset["timeseries"].values) @@ -412,4 +435,4 @@ def evaluate_subset( self.current_results["metrics"]["normalized_domain"] = normalized_metrics if self.cfg.evaluator.eval_disentanglement: - self.compute_disentanglement_metrics(context_vars, model) + self.compute_disentanglement_metrics(static_context_vars, model) diff --git a/cents/eval/eval_metrics.py b/cents/eval/eval_metrics.py index c3d18a5..ccb0917 100644 --- a/cents/eval/eval_metrics.py +++ b/cents/eval/eval_metrics.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Dict, Tuple @@ -147,12 +148,37 @@ def Context_FID(ori_data: np.ndarray, generated_data: np.ndarray) -> float: Calculate the FID score between original and generated data representations using TS2Vec embeddings. Args: - ori_data: Original time series data. - generated_data: Generated time series data. + ori_data: Original time series data (N, seq_len) or (N, seq_len, dims). + generated_data: Generated time series data, same shape convention. Returns: float: FID score between the original and generated data representations. """ + ori_data = np.asarray(ori_data, dtype=np.float32) + generated_data = np.asarray(generated_data, dtype=np.float32) + # TS2Vec expects (n_instance, n_timestamps, n_features); ensure 3D + if ori_data.ndim == 2: + ori_data = ori_data[:, :, np.newaxis] + if generated_data.ndim == 2: + generated_data = generated_data[:, :, np.newaxis] + if ori_data.ndim != 3 or generated_data.ndim != 3: + warnings.warn( + f"Context_FID: expected 2D or 3D arrays, got ori_data.ndim={ori_data.ndim}, generated_data.ndim={generated_data.ndim}; returning nan." + ) + return float("nan") + # Require at least one non–all-NaN row so TS2Vec.fit() does not infinite-loop + ori_valid = ~np.isnan(ori_data).all(axis=2).all(axis=1) + n_valid = int(ori_valid.sum()) + if n_valid == 0: + warnings.warn( + "Context_FID: ori_data has no valid (non–all-NaN) rows; returning nan." + ) + return float("nan") + # Allow single-sample (TS2Vec will get 1 batch); only reject when 0 valid + if np.isnan(ori_data).any() or np.isnan(generated_data).any(): + warnings.warn( + "Context_FID: ori_data or generated_data contain NaN; FID may be unreliable." + ) model = TS2Vec( input_dims=ori_data.shape[-1], device=0, @@ -161,7 +187,22 @@ def Context_FID(ori_data: np.ndarray, generated_data: np.ndarray) -> float: output_dims=320, max_train_length=50000, ) + + fit_log = model.fit(ori_data, verbose=False) model.fit(ori_data, verbose=False) + + ori_rep = model.encode(ori_data, encoding_window="full_series") + gen_rep = model.encode(generated_data, encoding_window="full_series") + + idx = np.random.permutation(ori_data.shape[0]) + ori_rep = ori_rep[idx] + gen_rep = gen_rep[idx] + + if not np.isfinite(ori_rep).all() or not np.isfinite(gen_rep).all(): + return float("nan") + + return calculate_fid(ori_rep, gen_rep) + ori_represenation = model.encode(ori_data, encoding_window="full_series") gen_represenation = model.encode(generated_data, encoding_window="full_series") idx = np.random.permutation(ori_data.shape[0]) diff --git a/cents/eval/t2vec/t2vec.py b/cents/eval/t2vec/t2vec.py index 4c1aafa..b3f4b2e 100644 --- a/cents/eval/t2vec/t2vec.py +++ b/cents/eval/t2vec/t2vec.py @@ -11,6 +11,8 @@ Note: Please ensure compliance with the repository's license and credit the original authors when using or distributing this code. """ +import warnings + import numpy as np import torch import torch.nn.functional as F @@ -113,13 +115,25 @@ def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False): train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)] + if len(train_data) == 0: + warnings.warn( + "TS2Vec.fit: no valid samples after dropping all-NaN rows; returning empty loss log." + ) + return [] + train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float)) + batch_size = min(self.batch_size, len(train_dataset)) train_loader = DataLoader( train_dataset, - batch_size=min(self.batch_size, len(train_dataset)), + batch_size=batch_size, shuffle=True, drop_last=True, ) + if len(train_loader) == 0: + warnings.warn( + "TS2Vec.fit: DataLoader has 0 batches (e.g. drop_last=True with too few samples); returning empty loss log." + ) + return [] optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.lr) diff --git a/cents/models/context.py b/cents/models/context.py index b6fee2b..dcf7f72 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -74,9 +74,12 @@ def forward( classification_logits (Dict[str, Tensor]): Logits per variable, each of shape (batch_size, num_categories). """ - embeddings = [ - layer(context_vars[name]) for name, layer in self.context_embeddings.items() - ] + embeddings = [] + for name, layer in self.context_embeddings.items(): + idx = context_vars[name] + if idx.dtype in (torch.long, torch.int, torch.int32, torch.int64): + idx = idx.clamp(0, layer.num_embeddings - 1) + embeddings.append(layer(idx)) context_matrix = torch.cat(embeddings, dim=1) embedding = self.mlp(context_matrix) @@ -182,7 +185,10 @@ def forward(self, context_vars): # Process categorical variables (only those present in context_vars) for name, layer in self.context_embeddings.items(): if name in context_vars: - encodings[name] = layer(context_vars[name]) + idx = context_vars[name] + if idx.dtype in (torch.long, torch.int, torch.int32, torch.int64): + idx = idx.clamp(0, layer.num_embeddings - 1) + encodings[name] = layer(idx) # Process continuous variables (only those present in context_vars) for name, layer in self.continuous_projections.items(): @@ -351,27 +357,27 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, """ embeddings = [] - # Process categorical time series - for name in self.categorical_ts_vars.keys(): - if name in context_vars: - # Input: (batch, seq_len) with integer indices - ts_data = context_vars[name] # (batch, seq_len) - # Check for NaN/Inf in input - if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): - raise ValueError(f"NaN/Inf detected in categorical time series input '{name}'") - # Embed: (batch, seq_len) -> (batch, seq_len, embedding_dim) - embedded = self.ts_embeddings[name](ts_data) - # Transpose for CNN: (batch, embedding_dim, seq_len) - embedded = embedded.transpose(1, 2) - # Check for NaN after embedding - if torch.isnan(embedded).any() or torch.isinf(embedded).any(): - raise ValueError(f"NaN/Inf detected after embedding for '{name}'") - # Encode: (batch, embedding_dim, seq_len) -> (batch, embedding_dim) - encoded = self.ts_encoders[name](embedded) - # Check for NaN after encoding - if torch.isnan(encoded).any() or torch.isinf(encoded).any(): - raise ValueError(f"NaN/Inf detected after encoding for '{name}'") - embeddings.append(encoded) + # # Process categorical time series + # for name in self.categorical_ts_vars.keys(): + # if name in context_vars: + # # Input: (batch, seq_len) with integer indices + # ts_data = context_vars[name] # (batch, seq_len) + # # Check for NaN/Inf in input + # if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): + # raise ValueError(f"NaN/Inf detected in categorical time series input '{name}'") + # # Embed: (batch, seq_len) -> (batch, seq_len, embedding_dim) + # embedded = self.ts_embeddings[name](ts_data) + # # Transpose for CNN: (batch, embedding_dim, seq_len) + # embedded = embedded.transpose(1, 2) + # # Check for NaN after embedding + # if torch.isnan(embedded).any() or torch.isinf(embedded).any(): + # raise ValueError(f"NaN/Inf detected after embedding for '{name}'") + # # Encode: (batch, embedding_dim, seq_len) -> (batch, embedding_dim) + # encoded = self.ts_encoders[name](embedded) + # # Check for NaN after encoding + # if torch.isnan(encoded).any() or torch.isinf(encoded).any(): + # raise ValueError(f"NaN/Inf detected after encoding for '{name}'") + # embeddings.append(encoded) # Process numeric time series for name in self.numeric_ts_vars: diff --git a/cents/models/context_registry.py b/cents/models/context_registry.py index f497341..3743768 100644 --- a/cents/models/context_registry.py +++ b/cents/models/context_registry.py @@ -29,7 +29,7 @@ def get_context_module_cls(key: str, subkey: str = None) -> type: Fetch the context module class for `key` (and optionally `subkey`). Raises if not found. Args: - key: The name of the context module to retrieve (e.g., "default", "dynamic"). + key: The name of the context module to retrieve (e.g., "default", "dynamic"c). subkey: Optional subkey for two-part registration (e.g., "mlp", "cnn"). Returns: diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 79ef313..54aaa40 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -38,6 +38,7 @@ def _nan_check(t: Optional[torch.Tensor], name: str, extra: str = "") -> None: default, linear_beta_schedule, total_correlation, + cosine_beta_schedule_logsnr, ) from cents.models.registry import register_model @@ -112,7 +113,7 @@ class Diffusion_TS(GenerativeModel): Uses a Transformer backbone to predict and denoise time series over discrete diffusion timesteps. Supports EMA smoothing and configurable beta schedules. Optional reconstruction-guided sampling (Algorithms 1 & 2) - via sample_reconstruction_guided(shape, context_vars, x_a, algorithm="alg1"|"alg2") + via sample_reconstruction_guided(shape, static_context_vars, dynamic_context_vars, algorithm="alg1"|"alg2") when conditional observed data x_a is provided. Training objective (config model.training_objective): x0, epsilon, or v. @@ -203,12 +204,17 @@ def __init__(self, cfg: DictConfig): betas = linear_beta_schedule(cfg.model.n_steps) elif cfg.model.beta_schedule == "cosine": betas = cosine_beta_schedule(cfg.model.n_steps) + elif cfg.model.beta_schedule == "cosine_logsnr": + betas = cosine_beta_schedule_logsnr(cfg.model.n_steps) else: raise ValueError("Unknown beta schedule") - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + eps = 1e-5 + alphas = (1.0 - betas).double() + alphas_cumprod = torch.cumprod(alphas, dim=0).float() + alphas_cumprod = alphas_cumprod.clamp(min=eps, max=1.0 - eps) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0 - eps) + self.num_timesteps = betas.shape[0] self.sampling_timesteps = default( cfg.model.sampling_timesteps, self.num_timesteps @@ -284,7 +290,7 @@ def __init__(self, cfg: DictConfig): self.continuous_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "continuous"] self.categorical_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "categorical"] - def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict]: + def _get_context_embedding(self, static_context_vars: dict, dynamic_context_vars: dict = None) -> Tuple[torch.Tensor, dict]: """ Get combined context embedding from static and/or dynamic context modules. @@ -297,25 +303,25 @@ def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict """ embeddings = [] all_logits = {} - for k, v in context_vars.items(): - if isinstance(v, torch.Tensor): - _nan_check(v, f"_get_context_embedding context_vars[{k}]") + # for k, v in context_vars.items(): + # if isinstance(v, torch.Tensor): + # _nan_check(v, f"_get_context_embedding context_vars[{k}]") # Process static context variables if self.static_context_module is not None: # Filter static context variables - static_vars = { - k: v for k, v in context_vars.items() - if k not in getattr(self, 'dynamic_context_vars', []) - } - if static_vars: + # static_vars = { + # k: v for k, v in context_vars.items() + # if k not in getattr(self, 'dynamic_context_vars', []) + # } + if static_context_vars: device = next(self.static_context_module.parameters()).device static_vars = { k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v - for k, v in static_vars.items() + for k, v in static_context_vars.items() } # Debug: print which static input has NaN (only when we see one) - for k, v in static_vars.items(): + for k, v in static_context_vars.items(): if isinstance(v, torch.Tensor) and (torch.isnan(v).any() or torch.isinf(v).any()): nan_c = torch.isnan(v).sum().item() inf_c = torch.isinf(v).sum().item() @@ -337,7 +343,7 @@ def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict # return torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0) # return t # static_vars = {k: _sanitize(v) for k, v in static_vars.items()} - static_embedding, static_logits = self.static_context_module(static_vars) + static_embedding, static_logits = self.static_context_module(static_context_vars) _nan_check(static_embedding, "_get_context_embedding static_embedding") embeddings.append(static_embedding) all_logits.update(static_logits) @@ -345,18 +351,18 @@ def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict # Process dynamic context variables if self.dynamic_context_module is not None: # Filter dynamic context variables - dynamic_var_names = getattr(self, 'dynamic_context_vars', []) - dynamic_vars = { - k: v for k, v in context_vars.items() - if k in dynamic_var_names - } - if dynamic_vars: + # dynamic_var_names = getattr(self, 'dynamic_context_vars', []) + # dynamic_vars = { + # k: v for k, v in context_vars.items() + # if k in dynamic_var_names + # } + if dynamic_context_vars: device = next(self.dynamic_context_module.parameters()).device - dynamic_vars = { + dynamic_context_vars = { k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v - for k, v in dynamic_vars.items() + for k, v in dynamic_context_vars.items() } - dynamic_embedding, dynamic_logits = self.dynamic_context_module(dynamic_vars) + dynamic_embedding, dynamic_logits = self.dynamic_context_module(dynamic_context_vars) _nan_check(dynamic_embedding, "_get_context_embedding dynamic_embedding") embeddings.append(dynamic_embedding) all_logits.update(dynamic_logits) @@ -370,6 +376,8 @@ def _get_context_embedding(self, context_vars: dict) -> Tuple[torch.Tensor, dict embedding = embeddings[0] else: raise ValueError("No context variables provided") + if embedding.is_floating_point(): + embedding = embedding.float() _nan_check(embedding, "_get_context_embedding embedding (before dropout)") if self.training and self.context_embed_dropout.p > 0: embedding = self.context_embed_dropout(embedding) @@ -539,7 +547,7 @@ def q_posterior( _nan_check(plv, "q_posterior plv") return pm, pv, plv - def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, dict]: + def forward(self, x: torch.Tensor, static_context_vars: dict, dynamic_context_vars: dict = None) -> Tuple[torch.Tensor, dict]: """ Single forward pass: add noise, predict denoised output and compute reconstruction loss. @@ -562,7 +570,7 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di b = x.shape[0] t = torch.randint(0, self.num_timesteps, (b,), device=self.device) - embedding, cond_classification_logits = self._get_context_embedding(context_vars) + embedding, cond_classification_logits = self._get_context_embedding(static_context_vars, dynamic_context_vars) _nan_check(embedding, "forward embedding") noise = blueish_noise_like( @@ -658,15 +666,15 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: Returns: total_loss: Scalar training loss. """ - ts_batch, cond_batch = batch + ts_batch, static_context_batch, dynamic_context_batch = batch _nan_check(ts_batch, "training_step ts_batch") - rec_loss, cond_class_logits, fourier_loss = self(ts_batch, cond_batch) + rec_loss, cond_class_logits, fourier_loss = self(ts_batch, static_context_batch, dynamic_context_batch) _nan_check(rec_loss, "training_step rec_loss") _nan_check(fourier_loss, "training_step fourier_loss") cond_loss = 0.0 for var_name, outputs in cond_class_logits.items(): - labels = cond_batch[var_name] + labels = static_context_batch[var_name] if isinstance(outputs, torch.Tensor): _nan_check(outputs, f"training_step cond_logits[{var_name}]") if isinstance(labels, torch.Tensor): @@ -686,7 +694,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: # cond_loss /= len(cond_class_logits) - h, _ = self._get_context_embedding(cond_batch) + h, _ = self._get_context_embedding(static_context_batch, dynamic_context_batch) _nan_check(h, "training_step h (for tc)") tc_term = ( self.cfg.model.tc_loss_weight * total_correlation(h) @@ -980,9 +988,10 @@ def _reconstruction_guided_step_alg2( def sample_reconstruction_guided( self, - shape: Tuple[int, int, int], - context_vars: dict, x_a: torch.Tensor, + shape: Tuple[int, int, int], + static_context_vars: dict, + dynamic_context_vars: dict = None, algorithm: str = "alg1", ) -> torch.Tensor: """ @@ -990,7 +999,8 @@ def sample_reconstruction_guided( Args: shape: (batch_size, seq_len, time_series_dims). - context_vars: context conditioning dict. + static_context_vars: static context conditioning dict. + dynamic_context_vars: dynamic context conditioning dict. x_a: Conditional (observed) data, shape (B, cond_len, C). First cond_len time steps to reconstruct; model output is split as x̂_0 = [x̂_a, x̂_b]. algorithm: "alg1" (one gradient step per t) or "alg2" (K inner steps per t). @@ -1013,11 +1023,10 @@ def sample_reconstruction_guided( assert x_a.shape[0] == shape[0] and x_a.shape[2] == shape[2] eta = self.recon_guide_eta gamma = self.recon_guide_gamma - x = _randn_shape_correlated( - shape, self.device, torch.get_default_dtype(), self.correlated_noise + shape, self.device, torch.float32, self.correlated_noise ) - embedding, _ = self._get_context_embedding(context_vars) + embedding, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) x_a = x_a.to(self.device) for t in reversed(range(self.num_timesteps)): @@ -1037,21 +1046,22 @@ def sample_reconstruction_guided( return x @torch.no_grad() - def sample(self, shape: Tuple[int, int, int], context_vars: dict) -> torch.Tensor: + def sample(self, shape: Tuple[int, int, int], static_context_vars: dict, dynamic_context_vars: dict = None) -> torch.Tensor: """ Full reverse-pass sampling over all timesteps. Args: shape: (batch_size, seq_len, dims) - context_vars: context conditioning dict + static_context_vars: static context conditioning dict + dynamic_context_vars: dynamic context conditioning dict Returns: Generated samples tensor. """ x = _randn_shape_correlated( - shape, self.device, torch.get_default_dtype(), self.correlated_noise + shape, self.device, torch.float32, self.correlated_noise ) - embedding, _ = self._get_context_embedding(context_vars) + embedding, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) for t in reversed(range(self.num_timesteps)): x = self.p_sample(x, t, embedding) _nan_check(x, "sample() output") @@ -1059,15 +1069,16 @@ def sample(self, shape: Tuple[int, int, int], context_vars: dict) -> torch.Tenso @torch.no_grad() def fast_sample( - self, shape: Tuple[int, int, int], context_vars: dict + self, shape: Tuple[int, int, int], static_context_vars: dict, + dynamic_context_vars: dict = None ) -> torch.Tensor: """ Faster sampling using a reduced number of timesteps. """ x = _randn_shape_correlated( - shape, self.device, torch.get_default_dtype(), self.correlated_noise + shape, self.device, torch.float32, self.correlated_noise ) - embedding, _ = self._get_context_embedding(context_vars) + embedding, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) times = torch.linspace( -1, self.num_timesteps - 1, steps=self.sampling_timesteps + 1 ) @@ -1105,40 +1116,46 @@ def ema_scope(self): else: yield - def generate(self, context_vars: dict) -> torch.Tensor: + def generate(self, static_context_vars: dict, dynamic_context_vars: dict = None) -> torch.Tensor: """ Public entry to generate conditioned samples in batches. Args: - context_vars: dict of context tensors for each sample. + static_context_vars: dict of context tensors for each sample. + dynamic_context_vars: dict of dynamic context tensors for each sample. Returns: Complete generated tensor of shape (N, seq_len, dims). """ bs = self.cfg.model.sampling_batch_size - total = len(next(iter(context_vars.values()))) + total = len(next(iter(static_context_vars.values()))) generated_samples = [] with self.ema_scope(): for start_idx in tqdm( range(0, total, bs), unit="seq", - desc="[CENTS] Generating samples", + desc="[CENTS] Generating samples", leave=True, ): end_idx = min(start_idx + bs, total) - batch_context_vars = { + batch_static_context_vars = { var_name: var_tensor[start_idx:end_idx] - for var_name, var_tensor in context_vars.items() + for var_name, var_tensor in static_context_vars.items() } + batch_dynamic_context_vars = { + var_name: var_tensor[start_idx:end_idx] + for var_name, var_tensor in dynamic_context_vars.items() + } + current_bs = end_idx - start_idx shape = (current_bs, self.seq_len, self.time_series_dims) with torch.no_grad(): if self.fast_sampling: - samples = self.fast_sample(shape, batch_context_vars) + samples = self.fast_sample(shape, batch_static_context_vars, batch_dynamic_context_vars) else: - samples = self.sample(shape, batch_context_vars) + samples = self.sample(shape, batch_static_context_vars, batch_dynamic_context_vars) generated_samples.append(samples.cpu()) diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index bbd015e..dd95cb0 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -70,6 +70,23 @@ def cosine_beta_schedule(timesteps: int, s: float = 0.004) -> torch.Tensor: betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) +import math +import torch + + +def cosine_beta_schedule_logsnr( + timesteps: int, + logsnr_min: float = -8.0, + logsnr_max: float = 15.0, + eps: float = 1e-5, +) -> torch.Tensor: + t = torch.linspace(0, 1, timesteps + 1, dtype=torch.float32) + logsnr = logsnr_max + 0.5 * (logsnr_min - logsnr_max) * (1 - torch.cos(math.pi * t)) + + alpha_bar = torch.sigmoid(logsnr).clamp(eps, 1.0 - eps) + + betas = 1 - (alpha_bar[1:] / alpha_bar[:-1]) + return torch.clip(betas, 0, 0.999) def exists(x): return x is not None @@ -239,14 +256,11 @@ def __init__(self, in_dim, out_dim, resid_pdrop=0.0): def forward(self, x): # x: (B, T, C) _nan_check(x, "Conv_MLP forward x (initial)") - # Print when values are extreme (even if not NaN) to debug downstream NaN - # if isinstance(x, torch.Tensor): - # x_abs_max = x.abs().max().item() - # if x_abs_max > 50.0 or torch.isnan(x).any() or torch.isinf(x).any(): - # print( - # f"[Conv_MLP] x (initial): shape={tuple(x.shape)}, min={x.min().item():.6g}, max={x.max().item():.6g}, " - # f"abs_max={x_abs_max:.6g}, has_nan={torch.isnan(x).any().item()}, has_inf={torch.isinf(x).any().item()}" - # ) + # # Print when values are extreme (even if not NaN) to debug downstream NaN + # print( + # f"[Conv_MLP] x (initial): shape={tuple(x.shape)}, min={x.min().item():.6g}, max={x.max().item():.6g}, " + # f"abs_max={x.abs().max().item():.6g}, has_nan={torch.isnan(x).any().item()}, has_inf={torch.isinf(x).any().item()}" + # ) x = x.transpose(1, 2).contiguous() # (B, C, T) contiguous _nan_check(x, "Conv_MLP forward x (transposed)") x = self.conv(x) @@ -875,6 +889,11 @@ def __init__( def forward(self, input, t, padding_masks=None, return_res=False, cond=None): # cond: (B, cond_dim) or None + # Ensure float32 so conv/linear (float32 params) never see double input + if input.is_floating_point(): + input = input.float() + if cond is not None and cond.is_floating_point(): + cond = cond.float() _nan_check(input, "forward input") t_emb = self.time_emb(t) _nan_check(t_emb, "forward t_emb") diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 2883906..37a4d54 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -117,13 +117,11 @@ def __init__( time_series_dims: int = 2, do_scale: bool = True, stats_head_type: str = "mlp", - dynamic_var_names: list[str] = None, n_layers: int = 3, ): super().__init__() self.static_cond_module = static_cond_module self.dynamic_cond_module = dynamic_cond_module - self.dynamic_var_names = dynamic_var_names # Determine embedding dimension from available modules if static_cond_module is not None: @@ -152,38 +150,38 @@ def __init__( n_layers=n_layers, ) - def forward(self, context_vars_dict: dict): + def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_dict: dict = None): embeddings = [] # Process static context variables if self.static_cond_module is not None: - static_vars = { - k: v for k, v in context_vars_dict.items() - if k not in getattr(self, '_dynamic_var_names', []) - } - if static_vars: + # static_vars = { + # k: v for k, v in context_vars_dict.items() + # if k not in getattr(self, '_dynamic_var_names', []) + # } + if static_context_vars_dict: device = next(self.static_cond_module.parameters()).device - static_vars = { + static_context_vars_dict = { k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v - for k, v in static_vars.items() + for k, v in static_context_vars_dict.items() } - static_embedding, _ = self.static_cond_module(static_vars) + static_embedding, _ = self.static_cond_module(static_context_vars_dict) embeddings.append(static_embedding) # Process dynamic context variables if self.dynamic_cond_module is not None: - dynamic_var_names = getattr(self, '_dynamic_var_names', []) - dynamic_vars = { - k: v for k, v in context_vars_dict.items() - if k in dynamic_var_names - } - if dynamic_vars: + # dynamic_var_names = getattr(self, '_dynamic_var_names', []) + # dynamic_vars = { + # k: v for k, v in context_vars_dict.items() + # if k in dynamic_var_names + # } + if dynamic_context_vars_dict: device = next(self.dynamic_cond_module.parameters()).device - dynamic_vars = { + dynamic_context_vars_dict = { k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v - for k, v in dynamic_vars.items() + for k, v in dynamic_context_vars_dict.items() } - dynamic_embedding, _ = self.dynamic_cond_module(dynamic_vars) + dynamic_embedding, _ = self.dynamic_cond_module(dynamic_context_vars_dict) if torch.isnan(dynamic_embedding).any() or torch.isinf(dynamic_embedding).any(): raise ValueError(f"NaN/Inf detected in dynamic embedding.") embeddings.append(dynamic_embedding) @@ -227,6 +225,21 @@ def __init__( self.static_context_vars = self.categorical_vars + self.continuous_vars self.context_vars = self.static_context_vars + self.dynamic_context_vars + + # Normalizer-specific conditioning: subset of static vars for grouping / stats (e.g. per-station) + self.normalizer_group_vars = getattr(self.dataset_cfg, "normalizer_group_vars", None) + self.normalizer_static_vars = ( + list(self.normalizer_group_vars) if self.normalizer_group_vars is not None + else self.static_context_vars + ) + if self.normalizer_static_vars: + bad = [v for v in self.normalizer_static_vars if v not in self.static_context_vars] + if bad: + raise ValueError(f"normalizer_group_vars {self.normalizer_group_vars} contains vars not in static_context_vars: {bad}") + # When normalizer_group_vars is set it is always static-only (dynamic vars are rejected in + # _build_training_samples), so exclude dynamic vars from normalizer conditioning entirely. + self.normalizer_dynamic_vars = [] if self.normalizer_group_vars is not None else self.dynamic_context_vars + self._group_bin_edges = {} # filled in setup() when grouping by continuous vars self.time_series_cols = dataset_cfg.time_series_columns[: dataset_cfg.time_series_dims] self.time_series_dims = dataset_cfg.time_series_dims @@ -255,27 +268,35 @@ def __init__( self.register_buffer("global_mu_std", torch.tensor(1.0)) self.register_buffer("global_log_sigma_mean", torch.tensor(0.0)) - # Create static context module + # Create static context module (only normalizer_static_vars so it matches group conditioning) self.static_context_module = None - if self.static_context_vars: + if self.normalizer_static_vars: StaticContextModuleCls = get_context_module_cls(self.static_module_type) - self.static_context_vars_dict = { - k: v for k, v in self.dataset.context_var_dict.items() - if k in self.static_context_vars - } + n_bins = getattr(self.dataset_cfg, "numeric_context_bins", 5) + self.static_context_vars_dict = {} + for k in self.normalizer_static_vars: + if k in self.continuous_vars: + # Binned continuous: treat as categorical with n_bins for normalizer conditioning + self.static_context_vars_dict[k] = ["categorical", n_bins] + else: + self.static_context_vars_dict[k] = self.dataset.context_var_dict[k] self.static_context_module = StaticContextModuleCls( self.static_context_vars_dict, 256, ) - # Create dynamic context module + # Create dynamic context module only for vars the normalizer should condition on. + # Filter to vars that are both in normalizer_dynamic_vars and are actual dynamic (time_series) vars. + # When normalizer_group_vars is set, normalizer_dynamic_vars is empty so the dict is empty + # and no dynamic module is created. self.dynamic_context_module = None - if self.dynamic_context_vars and self.dynamic_module_type is not None: + _normalizer_dynamic_vars_dict = { + k: v for k, v in self.dataset_cfg.context_vars.items() + if k in self.normalizer_dynamic_vars and k in self.dynamic_context_vars + } + if _normalizer_dynamic_vars_dict and self.dynamic_module_type is not None: DynamicContextModuleCls = get_context_module_cls("dynamic", self.dynamic_module_type) - dynamic_context_vars_dict = { - k: v for k, v in self.dataset_cfg.context_vars.items() - if k in self.dynamic_context_vars - } + dynamic_context_vars_dict = _normalizer_dynamic_vars_dict dynamic_seq_len = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len self.dynamic_context_module = DynamicContextModuleCls( dynamic_context_vars_dict, @@ -290,7 +311,6 @@ def __init__( time_series_dims=self.time_series_dims, do_scale=self.do_scale, stats_head_type=self.stats_head_type, - dynamic_var_names=self.dynamic_context_vars, n_layers=context_cfg.normalizer.n_layers, ) @@ -320,7 +340,8 @@ def setup(self, stage: Optional[str] = None): # Compute per-sample statistics # Note: Using robust quantile scaling for targets to avoid outlier instability mode = getattr(self.dataset_cfg, "normalizer_stats_mode", "sample") - self.sample_stats = self._build_training_samples(mode, use_quantile_scale=True) + group_vars = getattr(self.dataset_cfg, "normalizer_group_vars", None) + self.sample_stats = self._build_training_samples(mode, use_quantile_scale=True, group_vars=group_vars) if self.use_global_stats_preprocessing: # --- COMPUTE GLOBAL TARGET STATS FOR SCALING --- @@ -391,7 +412,7 @@ def _raw_mu_to_real(self, pred_mu_raw: torch.Tensor) -> torch.Tensor: # Direct path: network predicts asinh(mu); invert with sinh and clamp for stability return torch.sinh(torch.clamp(pred_mu_raw, -self.max_asinh_mu, self.max_asinh_mu)) - def forward(self, cat_vars_dict: dict): + def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_dict: dict = None): """ Predict normalization parameters. Applies UNSCALING logic to convert network outputs back to real-world range. @@ -399,7 +420,7 @@ def forward(self, cat_vars_dict: dict): Returns: Tuple of (pred_mu_real, pred_sigma_real, pred_z_min, pred_z_max, pred_log_sigma_raw). """ - pred_mu_raw, pred_sigma, pred_zmin, pred_zmax, pred_log_sigma_raw = self.normalizer_model(cat_vars_dict) + pred_mu_raw, pred_sigma, pred_zmin, pred_zmax, pred_log_sigma_raw = self.normalizer_model(static_context_vars_dict, dynamic_context_vars_dict) pred_mu_real = self._raw_mu_to_real(pred_mu_raw) @@ -427,12 +448,8 @@ def _compute_loss_mse(self, pred_mu_raw, pred_log_sigma_raw, mu_t_scaled, target return loss_mu, loss_sigma def training_step(self, batch, batch_idx: int): - context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch - - # 1. Get RAW network outputs (scaled space) from internal model - # We call self.normalizer_model directly to avoid the unscaling logic in self.forward - # This gives us values that match the standardized targets - pred_mu_raw, _, pred_z_min, pred_z_max, pred_log_sigma_raw = self.normalizer_model(context_vars_dict) + static_context_vars_dict, dynamic_context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch + pred_mu_raw, _, pred_z_min, pred_z_max, pred_log_sigma_raw = self.normalizer_model(static_context_vars_dict, dynamic_context_vars_dict) # 2. Targets: scaled (global) or asinh(mu) + log(sigma) (direct) if self.use_global_stats_preprocessing: @@ -569,15 +586,15 @@ def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): - context_vars_dict, dynamic_ctx_dict, mu_arr, sigma_arr, zmin_arr, zmax_arr = self.samples[idx] + static_context_vars_dict, dynamic_context_vars_dict, mu_arr, sigma_arr, zmin_arr, zmax_arr = self.samples[idx] for var_name in self.dynamic_context_vars: - if var_name in dynamic_ctx_dict: - ts_data = np.array(dynamic_ctx_dict[var_name]) + if var_name in dynamic_context_vars_dict: + ts_data = np.array(dynamic_context_vars_dict[var_name]) var_info = self.dataset_cfg.context_vars.get(var_name, None) if var_info and var_info[1] is not None: - context_vars_dict[var_name] = torch.from_numpy(ts_data).long() + dynamic_context_vars_dict[var_name] = torch.from_numpy(ts_data).long() else: - context_vars_dict[var_name] = torch.from_numpy(ts_data).float() + dynamic_context_vars_dict[var_name] = torch.from_numpy(ts_data).float() mu_t = torch.from_numpy(mu_arr).float() sigma_t = torch.from_numpy(sigma_arr).float() @@ -588,7 +605,7 @@ def __getitem__(self, idx: int): # Return dummy tensors so DataLoader collate does not see None zmin_t = torch.zeros_like(mu_t) zmax_t = torch.zeros_like(mu_t) - return context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t + return static_context_vars_dict, dynamic_context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t return _TrainSet(self.sample_stats, self.dynamic_context_vars, self.do_scale, self.dataset_cfg) @@ -599,30 +616,43 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: df_out = df.copy() self.eval() - continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] + continuous_vars = set(self.continuous_vars) categorical_ts = getattr(self.dataset, 'categorical_time_series', {}) - + group_edges = getattr(self, "_group_bin_edges", {}) + with torch.no_grad(): for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Normalizing"): - ctx = {} + static_context_vars_dict = {} + dynamic_context_vars_dict = {} for v in self.context_vars: - if v in continuous_vars: - ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + if v in self.normalizer_static_vars: + if v in continuous_vars and v in group_edges: + edges = group_edges[v] + bin_idx = np.digitize(np.asarray([float(row[v])]), edges[1:-1], right=False) + bin_idx = np.clip(bin_idx, 0, len(edges) - 2).item() + static_context_vars_dict[v] = torch.tensor(bin_idx, dtype=torch.long).unsqueeze(0) + elif v in continuous_vars: + static_context_vars_dict[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + else: + static_context_vars_dict[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) + elif v in self.static_context_vars: + continue + elif v in self.normalizer_dynamic_vars: + dynamic_context_vars_dict[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) elif v in self.dynamic_context_vars: - dtype = torch.long if v in categorical_ts else torch.float32 - ctx[v] = torch.tensor(row[v], dtype=dtype).unsqueeze(0) + continue # dynamic var excluded from normalizer conditioning else: - ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) - + raise ValueError(f"Variable {v} not found in context_vars") + # self(ctx) calls forward, which automatically UNSCALES predictions - pred_mu, pred_sigma, pred_zmin, pred_zmax, _ = self(ctx) + pred_mu, pred_sigma, pred_zmin, pred_zmax, _ = self(static_context_vars_dict, dynamic_context_vars_dict) mu, sigma = pred_mu[0].cpu().numpy(), pred_sigma[0].cpu().numpy() for d, col in enumerate(self.time_series_cols): if col in categorical_ts: df_out.at[i, col] = np.asarray(row[col]).astype(np.int32) continue - + arr = np.asarray(row[col], dtype=np.float32) sigma_floor = max(self.min_sigma, 0.01 * np.exp(self.global_log_sigma_mean.cpu().item())) @@ -645,23 +675,45 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: df_out = df.copy() self.eval() - continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] - categorical_ts = getattr(self.dataset, 'categorical_time_series', {}) + # continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] + continuous_vars = set(self.continuous_vars) + categorical_ts = getattr(self.dataset, "categorical_time_series", {}) + group_edges = getattr(self, "_group_bin_edges", {}) with torch.no_grad(): for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Inverse normalizing"): - ctx = {} + static_context_vars_dict = {} + dynamic_context_vars_dict = {} for v in self.context_vars: - if v in continuous_vars: - ctx[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) - elif v in self.dynamic_context_vars: + if v in self.normalizer_static_vars: + # Normalizer only conditions on normalizer_static_vars (e.g. station) + if v in continuous_vars and v in group_edges: + edges = group_edges[v] + bin_idx = np.digitize(np.asarray([float(row[v])]), edges[1:-1], right=False) + bin_idx = np.clip(bin_idx, 0, len(edges) - 2).item() + static_context_vars_dict[v] = torch.tensor(bin_idx, dtype=torch.long).unsqueeze(0) + elif v in continuous_vars: + static_context_vars_dict[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + else: + static_context_vars_dict[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) + elif v in self.static_context_vars: + continue # not used by normalizer + elif v in self.normalizer_dynamic_vars: + # Dynamic var included in normalizer conditioning + val = row[v] + dtype_np = np.int64 if v in categorical_ts else np.float32 + arr = np.asarray(val, dtype=dtype_np) + if arr.ndim == 0: + arr = arr[np.newaxis, np.newaxis] + elif arr.ndim == 1: + arr = arr[np.newaxis, :] dtype = torch.long if v in categorical_ts else torch.float32 - ctx[v] = torch.tensor(row[v], dtype=dtype).unsqueeze(0) - else: - ctx[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) + dynamic_context_vars_dict[v] = torch.tensor(arr, dtype=dtype) + elif v in self.dynamic_context_vars: + continue # dynamic var excluded from normalizer conditioning # self(ctx) calls forward, which automatically UNSCALES predictions - pred_mu, pred_sigma, pred_zmin, pred_zmax, _ = self(ctx) + pred_mu, pred_sigma, pred_zmin, pred_zmax, _ = self(static_context_vars_dict, dynamic_context_vars_dict) mu, sigma = pred_mu[0].cpu().numpy(), pred_sigma[0].cpu().numpy() for d, col in enumerate(self.time_series_cols): @@ -694,12 +746,13 @@ def _build_training_samples( df = self.dataset.data.copy() - continuous_vars = set(getattr(self.dataset_cfg, "continuous_context_vars", None) or []) + continuous_vars = set(self.continuous_vars) dynamic_vars = set(self.dynamic_context_vars) static_vars = [v for v in self.static_context_vars] if group_vars is None: - group_vars = [v for v in static_vars if (v not in continuous_vars and v not in dynamic_vars)] + # Default: group by all static categorical (continuous in normalizer_static_vars get binned in group mode) + group_vars = [v for v in self.normalizer_static_vars if (v not in continuous_vars and v not in dynamic_vars)] bad = [v for v in group_vars if v in dynamic_vars] if bad: @@ -736,16 +789,16 @@ def _row_stats(row) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optio if mode == "sample": for _, row in df.iterrows(): context_vars_dict = {} - for v in static_vars: - if v not in row: continue + for v in self.normalizer_static_vars: + if v not in row: + continue if v in continuous_vars: - # Ensure we store float for continuous, applying simple normalization if needed in dataset context_vars_dict[v] = torch.tensor(row[v], dtype=torch.float32) else: context_vars_dict[v] = torch.tensor(row[v], dtype=torch.long) dynamic_ctx_dict = {} - for v in self.dynamic_context_vars: + for v in self.normalizer_dynamic_vars: if v not in row: continue ts_data = row[v] if isinstance(ts_data, (np.ndarray, list)): @@ -760,13 +813,19 @@ def _row_stats(row) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optio return samples # mode == "group" - # Pre-process continuous vars for grouping by binning + # Pre-process continuous vars for grouping by binning; store edges for inverse_transform + self._group_bin_edges = {} for v in group_vars: if v in continuous_vars: - n_bins = getattr(self, "numeric_context_bins", 5) - df[v] = pd.cut(df[v], bins=n_bins, labels=False, include_lowest=True) - - grouped = df.groupby(group_vars, dropna=False) + n_bins = getattr(self.dataset_cfg, "numeric_context_bins", 5) + binned, edges = pd.cut( + df[v], bins=n_bins, labels=False, include_lowest=True, + duplicates="drop", retbins=True + ) + self._group_bin_edges[v] = np.asarray(edges, dtype=np.float64) + df[v] = binned + print(group_vars, "group_vars") + grouped = df.groupby(list(group_vars), dropna=False) for group_key, gdf in grouped: if len(group_vars) == 1: diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 67ec2a7..9ab0951 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -1,4 +1,5 @@ import logging +import math import os from pathlib import Path import json @@ -22,8 +23,8 @@ level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) -DATASET_OVERRIDES = ["max_samples=10000", "normalize=False"] -PECAN_OVERRIDES = ["time_series_dims=2", "user_group=pv_users"] +DATASET_OVERRIDES = ["normalize=False", "max_samples=10000"] +PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] CONFIG_DATASET_DIR = Path(__file__).resolve().parent.parent / "cents" / "config" / "dataset" @@ -150,6 +151,12 @@ def main() -> None: choices=("pecanstreet", "commercial", "airquality"), help="Dataset name (must match the one used to train the model).", ) + parser.add_argument( + "--device", + type=int, + default=0, + help="Device index to use for evaluation.", + ) parser.add_argument( "--dataset-overrides", type=str, @@ -234,6 +241,12 @@ def main() -> None: default=None, help="Epoch number to evaluate when using --run-path (e.g. 699). If omitted, use last.ckpt and save as metrics_last.json.", ) + parser.add_argument( + "--max-samples", + type=int, + default=None, + help="Limit evaluation to this many samples (applied as dataset max_samples override).", + ) args = parser.parse_args() use_run_path = args.run_path is not None @@ -256,15 +269,41 @@ def main() -> None: ) if context_path is not None: set_context_config_path(str(context_path)) + # Apply dataset overrides (e.g. max_samples) so eval uses the requested subset. + # normalize=False prevents the dataset from re-training or reloading a normalizer + # during init; the normalizer checkpoint is loaded explicitly below instead. + overrides = DATASET_OVERRIDES + if getattr(args, "max_samples", None) is not None: + overrides.append(f"max_samples={args.max_samples}") + if args.dataset_overrides: + overrides.extend(args.dataset_overrides) + dataset_cfg = apply_overrides(dataset_cfg, overrides) + logging.info("Applied dataset overrides: %s", overrides) dataset_name = dataset_cfg.get("name", "pecanstreet") dataset = _load_dataset(dataset_name, dataset_cfg, run_dir=str(run_path)) model_type = model_cfg.get("name", "diffusion_ts") + + # Resolve normalizer checkpoint from the run's normalizer directory so that + # z-space stats use the exact same normalizer the model was trained with. + normalizer_ckpts = sorted(normalizer_dir.glob("*.ckpt")) + if not normalizer_ckpts: + raise FileNotFoundError( + f"No normalizer checkpoint found in {normalizer_dir}. " + "Ensure the run was trained with the current code that saves normalizer/*.ckpt." + ) + run_normalizer_ckpt = normalizer_ckpts[0] + if len(normalizer_ckpts) > 1: + logging.warning( + "Multiple normalizer checkpoints found in %s; using %s", + normalizer_dir, run_normalizer_ckpt.name, + ) + logging.info("Using normalizer checkpoint: %s", run_normalizer_ckpt) eval_cfg = load_yaml(args.evaluator_config) top_cfg = load_yaml(args.config) cfg = OmegaConf.create({}) cfg.evaluator = eval_cfg cfg.wandb = top_cfg.get("wandb", {}) - cfg.device = "cuda:0" + cfg.device = f"cuda:{args.device}" cfg.model = model_cfg cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) cfg.model.use_ema_sampling = args.ema @@ -279,7 +318,7 @@ def main() -> None: logging.info("Loading dataset from run config (run_dir=%s)...", run_path) logging.info("Model checkpoint: %s", model_ckpt_path) gen = DataGenerator(model_type=model_type, dataset=dataset, cfg=cfg) - gen.load_from_checkpoint(str(model_ckpt_path), normalizer_ckpt=None) + gen.load_from_checkpoint(str(model_ckpt_path), normalizer_ckpt=str(run_normalizer_ckpt)) args._metrics_epoch = metrics_epoch else: args._metrics_epoch = None @@ -293,6 +332,8 @@ def main() -> None: overrides = overrides + ["scale=True"] if args.dataset_overrides: overrides = overrides + list(args.dataset_overrides) + if getattr(args, "max_samples", None) is not None: + overrides = overrides + [f"max_samples={args.max_samples}"] dataset_cfg = _load_dataset_config(args.dataset, overrides) dataset = _load_dataset(args.dataset, dataset_cfg) @@ -307,7 +348,7 @@ def main() -> None: cfg = OmegaConf.create({}) cfg.evaluator = eval_cfg cfg.wandb = top_cfg.get("wandb", {}) - cfg.device = "cuda:0" + cfg.device = f"cuda:{args.device}" cfg.model = OmegaConf.create( OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{model_type}.yaml"), resolve=True) ) @@ -394,15 +435,25 @@ def _print_metrics(m, prefix=" "): _print_metrics(normalized, prefix="") metrics["normalized_domain"] = normalized # restore for save + def _sanitize_for_json(obj): + """Recursively replace NaN/Inf with None so json.dump produces valid JSON.""" + if isinstance(obj, dict): + return {k: _sanitize_for_json(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_sanitize_for_json(v) for v in obj] + if isinstance(obj, float) and not math.isfinite(obj): + return None + return obj + # Results are automatically saved if save_results=True if use_run_path and getattr(args, "_metrics_epoch", None) is not None: metrics_file = cfg.save_dir / f"metrics_{args._metrics_epoch}.json" with open(metrics_file, "w") as f: - json.dump(metrics, f, indent=4) + json.dump(_sanitize_for_json(metrics), f, indent=4) print(f"\n✅ Results saved to {metrics_file}") elif args.save_dir: with open(Path(args.save_dir) / "metrics.json", "w") as f: - json.dump(metrics, f, indent=4) + json.dump(_sanitize_for_json(metrics), f, indent=4) print(f"\n✅ Results saved to {Path(args.save_dir) / "metrics.json"}") print("\n" + "=" * 60) diff --git a/scripts/train.py b/scripts/train.py index 80a1553..62310e2 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -78,9 +78,9 @@ def main(args) -> None: TC_LOSS_WEIGHT = args.tc_loss_weight run_name = args.run_name - # Create run directory under runs/ + # Create run directory under runs/{dataset}/{run_name} RUNS_DIR.mkdir(parents=True, exist_ok=True) - run_dir = RUNS_DIR / run_name + run_dir = RUNS_DIR / args.dataset / run_name run_dir.mkdir(parents=True, exist_ok=True) print(f"[Cents] Run directory: {run_dir}") @@ -189,7 +189,7 @@ def main(args) -> None: help="Path to checkpoint file (.ckpt) to resume training from", ) parser.add_argument("--run-name", type=str, required=True, - help="Name of this run. A directory runs/ will be created for checkpoints, cache, and summary.", + help="Name of this run. A directory runs// will be created for checkpoints, cache, and summary.", ) args = parser.parse_args() From a4756d13982a6565ee7f3046e8334b73a850f233 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 2 Mar 2026 11:12:27 -0500 Subject: [PATCH 40/50] cross-attention for dynamic context, cfg guidance, dropout regularization --- cents/config/model/diffusion_ts.yaml | 14 +- cents/config/trainer/diffusion_ts.yaml | 4 +- cents/config/trainer/normalizer.yaml | 4 +- cents/models/base.py | 17 +- cents/models/context.py | 193 ++++++------------- cents/models/diffusion_ts.py | 246 +++++++++++++------------ cents/models/model_utils.py | 51 +++-- scripts/eval_pretrained.py | 19 +- 8 files changed, 271 insertions(+), 277 deletions(-) diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 67f9bfa..2383006 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -3,12 +3,12 @@ name: diffusion_ts context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 noise_dim: 256 -cond_emb_dim: 16 +cond_emb_dim: 64 n_layer_enc: 4 n_layer_dec: 5 d_model: 128 n_steps: 1000 -sampling_timesteps: 1000 +sampling_timesteps: 200 sampling_batch_size: 4096 loss_type: l1 #l2 training_objective: v @@ -18,16 +18,16 @@ beta_schedule: cosine #linear diffusion ts paper uses linear schedule n_heads: 4 mlp_hidden_times: 4 eta: 0.0 -attn_pd: 0.0 -resid_pd: 0.0 +attn_pd: 0.1 +resid_pd: 0.1 kernel_size: null padding_size: null -use_ff: False +use_ff: True reg_weight: null gradient_accumulate_every: 2 ema_decay: 0.99 ema_update_interval: 10 -use_ema_sampling: False +use_ema_sampling: True k_bins: 20 # Reconstruction-guided sampling (Algorithms 1 & 2) recon_guide_eta: 0.1 # gradient scale for guidance @@ -37,7 +37,7 @@ recon_guide_K: 3 # inner steps per t for alg2 (int or list for K[t]) # Optional: dual head for x̂_a / x̂_b (set to cond_len used in recon-guided sampling) recon_cond_len: null # int or null; if set, use fc_a / fc_b for first vs rest of sequence # Context embedding dropout (training only) for more robust recon-guided sampling -context_embed_dropout: 0 # 0 = disabled +context_embed_dropout: 0.1 # probability of zeroing entire context embedding per sample (CFG-compatible) blue_noise_power: 0.0 # 0.0 = white noise, 1.0 = blue noise, 2.0 = violet noise # When true and time_series_dims > 1, noise is correlated across dimensions (same draw per timestep). correlated_noise: True \ No newline at end of file diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index f401396..970321a 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -5,7 +5,7 @@ strategy: ddp_find_unused_parameters_true gradient_accumulate_every: 4 log_every_n_steps: 1 batch_size: 512 -max_epochs: 1000 +max_epochs: 2500 base_lr: 1e-4 eval_after_training: False @@ -13,7 +13,7 @@ checkpoint: save_last: True # Save final model save_top_k: 0 # 0 = only periodic saves; use >0 to also save top-k by metric every_n_train_steps: null - every_n_epochs: 100 # Save a distinct checkpoint every 250 epochs (250, 500, 750, ...) + every_n_epochs: 250 # Save a distinct checkpoint every 250 epochs (250, 500, 750, ...) lr_scheduler_params: factor: 0.5 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index f34db7f..fa3e171 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,10 +1,10 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: 1, +devices: 2, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 -n_epochs: 500 +n_epochs: 2000 batch_size: 4096 lr: 3e-4 save_cycle: 5000 diff --git a/cents/models/base.py b/cents/models/base.py index de03b79..17c8405 100644 --- a/cents/models/base.py +++ b/cents/models/base.py @@ -85,16 +85,25 @@ def __init__(self, cfg: DictConfig = None): seq_len=seq_len if num_ts_steps is None else num_ts_steps, ) - # Determine embedding dimension and create combine MLP if both exist + # Determine embedding dimension if self.static_context_module is not None: self.embedding_dim = self.static_context_module.embedding_dim elif self.dynamic_context_module is not None: self.embedding_dim = self.dynamic_context_module.embedding_dim else: raise ValueError("At least one of static_context_module or dynamic_context_module must be provided") - - # If both modules exist, create combine MLP - if self.static_context_module is not None and self.dynamic_context_module is not None: + + # combine_mlp is only needed when the dynamic module returns a pooled + # vector (returns_sequence=False) that must be fused with the static + # embedding before AdaLN conditioning. When returns_sequence=True the + # dynamic context is routed to cross-attention inside the backbone and + # must NOT be merged with the static embedding here. + dyn_returns_seq = getattr(self.dynamic_context_module, "returns_sequence", False) \ + if self.dynamic_context_module is not None else False + + if (self.static_context_module is not None + and self.dynamic_context_module is not None + and not dyn_returns_seq): combined_dim = self.static_context_module.embedding_dim + self.dynamic_context_module.embedding_dim self.combine_mlp = nn.Sequential( nn.Linear(combined_dim, self.embedding_dim), diff --git a/cents/models/context.py b/cents/models/context.py index dcf7f72..671d73b 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -425,8 +425,15 @@ class DynamicContextModule_Transformer(BaseContextModule): """ Context module for processing dynamic (time series) context variables. Uses Transformer encoder to encode time series sequences into embeddings. + + Returns the full encoded sequence (B, T, embedding_dim) rather than a + pooled vector, so temporal structure is preserved for cross-attention + conditioning in the diffusion backbone. """ - + + # Signals to BaseModel that this module returns (B, T, emb_dim) not (B, emb_dim) + returns_sequence = True + def __init__( self, context_vars: dict[str, int], @@ -501,28 +508,12 @@ def __init__( for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars }) - # Pooling layer to get fixed-size embedding from sequence - # Use learnable weighted pooling (attention pooling) - self.pooling_layers = nn.ModuleDict({ - name: nn.Sequential( - nn.Linear(embedding_dim, embedding_dim), - nn.Tanh(), - nn.Linear(embedding_dim, 1, bias=False), - ) - for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars - }) - - # Mixing MLP to combine all time series embeddings - total_dim = embedding_dim * (len(self.categorical_ts_vars) + len(self.numeric_ts_vars)) - if total_dim > 0: - self.mixing_mlp = nn.Sequential( - nn.Linear(total_dim, 128), - nn.ReLU(), - nn.Linear(128, embedding_dim), - ) - else: - self.mixing_mlp = nn.Identity() - + # Output projection: sum contributions from all variables, then project + # to embedding_dim so the cross-attention key/value dim is consistent. + n_vars = len(self.categorical_ts_vars) + len(self.numeric_ts_vars) + # Per-variable weight (scalar) for the additive mixture across variables + self.var_mix = nn.Linear(n_vars * embedding_dim, embedding_dim) if n_vars > 1 else None + # Initialize weights self._initialize_weights() @@ -546,135 +537,73 @@ def _initialize_weights(self): def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Process dynamic (time series) context variables using Transformer. - + Args: context_vars: Dict mapping variable names to tensors. - For categorical TS: (batch, seq_len) with integer values - For numeric TS: (batch, seq_len) with float values - + For categorical TS: (batch, seq_len) integer values. + For numeric TS: (batch, seq_len) float values. + Returns: - embedding: Combined embedding of shape (batch_size, embedding_dim) - outputs: Empty dict for compatibility + sequence: Combined sequence of shape (batch, seq_len, embedding_dim). + Temporal structure is preserved for downstream cross-attention. + outputs: Empty dict for interface compatibility. """ - embeddings = [] - + sequences = [] + # Process categorical time series for name in self.categorical_ts_vars.keys(): if name in context_vars: - # Input: (batch, seq_len) with integer indices - ts_data = context_vars[name] # (batch, seq_len) - # Check for NaN/Inf in input + ts_data = context_vars[name] # (B, T) if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): - raise ValueError(f"NaN/Inf detected in categorical time series input '{name}'") - - # Embed: (batch, seq_len) -> (batch, seq_len, embedding_dim) - embedded = self.ts_embeddings[name](ts_data) - - # Add positional encoding if available + raise ValueError(f"NaN/Inf in categorical TS input '{name}'") + embedded = self.ts_embeddings[name](ts_data) # (B, T, emb_dim) if self.pos_encodings is not None and name in self.pos_encodings: - seq_len_actual = embedded.size(1) - pos_enc = self.pos_encodings[name][:, :seq_len_actual, :] - embedded = embedded + pos_enc - - # Check for NaN after embedding - if torch.isnan(embedded).any() or torch.isinf(embedded).any(): - raise ValueError(f"NaN/Inf detected after embedding for '{name}'") - - # Encode with transformer: (batch, seq_len, embedding_dim) -> (batch, seq_len, embedding_dim) - encoded = self.ts_encoders[name](embedded) - - # Check for NaN after encoding + embedded = embedded + self.pos_encodings[name][:, :embedded.size(1)] + encoded = self.ts_encoders[name](embedded) # (B, T, emb_dim) if torch.isnan(encoded).any() or torch.isinf(encoded).any(): - raise ValueError(f"NaN/Inf detected after transformer encoding for '{name}'") - - # Pool to fixed size: (batch, seq_len, embedding_dim) -> (batch, embedding_dim) - # Use attention-based pooling - attention_weights = self.pooling_layers[name](encoded) # (batch, seq_len, 1) - attention_weights = torch.softmax(attention_weights, dim=1) - pooled = (encoded * attention_weights).sum(dim=1) # (batch, embedding_dim) - - # Normalize pooled embedding to prevent accumulation of large values - # Layer normalization: normalize across embedding dimension - pooled_mean = pooled.mean(dim=1, keepdim=True) # (batch, 1) - pooled_std = pooled.std(dim=1, keepdim=True) + 1e-8 # (batch, 1) - pooled = (pooled - pooled_mean) / pooled_std - - embeddings.append(pooled) - + raise ValueError(f"NaN/Inf after transformer encoding '{name}'") + sequences.append(encoded) + # Process numeric time series for name in self.numeric_ts_vars: if name in context_vars: - # Input: (batch, seq_len) with float values - ts_data = context_vars[name] # (batch, seq_len) - # Ensure numeric time series are float type (not long/int) + ts_data = context_vars[name] # (B, T) if not ts_data.is_floating_point(): ts_data = ts_data.float() - - # Check for NaN/Inf in input - if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): - raise ValueError(f"NaN/Inf detected in numeric time series input '{name}'") - - # Replace NaN/Inf with zeros to prevent propagation ts_data = torch.where(torch.isfinite(ts_data), ts_data, torch.zeros_like(ts_data)) - - # Normalize input to prevent numerical overflow - # Compute per-sample statistics to normalize each time series independently - ts_mean = ts_data.mean(dim=1, keepdim=True) # (batch, 1) - ts_std = ts_data.std(dim=1, keepdim=True) + 1e-8 # (batch, 1) - add epsilon to prevent division by zero - ts_data_normalized = (ts_data - ts_mean) / ts_std - - # Project to embedding_dim: (batch, seq_len) -> (batch, seq_len, embedding_dim) - ts_data_expanded = ts_data_normalized.unsqueeze(-1) # (batch, seq_len, 1) - embedded = self.ts_projections[name](ts_data_expanded) # (batch, seq_len, embedding_dim) - - # Add positional encoding if available + # Per-sample z-score normalisation over the time axis + ts_mean = ts_data.mean(dim=1, keepdim=True) + ts_std = ts_data.std(dim=1, keepdim=True) + 1e-8 + ts_data = (ts_data - ts_mean) / ts_std + embedded = self.ts_projections[name](ts_data.unsqueeze(-1)) # (B, T, emb_dim) if self.pos_encodings is not None and name in self.pos_encodings: - seq_len_actual = embedded.size(1) - pos_enc = self.pos_encodings[name][:, :seq_len_actual, :] - embedded = embedded + pos_enc - - # Check for NaN after projection + embedded = embedded + self.pos_encodings[name][:, :embedded.size(1)] if torch.isnan(embedded).any() or torch.isinf(embedded).any(): - raise ValueError(f"NaN/Inf detected after projection for '{name}'") - - # Encode with transformer: (batch, seq_len, embedding_dim) -> (batch, seq_len, embedding_dim) - encoded = self.ts_encoders[name](embedded) - - # Check for NaN after encoding + raise ValueError(f"NaN/Inf after projection for '{name}'") + encoded = self.ts_encoders[name](embedded) # (B, T, emb_dim) if torch.isnan(encoded).any() or torch.isinf(encoded).any(): - raise ValueError(f"NaN/Inf detected after transformer encoding numeric TS '{name}'") - - # Pool to fixed size: (batch, seq_len, embedding_dim) -> (batch, embedding_dim) - # Use attention-based pooling - attention_weights = self.pooling_layers[name](encoded) # (batch, seq_len, 1) - attention_weights = torch.softmax(attention_weights, dim=1) - pooled = (encoded * attention_weights).sum(dim=1) # (batch, embedding_dim) - - # Normalize pooled embedding to prevent accumulation of large values - # Layer normalization: normalize across embedding dimension - pooled_mean = pooled.mean(dim=1, keepdim=True) # (batch, 1) - pooled_std = pooled.std(dim=1, keepdim=True) + 1e-8 # (batch, 1) - pooled = (pooled - pooled_mean) / pooled_std - - embeddings.append(pooled) - - if not embeddings: - # No dynamic context variables, return zero embedding + raise ValueError(f"NaN/Inf after transformer encoding numeric TS '{name}'") + sequences.append(encoded) + + if not sequences: + device = next(iter(context_vars.values())).device if context_vars else None batch_size = next(iter(context_vars.values())).size(0) if context_vars else 1 - embedding = torch.zeros(batch_size, self.embedding_dim, device=next(iter(context_vars.values())).device if context_vars else None) - return embedding, {} - - # Combine all time series embeddings - combined = torch.cat(embeddings, dim=1) # (batch, total_dim) - # Check for NaN before mixing - if torch.isnan(combined).any() or torch.isinf(combined).any(): - raise ValueError(f"NaN/Inf detected in combined embeddings before mixing MLP") - embedding = self.mixing_mlp(combined) # (batch, embedding_dim) - # Check for NaN after mixing and normalization - if torch.isnan(embedding).any() or torch.isinf(embedding).any(): - raise ValueError(f"NaN/Inf detected in final embedding after mixing MLP and normalization") - - return embedding, {} + T = self.seq_len if self.seq_len is not None else 1 + return torch.zeros(batch_size, T, self.embedding_dim, device=device), {} + + if len(sequences) == 1: + out = sequences[0] + elif self.var_mix is not None: + # Learned combination across variables: concat along feature axis, project back + out = self.var_mix(torch.cat(sequences, dim=-1)) # (B, T, emb_dim) + else: + # Single-variable fallback (var_mix is None only when n_vars == 1) + out = sequences[0] + + if torch.isnan(out).any() or torch.isinf(out).any(): + raise ValueError("NaN/Inf in dynamic context sequence output") + + return out, {} def on_after_backward(self): unused = [n for n,p in self.named_parameters() if p.requires_grad and p.grad is None] diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 54aaa40..9be4879 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -159,9 +159,12 @@ def __init__(self, cfg: DictConfig): if not hasattr(self, 'static_context_module') and not hasattr(self, 'dynamic_context_module'): raise ValueError("At least one context module (static or dynamic) must be initialized") - # Context embedding dropout (training only) for robust reconstruction-guided sampling - context_embed_dropout = getattr(cfg.model, "context_embed_dropout", 0.0) - self.context_embed_dropout = nn.Dropout(p=context_embed_dropout) + # Context embedding dropout (training only): zeros the *entire* embedding for a random + # subset of samples (Bernoulli with prob p). This is CFG-compatible — the model learns + # to denoise both with and without context, enabling guidance-scale inference later. + self.context_embed_dropout_p = getattr(cfg.model, "context_embed_dropout", 0.0) + # Keep the nn.Dropout attribute so old checkpoints that saved it don't break on load. + self.context_embed_dropout = nn.Dropout(p=0.0) # Optional dual head for x̂_a / x̂_b: separate output heads for conditional vs rest of sequence self.recon_cond_len = getattr(cfg.model, "recon_cond_len", None) @@ -290,99 +293,85 @@ def __init__(self, cfg: DictConfig): self.continuous_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "continuous"] self.categorical_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "categorical"] - def _get_context_embedding(self, static_context_vars: dict, dynamic_context_vars: dict = None) -> Tuple[torch.Tensor, dict]: + def _get_context_embedding( + self, static_context_vars: dict, dynamic_context_vars: dict = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], dict]: """ - Get combined context embedding from static and/or dynamic context modules. - - Args: - context_vars: Dict of context tensors (static: single values, dynamic: time series) - + Get context embeddings from static and/or dynamic context modules. + Returns: - embedding: Combined embedding tensor of shape (batch_size, embedding_dim) - all_logits: Dict of classification/regression logits from both modules + static_emb: (B, embedding_dim) — fed into AdaLN via the Transformer's cond path. + dyn_ctx_seq: (B, T, embedding_dim) or None — fed into cross-attention in each + DecoderBlock when the dynamic module returns_sequence=True. + all_logits: Dict of auxiliary classification/regression logits (static only). """ - embeddings = [] all_logits = {} - # for k, v in context_vars.items(): - # if isinstance(v, torch.Tensor): - # _nan_check(v, f"_get_context_embedding context_vars[{k}]") - - # Process static context variables - if self.static_context_module is not None: - # Filter static context variables - # static_vars = { - # k: v for k, v in context_vars.items() - # if k not in getattr(self, 'dynamic_context_vars', []) - # } - if static_context_vars: - device = next(self.static_context_module.parameters()).device - static_vars = { - k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v - for k, v in static_context_vars.items() - } - # Debug: print which static input has NaN (only when we see one) - for k, v in static_context_vars.items(): - if isinstance(v, torch.Tensor) and (torch.isnan(v).any() or torch.isinf(v).any()): - nan_c = torch.isnan(v).sum().item() - inf_c = torch.isinf(v).sum().item() - finite = v[~(torch.isnan(v) | torch.isinf(v))] - min_s = finite.min().item() if finite.numel() > 0 else float("nan") - max_s = finite.max().item() if finite.numel() > 0 else float("nan") - mean_s = finite.float().mean().item() if finite.numel() > 0 else float("nan") - print( - f"[NaN/Inf] static_var '{k}': shape={tuple(v.shape)}, dtype={v.dtype}, " - f"nan_count={nan_c}, inf_count={inf_c}, finite_min={min_s:.6g}, finite_max={max_s:.6g}, finite_mean={mean_s:.6g}" - ) - # Replace NaN/Inf in static inputs so the context module does not produce NaN embeddings - # def _sanitize(t: torch.Tensor) -> torch.Tensor: - # if not isinstance(t, torch.Tensor): - # return t - # if not (torch.isnan(t).any() or torch.isinf(t).any()): - # return t - # if t.is_floating_point(): - # return torch.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0) - # return t - # static_vars = {k: _sanitize(v) for k, v in static_vars.items()} - static_embedding, static_logits = self.static_context_module(static_context_vars) - _nan_check(static_embedding, "_get_context_embedding static_embedding") - embeddings.append(static_embedding) - all_logits.update(static_logits) - - # Process dynamic context variables - if self.dynamic_context_module is not None: - # Filter dynamic context variables - # dynamic_var_names = getattr(self, 'dynamic_context_vars', []) - # dynamic_vars = { - # k: v for k, v in context_vars.items() - # if k in dynamic_var_names - # } - if dynamic_context_vars: - device = next(self.dynamic_context_module.parameters()).device - dynamic_context_vars = { - k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v - for k, v in dynamic_context_vars.items() - } - dynamic_embedding, dynamic_logits = self.dynamic_context_module(dynamic_context_vars) - _nan_check(dynamic_embedding, "_get_context_embedding dynamic_embedding") - embeddings.append(dynamic_embedding) - all_logits.update(dynamic_logits) - - # Combine embeddings if both exist - if len(embeddings) == 2: - combined = torch.cat(embeddings, dim=1) - _nan_check(combined, "_get_context_embedding combined") - embedding = self.combine_mlp(combined) - elif len(embeddings) == 1: - embedding = embeddings[0] - else: - raise ValueError("No context variables provided") - if embedding.is_floating_point(): - embedding = embedding.float() - _nan_check(embedding, "_get_context_embedding embedding (before dropout)") - if self.training and self.context_embed_dropout.p > 0: - embedding = self.context_embed_dropout(embedding) - _nan_check(embedding, "_get_context_embedding embedding (final)") - return embedding, all_logits + static_emb = None + dyn_ctx_seq = None + + # --- Static context (categorical + continuous) → (B, emb_dim) for AdaLN --- + if self.static_context_module is not None and static_context_vars: + device = next(self.static_context_module.parameters()).device + static_vars = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in static_context_vars.items() + } + for k, v in static_vars.items(): + if isinstance(v, torch.Tensor) and (torch.isnan(v).any() or torch.isinf(v).any()): + nan_c = torch.isnan(v).sum().item() + inf_c = torch.isinf(v).sum().item() + finite = v[~(torch.isnan(v) | torch.isinf(v))] + min_s = finite.min().item() if finite.numel() > 0 else float("nan") + max_s = finite.max().item() if finite.numel() > 0 else float("nan") + mean_s = finite.float().mean().item() if finite.numel() > 0 else float("nan") + print( + f"[NaN/Inf] static_var '{k}': shape={tuple(v.shape)}, dtype={v.dtype}, " + f"nan_count={nan_c}, inf_count={inf_c}, finite_min={min_s:.6g}, " + f"finite_max={max_s:.6g}, finite_mean={mean_s:.6g}" + ) + static_emb, static_logits = self.static_context_module(static_vars) + _nan_check(static_emb, "_get_context_embedding static_emb") + all_logits.update(static_logits) + + # --- Dynamic context (time series) → (B, T, emb_dim) for cross-attention --- + if self.dynamic_context_module is not None and dynamic_context_vars: + device = next(self.dynamic_context_module.parameters()).device + dyn_vars = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in dynamic_context_vars.items() + } + dyn_out, dyn_logits = self.dynamic_context_module(dyn_vars) + _nan_check(dyn_out, "_get_context_embedding dyn_out") + all_logits.update(dyn_logits) + + if getattr(self.dynamic_context_module, "returns_sequence", False): + # (B, T, emb_dim) — routed to cross-attention, not AdaLN + dyn_ctx_seq = dyn_out.float() if dyn_out.is_floating_point() else dyn_out + else: + # Legacy pooled vector: fuse with static embedding via combine_mlp + if static_emb is not None and self.combine_mlp is not None: + combined = torch.cat([static_emb, dyn_out], dim=1) + static_emb = self.combine_mlp(combined) + elif static_emb is None: + static_emb = dyn_out + + if static_emb is None: + raise ValueError("No static context embedding could be produced") + + if static_emb.is_floating_point(): + static_emb = static_emb.float() + _nan_check(static_emb, "_get_context_embedding static_emb (before dropout)") + if self.training and self.context_embed_dropout_p > 0: + # Sample-wise mask: zero the entire embedding for ~p fraction of samples. + # Each sample independently gets its context dropped (not individual features), + # which teaches the model to work unconditionally — enabling CFG at inference. + mask = torch.bernoulli( + torch.full((static_emb.shape[0], 1), 1.0 - self.context_embed_dropout_p, + device=static_emb.device, dtype=static_emb.dtype) + ) + static_emb = static_emb * mask + _nan_check(static_emb, "_get_context_embedding static_emb (final)") + return static_emb, dyn_ctx_seq, all_logits def _decode_to_x0(self, backbone: torch.Tensor) -> torch.Tensor: """ @@ -570,7 +559,7 @@ def forward(self, x: torch.Tensor, static_context_vars: dict, dynamic_context_va b = x.shape[0] t = torch.randint(0, self.num_timesteps, (b,), device=self.device) - embedding, cond_classification_logits = self._get_context_embedding(static_context_vars, dynamic_context_vars) + embedding, dyn_ctx_seq, cond_classification_logits = self._get_context_embedding(static_context_vars, dynamic_context_vars) _nan_check(embedding, "forward embedding") noise = blueish_noise_like( @@ -582,7 +571,7 @@ def forward(self, x: torch.Tensor, static_context_vars: dict, dynamic_context_va + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * noise ) _nan_check(x_noisy, "forward x_noisy") - trend, season = self.model(x_noisy, t, padding_masks=None, cond=embedding) + trend, season = self.model(x_noisy, t, padding_masks=None, cond=embedding, dyn_ctx=dyn_ctx_seq) _nan_check(trend, "forward trend") _nan_check(season, "forward season") x_start_pred = self._decode_to_x0((trend + season).contiguous()) @@ -694,7 +683,7 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: # cond_loss /= len(cond_class_logits) - h, _ = self._get_context_embedding(static_context_batch, dynamic_context_batch) + h, _, _ = self._get_context_embedding(static_context_batch, dynamic_context_batch) _nan_check(h, "training_step h (for tc)") tc_term = ( self.cfg.model.tc_loss_weight * total_correlation(h) @@ -807,20 +796,22 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: # raise ValueError("No EMA keys found in checkpoint") def _predict_x0_from_xt_with_grad( - self, x_t: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor + self, x_t: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor, + dyn_ctx: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Predict x0 from x_t with gradients enabled (for reconstruction-guided sampling). Returns x_start of shape (B, L, C). Call with x_t.requires_grad_(True). """ - trend, season = self.model(x_t, t, padding_masks=None, cond=embedding) + trend, season = self.model(x_t, t, padding_masks=None, cond=embedding, dyn_ctx=dyn_ctx) x_start = self._decode_to_x0((trend + season).contiguous()) _nan_check(x_start, "_predict_x0_from_xt_with_grad x_start") return x_start @torch.no_grad() def model_predictions( - self, x: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor + self, x: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor, + dyn_ctx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict both noise and clean sample from current x. @@ -829,7 +820,7 @@ def model_predictions( pred_noise: predicted noise tensor. x_start: predicted clean sample tensor. """ - trend, season = self.model(x, t, padding_masks=None, cond=embedding) + trend, season = self.model(x, t, padding_masks=None, cond=embedding, dyn_ctx=dyn_ctx) x_start = self._decode_to_x0((trend + season).contiguous()) pred_noise = self.predict_noise_from_start(x, t, x_start) _nan_check(x_start, "model_predictions x_start") @@ -850,7 +841,8 @@ def _replace_conditional( @torch.no_grad() def p_mean_variance( - self, x: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor + self, x: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor, + dyn_ctx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute mean and variance for p(x_{t-1} | x_t). @@ -858,20 +850,21 @@ def p_mean_variance( Returns: pm, pv, plv, x_start: posterior parameters and predicted x0. """ - pred_noise, x_start = self.model_predictions(x, t, embedding) + pred_noise, x_start = self.model_predictions(x, t, embedding, dyn_ctx=dyn_ctx) pm, pv, plv = self.q_posterior(x_start, x, t) _nan_check(x_start, "p_mean_variance x_start") return pm, pv, plv, x_start @torch.no_grad() def p_sample( - self, x: torch.Tensor, t: int, embedding: torch.Tensor + self, x: torch.Tensor, t: int, embedding: torch.Tensor, + dyn_ctx: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Sample x_{t-1} from x_t using posterior distribution. """ bt = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long) - pm, pv, plv, _ = self.p_mean_variance(x, bt, embedding) + pm, pv, plv, _ = self.p_mean_variance(x, bt, embedding, dyn_ctx=dyn_ctx) noise = ( blueish_noise_like(x, power=self.blue_noise_power, correlated=self.correlated_noise) if t > 0 @@ -890,6 +883,7 @@ def _reconstruction_guided_step_alg1( cond_len: int, eta: float, gamma: float, + dyn_ctx: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ One step of Algorithm 1: predict x̂_0, compute L_1 + γ*L_2, then @@ -898,7 +892,7 @@ def _reconstruction_guided_step_alg1( bt = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long) x_t = x_t.detach().requires_grad_(True) - x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding) + x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding, dyn_ctx=dyn_ctx) _nan_check(x_start, "_reconstruction_guided_step_alg1 x_start") x_hat_a = x_start[:, :cond_len] L_1 = (x_a - x_hat_a).pow(2).mean() @@ -942,17 +936,19 @@ def _reconstruction_guided_step_alg2( eta: float, gamma: float, K: int, + dyn_ctx: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ One step of Algorithm 2: K inner gradient updates on x_t, then one final sample and Replace. """ bt = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long) embedding_detach = embedding.detach() + dyn_ctx_detach = dyn_ctx.detach() if dyn_ctx is not None else None x_t = x_t.detach().clone() for _ in range(K): x_t = x_t.requires_grad_(True) - x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach) + x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach, dyn_ctx=dyn_ctx_detach) _nan_check(x_start, "_reconstruction_guided_step_alg2 x_start (inner)") x_hat_a = x_start[:, :cond_len] L_1 = (x_a - x_hat_a).pow(2).mean() @@ -973,7 +969,7 @@ def _reconstruction_guided_step_alg2( x_t = x_t.detach() with torch.no_grad(): - x_start_final = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach) + x_start_final = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach, dyn_ctx=dyn_ctx_detach) _nan_check(x_start_final, "_reconstruction_guided_step_alg2 x_start_final") pm_final, pv_final, plv_final = self.q_posterior(x_start_final, x_t, bt) noise_final = ( @@ -1026,7 +1022,7 @@ def sample_reconstruction_guided( x = _randn_shape_correlated( shape, self.device, torch.float32, self.correlated_noise ) - embedding, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) + embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) x_a = x_a.to(self.device) for t in reversed(range(self.num_timesteps)): @@ -1037,11 +1033,11 @@ def sample_reconstruction_guided( ) if algorithm == "alg1": x = self._reconstruction_guided_step_alg1( - x, t, embedding, x_a, cond_len, eta, gamma + x, t, embedding, x_a, cond_len, eta, gamma, dyn_ctx=dyn_ctx_seq ) else: x = self._reconstruction_guided_step_alg2( - x, t, embedding, x_a, cond_len, eta, gamma, K_t + x, t, embedding, x_a, cond_len, eta, gamma, K_t, dyn_ctx=dyn_ctx_seq ) return x @@ -1061,24 +1057,35 @@ def sample(self, shape: Tuple[int, int, int], static_context_vars: dict, dynamic x = _randn_shape_correlated( shape, self.device, torch.float32, self.correlated_noise ) - embedding, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) + embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) for t in reversed(range(self.num_timesteps)): - x = self.p_sample(x, t, embedding) + x = self.p_sample(x, t, embedding, dyn_ctx=dyn_ctx_seq) _nan_check(x, "sample() output") return x @torch.no_grad() def fast_sample( self, shape: Tuple[int, int, int], static_context_vars: dict, - dynamic_context_vars: dict = None + dynamic_context_vars: dict = None, cfg_scale: float = 1.0, ) -> torch.Tensor: """ - Faster sampling using a reduced number of timesteps. + DDIM sampling with optional classifier-free guidance. + + cfg_scale=1.0 → standard conditional sampling (no guidance). + cfg_scale>1.0 → CFG: runs both conditional and unconditional passes each step + and blends pred_noise = uncond + scale*(cond - uncond). + Requires the model to have been trained with context_embed_dropout > 0. """ x = _randn_shape_correlated( shape, self.device, torch.float32, self.correlated_noise ) - embedding, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) + embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) + + use_cfg = cfg_scale > 1.0 + if use_cfg: + uncond_emb = torch.zeros_like(embedding) + uncond_dyn = torch.zeros_like(dyn_ctx_seq) if dyn_ctx_seq is not None else None + times = torch.linspace( -1, self.num_timesteps - 1, steps=self.sampling_timesteps + 1 ) @@ -1086,7 +1093,13 @@ def fast_sample( pairs = list(zip(times[:-1], times[1:])) for time, time_next in pairs: bt = torch.full((x.shape[0],), time, device=self.device, dtype=torch.long) - pred_noise, x_start = self.model_predictions(x, bt, embedding) + if use_cfg: + pred_noise_u, _ = self.model_predictions(x, bt, uncond_emb, dyn_ctx=uncond_dyn) + pred_noise_c, _ = self.model_predictions(x, bt, embedding, dyn_ctx=dyn_ctx_seq) + pred_noise = pred_noise_u + cfg_scale * (pred_noise_c - pred_noise_u) + x_start = self.predict_start_from_noise(x, bt, pred_noise) + else: + pred_noise, x_start = self.model_predictions(x, bt, embedding, dyn_ctx=dyn_ctx_seq) if time_next < 0: x = x_start _nan_check(x, "fast_sample x (final step)") @@ -1152,8 +1165,9 @@ def generate(self, static_context_vars: dict, dynamic_context_vars: dict = None) shape = (current_bs, self.seq_len, self.time_series_dims) with torch.no_grad(): + cfg_scale = getattr(self, '_cfg_scale', 1.0) if self.fast_sampling: - samples = self.fast_sample(shape, batch_static_context_vars, batch_dynamic_context_vars) + samples = self.fast_sample(shape, batch_static_context_vars, batch_dynamic_context_vars, cfg_scale=cfg_scale) else: samples = self.sample(shape, batch_static_context_vars, batch_dynamic_context_vars) diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index dd95cb0..f8d3b41 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -680,7 +680,7 @@ def __init__( super().__init__() self.ln1 = AdaLayerNorm(n_embd) - self.ln2 = AdaLayerNorm(n_embd) # Changed from nn.LayerNorm to AdaLayerNorm + self.ln2 = AdaLayerNorm(n_embd) self.attn1 = FullAttention( n_embd=n_embd, @@ -695,6 +695,17 @@ def __init__( attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, ) + # Cross-attention for temporally-aligned dynamic context (same seq_len as target). + # Conditioned via AdaLN on the same diffusion-timestep + static-context embedding + # as the rest of the block, so global conditioning still flows through here. + self.ln_dyn = AdaLayerNorm(n_embd) + self.attn_dyn = CrossAttention( + n_embd=n_embd, + condition_embd=condition_dim, + n_head=n_head, + attn_pdrop=attn_pdrop, + resid_pdrop=resid_pdrop, + ) self.ln1_1 = AdaLayerNorm(n_embd) @@ -716,11 +727,14 @@ def __init__( self.proj = nn.Conv1d(n_channel, n_channel * 2, 1) self.linear = nn.Linear(n_embd, n_feat) - def forward(self, x, encoder_output, cond_emb, mask=None): + def forward(self, x, encoder_output, cond_emb, dyn_ctx=None, mask=None): a, att = self.attn1(self.ln1(x, cond_emb), mask=mask) x = x + a a, att = self.attn2(self.ln1_1(x, cond_emb), encoder_output, mask=mask) x = x + a + if dyn_ctx is not None: + a, _ = self.attn_dyn(self.ln_dyn(x, cond_emb), dyn_ctx) + x = x + a # FIX: chunk() returns views that are often non-contiguous. # Since self.proj and self.trend use Conv1d, this causes the DDP stride mismatch. @@ -772,19 +786,18 @@ def __init__( ] ) - def forward(self, x, cond_emb, enc, padding_masks=None): + def forward(self, x, cond_emb, enc, dyn_ctx=None, padding_masks=None): b, c, _ = x.shape mean = [] # Initialize accumulating tensors on the correct device - season = torch.zeros((b, c, x.shape[-1]), device=x.device) - # Note: Check if season dim is n_embd or n_feat. - # FourierLayer returns same dim as input x (n_embd) - + season = torch.zeros((b, c, x.shape[-1]), device=x.device) + # Note: FourierLayer returns same dim as input x (n_embd) + trend = torch.zeros((b, c, self.blocks[0].linear.out_features), device=x.device) - + for block in self.blocks: x, residual_mean, residual_trend, residual_season = block( - x, enc, cond_emb, mask=padding_masks + x, enc, cond_emb, dyn_ctx=dyn_ctx, mask=padding_masks ) season += residual_season trend += residual_trend @@ -818,10 +831,13 @@ def __init__( self.cond_dim = cond_dim if cond_dim is not None: - # Map context embedding (B, cond_dim) -> (B, n_embd) + # Map static context embedding (B, cond_dim) -> (B, n_embd) self.cond_proj = nn.Linear(cond_dim, n_embd) + # Map dynamic context sequence (B, T, cond_dim) -> (B, T, n_embd) + self.dyn_ctx_proj = nn.Linear(cond_dim, n_embd) else: self.cond_proj = None + self.dyn_ctx_proj = None self.time_emb = SinusoidalPosEmb(n_embd) @@ -887,13 +903,16 @@ def __init__( n_embd, dropout=resid_pdrop, max_len=max_len ) - def forward(self, input, t, padding_masks=None, return_res=False, cond=None): - # cond: (B, cond_dim) or None + def forward(self, input, t, padding_masks=None, return_res=False, cond=None, dyn_ctx=None): + # cond: (B, cond_dim) static context or None + # dyn_ctx: (B, T, cond_dim) dynamic context sequence or None # Ensure float32 so conv/linear (float32 params) never see double input if input.is_floating_point(): input = input.float() if cond is not None and cond.is_floating_point(): cond = cond.float() + if dyn_ctx is not None and dyn_ctx.is_floating_point(): + dyn_ctx = dyn_ctx.float() _nan_check(input, "forward input") t_emb = self.time_emb(t) _nan_check(t_emb, "forward t_emb") @@ -907,6 +926,12 @@ def forward(self, input, t, padding_masks=None, return_res=False, cond=None): total_cond_emb = t_emb _nan_check(total_cond_emb, "forward total_cond_emb") + # Project dynamic context sequence to n_embd for cross-attention + dyn_ctx_emb = None + if dyn_ctx is not None and self.dyn_ctx_proj is not None: + dyn_ctx_emb = self.dyn_ctx_proj(dyn_ctx) # (B, T, n_embd) + _nan_check(dyn_ctx_emb, "forward dyn_ctx_emb") + emb = self.emb(input) _nan_check(emb, "forward emb") inp_enc = self.pos_enc(emb) @@ -918,7 +943,7 @@ def forward(self, input, t, padding_masks=None, return_res=False, cond=None): inp_dec = self.pos_dec(emb) _nan_check(inp_dec, "forward inp_dec") output, mean, trend, season = self.decoder( - inp_dec, total_cond_emb, enc_cond, padding_masks=padding_masks + inp_dec, total_cond_emb, enc_cond, dyn_ctx=dyn_ctx_emb, padding_masks=padding_masks ) _nan_check(output, "forward decoder output") _nan_check(mean, "forward decoder mean") diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 9ab0951..90a4a19 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -247,6 +247,18 @@ def main() -> None: default=None, help="Limit evaluation to this many samples (applied as dataset max_samples override).", ) + parser.add_argument( + "--cfg-scale", + type=float, + default=1.0, + help=( + "Classifier-free guidance scale (default 1.0 = no guidance). " + "Values >1 blend unconditional and conditional noise predictions: " + "pred = uncond + scale*(cond - uncond). " + "Requires model trained with context_embed_dropout > 0. " + "Only applies to fast (DDIM) sampling." + ), + ) args = parser.parse_args() use_run_path = args.run_path is not None @@ -393,7 +405,12 @@ def main() -> None: # gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) cfg.dataset = gen.model.cfg.dataset - + + # Set CFG scale on the model instance (read by generate() at inference time) + if args.cfg_scale != 1.0: + logging.info("Classifier-free guidance scale: %.2f", args.cfg_scale) + gen.model._cfg_scale = args.cfg_scale + print("\n" + "=" * 60) print("EVALUATION RESULTS") print("=" * 60 + "\n") From fb5f64502992236850295d421bb04650db8e833f Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 2 Mar 2026 20:58:27 -0500 Subject: [PATCH 41/50] Select normalization, dropout regularization (massive improvement --- cents/config/dataset/airquality.yaml | 2 +- cents/config/dataset/pecanstreet.yaml | 3 ++- cents/config/model/diffusion_ts.yaml | 2 +- cents/config/trainer/normalizer.yaml | 4 ++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cents/config/dataset/airquality.yaml b/cents/config/dataset/airquality.yaml index 2aeb931..3ce51c2 100644 --- a/cents/config/dataset/airquality.yaml +++ b/cents/config/dataset/airquality.yaml @@ -14,7 +14,7 @@ reduce_cardinality: False time_series_dims: 1 normalizer_stats_mode: group # Normalizer conditions only on these (e.g. per-station); diffusion still gets full context_vars -normalizer_group_vars: ["station", "year", "month"] +normalizer_group_vars: ["station"] # Targets (what becomes the merged "timeseries" dims) # NOTE: use PMcoarse instead of PM10 diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index 2d16ace..3e1ae25 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -10,12 +10,13 @@ shuffle: True skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP) max_samples: null # Limit dataset size (null = use all data) path: "./data/pecanstreet/csv" -time_series_columns: ["grid", "solar"] +time_series_columns: ["grid"] data_columns: ["dataid","local_15min","car1","grid","solar"] metadata_columns: ["dataid","building_type","solar","car1","city","state","total_square_footage","house_construction_year"] user_group: all # non_pv_users, all, pv_users numeric_context_bins: 5 normalizer_stats_mode: group +normalizer_group_vars: ["state", "city"] context_vars: diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 2383006..cae33d1 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -37,7 +37,7 @@ recon_guide_K: 3 # inner steps per t for alg2 (int or list for K[t]) # Optional: dual head for x̂_a / x̂_b (set to cond_len used in recon-guided sampling) recon_cond_len: null # int or null; if set, use fc_a / fc_b for first vs rest of sequence # Context embedding dropout (training only) for more robust recon-guided sampling -context_embed_dropout: 0.1 # probability of zeroing entire context embedding per sample (CFG-compatible) +context_embed_dropout: 0 # probability of zeroing entire context embedding per sample (CFG-compatible) blue_noise_power: 0.0 # 0.0 = white noise, 1.0 = blue noise, 2.0 = violet noise # When true and time_series_dims > 1, noise is correlated across dimensions (same draw per timestep). correlated_noise: True \ No newline at end of file diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index fa3e171..f34db7f 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,10 +1,10 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: 2, +devices: 1, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 -n_epochs: 2000 +n_epochs: 500 batch_size: 4096 lr: 3e-4 save_cycle: 5000 From 7a58811c44d6a29b2593046b77f53fe4d1dcce6e Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Sun, 8 Mar 2026 17:57:35 -0400 Subject: [PATCH 42/50] Added metraq dataset --- cents/config/dataset/metraq.yaml | 52 +++++ cents/config/dataset/pecanstreet.yaml | 2 +- cents/config/model/diffusion_ts.yaml | 6 +- cents/config/trainer/diffusion_ts.yaml | 1 + cents/config/trainer/normalizer.yaml | 4 +- cents/datasets/metraq.py | 280 +++++++++++++++++++++++++ cents/datasets/timeseries_dataset.py | 13 +- cents/models/base.py | 1 + cents/models/context.py | 20 +- cents/models/diffusion_ts.py | 18 +- cents/models/model_utils.py | 38 ++-- cents/trainer.py | 1 + scripts/eval_pretrained.py | 22 +- scripts/train.py | 14 ++ 14 files changed, 427 insertions(+), 45 deletions(-) create mode 100644 cents/config/dataset/metraq.yaml create mode 100644 cents/datasets/metraq.py diff --git a/cents/config/dataset/metraq.yaml b/cents/config/dataset/metraq.yaml new file mode 100644 index 0000000..df87448 --- /dev/null +++ b/cents/config/dataset/metraq.yaml @@ -0,0 +1,52 @@ +name: metraq +geography: null +normalize: True +scale: False +use_learned_normalizer: True +threshold: 8 +seq_len: 24 +shuffle: True +skip_heavy_processing: False +max_samples: null +path: "./data/metraq" +numeric_context_bins: 1 +reduce_cardinality: False +time_series_dims: 1 +normalizer_stats_mode: group +# Normalizer conditions only on these (e.g. per-station); diffusion still gets full context_vars +normalizer_group_vars: ["sensor_name"] + +# Targets (what becomes the merged "timeseries" dims) +# NOTE: use PMcoarse instead of PM10 +time_series_columns: ["PM2.5"] + +# Raw CSV columns to load +# Keep wd/WSPM because we need them to engineer wind_u/wind_v +# Keep PM10 because we need it to engineer PMcoarse +data_columns: + - "entry_date" + - "magnitude_name" + - "sensor_name" + - "value" + # - "utm_x" + # - "utm_y" + +context_vars: + # static categorical + year: ["categorical", 6] + month: ["categorical", 12] + weekday: ["categorical", 7] + sensor_name: ["categorical", 24] + # utm_x: ["continuous", null] + # utm_y: ["continuous", null] + + # dynamic time-series context + # WS and WD are decomposed into wind_u/wind_v in preprocessing to handle + # the circularity of wind direction (WD=355° ≈ WD=5°, but z-score would give opposite signs). + T: ["time_series", null] + wind_u: ["time_series", null] + wind_v: ["time_series", null] + wd_valid: ["time_series", null] + RH: ["time_series", null] + AP: ["time_series", null] + R: ["time_series", null] \ No newline at end of file diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index 3e1ae25..ad30e36 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -16,7 +16,7 @@ metadata_columns: ["dataid","building_type","solar","car1","city","state","total user_group: all # non_pv_users, all, pv_users numeric_context_bins: 5 normalizer_stats_mode: group -normalizer_group_vars: ["state", "city"] +# normalizer_group_vars: ["state", "city"] context_vars: diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index cae33d1..6b46430 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -22,12 +22,12 @@ attn_pd: 0.1 resid_pd: 0.1 kernel_size: null padding_size: null -use_ff: True +use_ff: False reg_weight: null gradient_accumulate_every: 2 ema_decay: 0.99 ema_update_interval: 10 -use_ema_sampling: True +use_ema_sampling: False k_bins: 20 # Reconstruction-guided sampling (Algorithms 1 & 2) recon_guide_eta: 0.1 # gradient scale for guidance @@ -40,4 +40,4 @@ recon_cond_len: null # int or null; if set, use fc_a / fc_b for first vs context_embed_dropout: 0 # probability of zeroing entire context embedding per sample (CFG-compatible) blue_noise_power: 0.0 # 0.0 = white noise, 1.0 = blue noise, 2.0 = violet noise # When true and time_series_dims > 1, noise is correlated across dimensions (same draw per timestep). -correlated_noise: True \ No newline at end of file +correlated_noise: False \ No newline at end of file diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 970321a..0f3cb48 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -3,6 +3,7 @@ accelerator: auto devices: auto strategy: ddp_find_unused_parameters_true gradient_accumulate_every: 4 +gradient_clip_val: 1.0 log_every_n_steps: 1 batch_size: 512 max_epochs: 2500 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index f34db7f..d0d6b98 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,10 +1,10 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: 1, +devices: 3, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 -n_epochs: 500 +n_epochs: 200 batch_size: 4096 lr: 3e-4 save_cycle: 5000 diff --git a/cents/datasets/metraq.py b/cents/datasets/metraq.py new file mode 100644 index 0000000..7cce0b0 --- /dev/null +++ b/cents/datasets/metraq.py @@ -0,0 +1,280 @@ +import os +import warnings +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides + +from cents.datasets.timeseries_dataset import TimeSeriesDataset + +warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class MetraqDataset(TimeSeriesDataset): + """ + Dataset class for Metraq time series data. + + Handles loading and preprocessing including normalization and context variables. + Data: https://huggingface.co/datasets/dmariaa70/METRAQ-Air-Quality + + Attributes: + cfg (DictConfig): Hydra config for the dataset. + name (str): Dataset name. + geography (str): Geographic region selector. + normalize (bool): Whether to apply normalization. + threshold (Tuple[int, int]): Range filter for grid values. + include_generation (bool): If True, include solar series. + """ + + def __init__( + self, + cfg: Optional[DictConfig] = None, + overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None, + ): + """ + Initialize and preprocess the PecanStreet dataset. + + Loads metadata and timeseries CSVs, then applies filtering, + grouping, user-subsetting, and calls the base class for + further preprocessing (normalization, merging, rarity flags). + + Args: + cfg (Optional[DictConfig]): Override Hydra config; if None, + load from `config/dataset/pecanstreet.yaml`. + overrides (Optional[List[str]]): Override Hydra config; if None, + load from `config/dataset/pecanstreet.yaml` and apply overrides. + + Raises: + FileNotFoundError: If required CSV files are missing. + """ + if cfg is None: + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "pecanstreet.yaml")) + if overrides: + cfg = apply_overrides(cfg, overrides) + + self.cfg = cfg + self.name = cfg.name + self.geography = cfg.geography + self.normalize = cfg.normalize + self.target_time_series_columns = cfg.time_series_columns + + self.threshold = (-1 * int(cfg.threshold), int(cfg.threshold)) + self.time_series_dims = cfg.time_series_dims + + self._load_data() + + ts_cols: List[str] = self.cfg.time_series_columns[: self.time_series_dims] + + self.context_time_series_columns = {k:v[1] for k,v in self.cfg.context_vars.items() if v[0] == "time_series"} + self.context_series_names = list(self.context_time_series_columns.keys()) + + super().__init__( + data=self.data, + time_series_column_names=ts_cols, + context_var_column_names=list(self.cfg.context_vars.keys()), + seq_len=self.cfg.seq_len, + normalize=self.cfg.normalize, + scale=self.cfg.scale, + skip_heavy_processing=cfg.get('skip_heavy_processing', False), + size=cfg.get('max_samples', None), + force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, + ) + + def _load_data(self) -> None: + """ + Populates self.data DataFrames. + + Raises: + FileNotFoundError: If any required CSV file is missing. + """ + module_dir = os.path.dirname(os.path.abspath(__file__)) + path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) + "/metraq_aq_processed.csv" + + data = pd.read_csv(path) + + if self.geography: + data = data.loc[data.sensor_name.isin(self.geography)].copy() + + self.data = data + + def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: + """ + Convert timestamps, assemble sequences of length seq_len, and merge metadata. + + Args: + data (pd.DataFrame): Raw concatenated grid (and solar) rows. + + Returns: + pd.DataFrame: One row per sequence, with array-valued 'grid' and + optional 'solar' columns plus context and metadata fields. + """ + data = data.copy() + data["timestamp"] = pd.to_datetime(data["entry_date"], utc=True) + data["year"] = data["timestamp"].dt.year + data["month"] = data["timestamp"].dt.month_name() + data["weekday"] = data["timestamp"].dt.day_name() + data["day"] = data["timestamp"].dt.day + + ctx_ts = list(self.context_series_names) + tgt_ts = list(self.target_time_series_columns) + + if "PM10" in data.columns and "PM2.5" in data.columns: + pm10 = pd.to_numeric(data["PM10"], errors="coerce") + pm25 = pd.to_numeric(data["PM2.5"], errors="coerce") + data["PMcoarse"] = (pm10 - pm25).clip(lower=0.0) + + if "PMcoarse" in data.columns: + tgt_ts = ["PMcoarse" if c == "PM10" else c for c in tgt_ts] + + # Decompose circular wind direction into Cartesian components so z-score + # normalization is meaningful. WD=355° and WD=5° are 10° apart but would + # get opposite signs after z-scoring — wind_u/wind_v avoids this. + if "WD" in data.columns and "WS" in data.columns: + wd_deg = pd.to_numeric(data["WD"], errors="coerce") + ws = pd.to_numeric(data["WS"], errors="coerce").clip(lower=0.0) + # Binary mask: 1 where WD is measured, 0 where it is missing + data["wd_valid"] = wd_deg.notna().astype(np.int8) + wd_deg = wd_deg.fillna(0.0) + ws = ws.fillna(0.0) + wd_rad = np.deg2rad(wd_deg) + data["wind_u"] = ws * np.sin(wd_rad) + data["wind_v"] = ws * np.cos(wd_rad) + # Replace WS/WD in ctx_ts with wind_u/wind_v (handles legacy configs that + # listed WS/WD; current config lists wind_u/wind_v/wd_valid directly). + ctx_ts = [ + "wind_u" if c == "WS" else "wind_v" if c == "WD" else c + for c in ctx_ts + ] + + ts_cols = ctx_ts + tgt_ts + + missing = [c for c in ts_cols if c not in data.columns] + if missing: + raise ValueError(f"Missing required time-series columns after preprocessing: {missing}") + + data = data.sort_values(["sensor_name", "timestamp"]) + + # print(data) + + group_keys = ["sensor_name", "year", "month", "day", "weekday"] + + # Continuous (scalar) context vars are constant per station — carry them through + # with "first" so they survive the groupby without being collapsed into lists. + static_continuous_cols = [ + k for k, v in self.cfg.context_vars.items() + if v[0] == "continuous" and k in data.columns + ] + agg_dict = {c: list for c in ts_cols} + agg_dict.update({c: "first" for c in static_continuous_cols}) + + grouped = ( + data.groupby(group_keys, as_index=False, sort=False) + .agg(agg_dict) + ) + + # print(grouped) + + for c in ts_cols: + grouped[c] = grouped[c].map(np.asarray) + + len_col = tgt_ts[0] if len(tgt_ts) > 0 else ts_cols[0] + grouped = grouped[grouped[len_col].apply(len) == self.cfg.seq_len].reset_index(drop=True) + + grouped = self._handle_missing_data(grouped) + + # print("POST CLEAN") + # print(grouped) + + ctx_numeric = [c for c in ctx_ts if c not in self.categorical_time_series] + + log1p_channels = {"R"} + binary_channels = {"wd_valid"} # already in [0, 1] — skip z-scoring + clip_bound = 5.0 + eps = 1e-8 + # Compute global mean/std per channel over all rows and timesteps + ctx_stats = {} + for c in ctx_numeric: + # stacked shape: (N, L) + X = np.stack(grouped[c].values).astype(np.float32) + + if c in binary_channels: + # Already in [0, 1] — pass through without z-scoring + grouped[c] = list(X) + continue + + if c in log1p_channels: + X = np.log1p(np.clip(X, a_min=0.0, a_max=None)) + + mu = float(X.mean()) + sd = float(X.std()) + if sd < 1e-6: + sd = 1.0 # avoid divide-by-zero; effectively makes it "center only" + ctx_stats[c] = (mu, sd) + + Xn = (X - mu) / (sd + eps) + Xn = np.clip(Xn, -clip_bound, clip_bound).astype(np.float32) + + grouped[c] = list(Xn) + + # (Optional) store for later inverse-transform / debugging + self.context_ts_stats_ = ctx_stats + + # arrays -> tuples (hashable) + for c in ts_cols: + grouped[c] = grouped[c].map(tuple) + + return grouped + + def _handle_missing_data(self, data): + numeric_series = [c for c in self.context_series_names if c not in self.categorical_time_series] + + mask = data[numeric_series].applymap(is_all_nan).any(axis=1) if numeric_series else pd.Series([False] * len(data)) + data = data[~mask] + + for col in numeric_series: + data[col] = data[col].apply(fill_with_row_mean) + + # categorical time series must have no NaNs + cat_cols = list(self.categorical_time_series.keys()) + if cat_cols: + mask = data[cat_cols].applymap(is_any_nan).any(axis=1) + data = data[~mask] + + # ensure no NaNs in target series columns + for tcol in self.target_time_series_columns: + # If you replaced PM10->PMcoarse in cfg, this remains correct + if tcol in data.columns: + data = data.loc[data[tcol].apply(lambda x: not np.isnan(np.asarray(x, dtype=float)).any())] + + def row_has_low_std(row, cols, thresh=0.01): + for c in cols: + arr = np.asarray(row[c], dtype=np.float32) + if arr.std() < thresh: + return True + return False + + mask = data.apply( + lambda row: row_has_low_std(row, self.target_time_series_columns, thresh=0.01), + axis=1 + ) + + data = data[~mask] + return data + + +def is_all_nan(arr): + return pd.isna(arr).all() + +def is_any_nan(arr): + return pd.isna(arr).any() + +def fill_with_row_mean(lst): + s = pd.Series(lst, dtype=float) + m = s.mean(skipna=True) + return s.fillna(m).tolist() diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index b2bdd59..5f11dd7 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -509,12 +509,19 @@ def get_frequency_based_rarity(self) -> pd.DataFrame: Returns: pd.DataFrame: DataFrame with 'is_frequency_rare' column. """ - freq = self.data.groupby(self.context_vars).size().reset_index(name="count") + # Continuous vars are floats (e.g. UTM coordinates) — grouping by them would create + # one group per unique value, making rarity meaningless. Use only discrete vars. + continuous = set(getattr(self, "continuous_vars", [])) + groupby_vars = [v for v in self.context_vars if v not in continuous] + if not groupby_vars: + self.data["is_frequency_rare"] = False + return self.data + freq = self.data.groupby(groupby_vars).size().reset_index(name="count") threshold = freq["count"].quantile(0.1) freq["is_frequency_rare"] = freq["count"] < threshold self.data = self.data.merge( - freq[self.context_vars + ["is_frequency_rare"]], - on=self.context_vars, + freq[groupby_vars + ["is_frequency_rare"]], + on=groupby_vars, how="left", ) return self.data diff --git a/cents/models/base.py b/cents/models/base.py index 17c8405..10a82bb 100644 --- a/cents/models/base.py +++ b/cents/models/base.py @@ -62,6 +62,7 @@ def __init__(self, cfg: DictConfig = None): k: v for k, v in cfg.dataset.context_vars.items() if k in static_context_vars } + print(static_context_vars_dict) self.static_context_module = StaticContextModuleCls( static_context_vars_dict, emb_dim, diff --git a/cents/models/context.py b/cents/models/context.py index 671d73b..4dbf8ae 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -193,22 +193,10 @@ def forward(self, context_vars): # Process continuous variables (only those present in context_vars) for name, layer in self.continuous_projections.items(): if name in context_vars: - # # Reshape to (batch_size, 1) for linear layer - # # Ensure proper shape and gradient flow - continuous_val = context_vars[name] - # # Handle different input shapes - # if continuous_val.dim() == 0: - # # Scalar: add batch dimension - # continuous_val = continuous_val.unsqueeze(0) - # elif continuous_val.dim() == 1: - # # 1D tensor: add feature dimension - # continuous_val = continuous_val.unsqueeze(-1) - # # Ensure float type while preserving gradients - # if not continuous_val.is_floating_point(): - # continuous_val = continuous_val.float() - - # if continuous_val.dim() == 1: - # continuous_val = continuous_val.unsqueeze(-1) + continuous_val = context_vars[name].float() + # DataLoader stacks 0-dim scalars into (batch,); layer expects (batch, 1) + if continuous_val.dim() == 1: + continuous_val = continuous_val.unsqueeze(-1) encodings[name] = layer(continuous_val) embeddings = [] diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 9be4879..b2955c3 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -2,6 +2,8 @@ import math from typing import Any, Optional, Tuple +from pyparsing import alphas + import pytorch_lightning as pl import torch import torch.nn as nn @@ -194,6 +196,7 @@ def __init__(self, cfg: DictConfig): n_embd=cfg.model.d_model, conv_params=[cfg.model.kernel_size, cfg.model.padding_size], cond_dim=self.embedding_dim, + has_dynamic_ctx=self.dynamic_context_module is not None, ) self.blue_noise_power = cfg.model.blue_noise_power @@ -213,10 +216,14 @@ def __init__(self, cfg: DictConfig): raise ValueError("Unknown beta schedule") eps = 1e-5 - alphas = (1.0 - betas).double() - alphas_cumprod = torch.cumprod(alphas, dim=0).float() - alphas_cumprod = alphas_cumprod.clamp(min=eps, max=1.0 - eps) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0 - eps) + # alphas = (1.0 - betas).double() + # alphas_cumprod = torch.cumprod(alphas, dim=0).float() + # alphas_cumprod = alphas_cumprod.clamp(min=eps, max=1.0 - eps) + # alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0 - eps) + + alphas = 1.0 - betas # float32, no double + alphas_cumprod = torch.cumprod(alphas, dim=0) # no clamp + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) # exactly 1.0 self.num_timesteps = betas.shape[0] self.sampling_timesteps = default( @@ -656,6 +663,9 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: total_loss: Scalar training loss. """ ts_batch, static_context_batch, dynamic_context_batch = batch + # print("BEFORE PRINT I") + # print(ts_batch, static_context_batch, dynamic_context_batch) + # print("AFTER PRINT I") _nan_check(ts_batch, "training_step ts_batch") rec_loss, cond_class_logits, fourier_loss = self(ts_batch, static_context_batch, dynamic_context_batch) _nan_check(rec_loss, "training_step rec_loss") diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index f8d3b41..7ed194f 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -676,6 +676,7 @@ def __init__( mlp_hidden_times=4, activate="GELU", condition_dim=1024, + has_dynamic_ctx=False, ): super().__init__() @@ -695,17 +696,20 @@ def __init__( attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, ) - # Cross-attention for temporally-aligned dynamic context (same seq_len as target). - # Conditioned via AdaLN on the same diffusion-timestep + static-context embedding - # as the rest of the block, so global conditioning still flows through here. - self.ln_dyn = AdaLayerNorm(n_embd) - self.attn_dyn = CrossAttention( - n_embd=n_embd, - condition_embd=condition_dim, - n_head=n_head, - attn_pdrop=attn_pdrop, - resid_pdrop=resid_pdrop, - ) + # Cross-attention for temporally-aligned dynamic context — only created when dynamic + # context is actually present (has_dynamic_ctx=True). Saves ~495K params when absent. + if has_dynamic_ctx: + self.ln_dyn = AdaLayerNorm(n_embd) + self.attn_dyn = CrossAttention( + n_embd=n_embd, + condition_embd=condition_dim, + n_head=n_head, + attn_pdrop=attn_pdrop, + resid_pdrop=resid_pdrop, + ) + else: + self.ln_dyn = None + self.attn_dyn = None self.ln1_1 = AdaLayerNorm(n_embd) @@ -732,7 +736,7 @@ def forward(self, x, encoder_output, cond_emb, dyn_ctx=None, mask=None): x = x + a a, att = self.attn2(self.ln1_1(x, cond_emb), encoder_output, mask=mask) x = x + a - if dyn_ctx is not None: + if dyn_ctx is not None and self.attn_dyn is not None: a, _ = self.attn_dyn(self.ln_dyn(x, cond_emb), dyn_ctx) x = x + a @@ -765,6 +769,7 @@ def __init__( mlp_hidden_times=4, block_activate="GELU", condition_dim=512, + has_dynamic_ctx=False, ): super().__init__() self.d_model = n_embd @@ -781,6 +786,7 @@ def __init__( mlp_hidden_times=mlp_hidden_times, activate=block_activate, condition_dim=condition_dim, + has_dynamic_ctx=has_dynamic_ctx, ) for _ in range(n_layer) ] @@ -823,18 +829,19 @@ def __init__( max_len=2048, conv_params=None, cond_dim=None, + has_dynamic_ctx=False, **kwargs ): super().__init__() self.emb = Conv_MLP(n_feat, n_embd, resid_pdrop=resid_pdrop) self.inverse = Conv_MLP(n_embd, n_feat, resid_pdrop=resid_pdrop) - + self.cond_dim = cond_dim if cond_dim is not None: # Map static context embedding (B, cond_dim) -> (B, n_embd) self.cond_proj = nn.Linear(cond_dim, n_embd) - # Map dynamic context sequence (B, T, cond_dim) -> (B, T, n_embd) - self.dyn_ctx_proj = nn.Linear(cond_dim, n_embd) + # Map dynamic context sequence (B, T, cond_dim) -> (B, T, n_embd); only when needed + self.dyn_ctx_proj = nn.Linear(cond_dim, n_embd) if has_dynamic_ctx else None else: self.cond_proj = None self.dyn_ctx_proj = None @@ -898,6 +905,7 @@ def __init__( mlp_hidden_times, block_activate, condition_dim=n_embd, + has_dynamic_ctx=has_dynamic_ctx, ) self.pos_dec = LearnablePositionalEncoding( n_embd, dropout=resid_pdrop, max_len=max_len diff --git a/cents/trainer.py b/cents/trainer.py index c5bfcd6..445e319 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -275,6 +275,7 @@ def _instantiate_trainer(self) -> pl.Trainer: precision=tc.precision, log_every_n_steps=tc.get("log_every_n_steps", 1), accumulate_grad_batches=tc.get("gradient_accumulate_every", 1), + gradient_clip_val=tc.get("gradient_clip_val", None), callbacks=callbacks, logger=logger, default_root_dir=self.cfg.run_dir, diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 90a4a19..877496e 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -1,9 +1,11 @@ import logging import math import os +import random from pathlib import Path import json +import numpy as np import torch import torch.nn.functional as F @@ -15,6 +17,7 @@ from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset from cents.datasets.airquality import AirQualityDataset +from cents.datasets.metraq import MetraqDataset from cents.eval.eval import Evaluator from cents.utils.config_loader import load_yaml, apply_overrides from cents.utils.utils import set_context_config_path @@ -54,6 +57,8 @@ def _load_dataset(name: str, dataset_cfg: OmegaConf, run_dir: str = None): return CommercialDataset(**kwargs) if name == "airquality": return AirQualityDataset(**kwargs) + if name == "metraq": + return MetraqDataset(**kwargs) raise ValueError(f"Dataset {name} not supported. Use: pecanstreet, commercial, airquality.") @@ -148,7 +153,7 @@ def main() -> None: "--dataset", type=str, default="pecanstreet", - choices=("pecanstreet", "commercial", "airquality"), + choices=("pecanstreet", "commercial", "airquality", "metraq"), help="Dataset name (must match the one used to train the model).", ) parser.add_argument( @@ -247,6 +252,12 @@ def main() -> None: default=None, help="Limit evaluation to this many samples (applied as dataset max_samples override).", ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed for reproducible sampling (sets Python, NumPy, and PyTorch seeds).", + ) parser.add_argument( "--cfg-scale", type=float, @@ -406,6 +417,15 @@ def main() -> None: # gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) cfg.dataset = gen.model.cfg.dataset + # Set random seed for reproducible sampling + if args.seed is not None: + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + logging.info("Random seed set to %d", args.seed) + # Set CFG scale on the model instance (read by generate() at inference time) if args.cfg_scale != 1.0: logging.info("Classifier-free guidance scale: %.2f", args.cfg_scale) diff --git a/scripts/train.py b/scripts/train.py index 62310e2..ed79245 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -5,6 +5,8 @@ from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset from cents.datasets.airquality import AirQualityDataset +from cents.datasets.metraq import MetraqDataset + from cents.trainer import Trainer from cents.utils.utils import set_context_config_path, set_context_overrides, get_context_config from cents.utils.config_loader import load_yaml, apply_overrides @@ -118,6 +120,13 @@ def main(args) -> None: force_retrain_normalizer=args.force_retrain_normalizer, run_dir=str(run_dir), ) + + elif args.dataset == "metraq": + dataset = MetraqDataset( + cfg=dataset_cfg, + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), + ) else: raise ValueError(f"Dataset {args.dataset} not supported") @@ -143,6 +152,8 @@ def main(args) -> None: f"model.tc_loss_weight={TC_LOSS_WEIGHT}", f"wandb.name=training_dai_{MODEL_NAME}_{datetime.now().strftime('%Y%m%d-%H%M%S')}_L{CR_LOSS_WEIGHT}_TC_{TC_LOSS_WEIGHT}_dim2", ] + if args.model_overrides: + trainer_overrides.extend(args.model_overrides) trainer = Trainer( model_type=MODEL_NAME, @@ -188,6 +199,9 @@ def main(args) -> None: parser.add_argument("--resume-from-checkpoint", type=str, default=None, help="Path to checkpoint file (.ckpt) to resume training from", ) + parser.add_argument("--model-overrides", type=str, nargs="*", default=[], + help="Override model config values (e.g., 'model.cond_emb_dim=16' 'model.attn_pd=0.0')", + ) parser.add_argument("--run-name", type=str, required=True, help="Name of this run. A directory runs// will be created for checkpoints, cache, and summary.", ) From da10776ce96a2445a4ab4b0339a46ffbf12bd8b2 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 11 Mar 2026 10:27:01 -0400 Subject: [PATCH 43/50] metraq dataset implementation --- cents/datasets/metraq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cents/datasets/metraq.py b/cents/datasets/metraq.py index 7cce0b0..9221807 100644 --- a/cents/datasets/metraq.py +++ b/cents/datasets/metraq.py @@ -277,4 +277,4 @@ def is_any_nan(arr): def fill_with_row_mean(lst): s = pd.Series(lst, dtype=float) m = s.mean(skipna=True) - return s.fillna(m).tolist() + return s.fillna(m).tolist() \ No newline at end of file From eeff92933419d20e7bd8feffdec7e8af07b2d66f Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 11 Mar 2026 16:12:44 -0400 Subject: [PATCH 44/50] ema + gradient clipping --- cents/config/model/diffusion_ts.yaml | 16 +++--- cents/config/trainer/diffusion_ts.yaml | 2 +- cents/config/trainer/normalizer.yaml | 2 +- cents/models/diffusion_ts.py | 73 ++++++++++++-------------- scripts/eval_pretrained.py | 11 ++++ scripts/train.py | 12 +++++ 6 files changed, 67 insertions(+), 49 deletions(-) diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 6b46430..ce6949b 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -3,7 +3,7 @@ name: diffusion_ts context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 noise_dim: 256 -cond_emb_dim: 64 +cond_emb_dim: 16 n_layer_enc: 4 n_layer_dec: 5 d_model: 128 @@ -18,16 +18,16 @@ beta_schedule: cosine #linear diffusion ts paper uses linear schedule n_heads: 4 mlp_hidden_times: 4 eta: 0.0 -attn_pd: 0.1 -resid_pd: 0.1 +attn_pd: 0.0 +resid_pd: 0.0 kernel_size: null padding_size: null -use_ff: False +use_ff: True reg_weight: null gradient_accumulate_every: 2 -ema_decay: 0.99 -ema_update_interval: 10 -use_ema_sampling: False +ema_decay: 0.9999 +ema_update_interval: 1 +use_ema_sampling: True k_bins: 20 # Reconstruction-guided sampling (Algorithms 1 & 2) recon_guide_eta: 0.1 # gradient scale for guidance @@ -40,4 +40,4 @@ recon_cond_len: null # int or null; if set, use fc_a / fc_b for first vs context_embed_dropout: 0 # probability of zeroing entire context embedding per sample (CFG-compatible) blue_noise_power: 0.0 # 0.0 = white noise, 1.0 = blue noise, 2.0 = violet noise # When true and time_series_dims > 1, noise is correlated across dimensions (same draw per timestep). -correlated_noise: False \ No newline at end of file +correlated_noise: True \ No newline at end of file diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 0f3cb48..c105cc3 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -3,12 +3,12 @@ accelerator: auto devices: auto strategy: ddp_find_unused_parameters_true gradient_accumulate_every: 4 -gradient_clip_val: 1.0 log_every_n_steps: 1 batch_size: 512 max_epochs: 2500 base_lr: 1e-4 eval_after_training: False +gradient_clip_val: 1.0 checkpoint: save_last: True # Save final model diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index d0d6b98..c39525d 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -4,7 +4,7 @@ devices: 3, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 -n_epochs: 200 +n_epochs: 2000 batch_size: 4096 lr: 3e-4 save_cycle: 5000 diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index b2955c3..5b20448 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -732,36 +732,24 @@ def configure_optimizers(self) -> dict: self.parameters(), lr=self.cfg.trainer.base_lr, betas=(0.9, 0.96) ) scheduler = ReduceLROnPlateau(optimizer, **self.cfg.trainer.lr_scheduler_params) + print(self.cfg.trainer.get("gradient_clip_val", 0.0), "gradient clip val") return { "optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss", + "gradient_clip_val": self.cfg.trainer.get("gradient_clip_val", 0.0), } def on_train_start(self) -> None: """ Initialize EMA helper at start of training. """ - self._ema = EMA( - self.model, - beta=self.cfg.model.ema_decay, - update_every=self.cfg.model.ema_update_interval, - ) - - # def on_after_backward(self) -> None: - # """ - # Check gradients after backward pass but before optimizer step. - # This is the right place to inspect gradients before they're zeroed. - # """ - # # Get current batch index from trainer - # for name, p in self.named_parameters(): - # if p.grad is None: - # continue - # if p.grad.stride() != p.stride(): - # print("stride mismatch:", name, - # "param", tuple(p.shape), p.stride(), - # "grad", tuple(p.grad.shape), p.grad.stride()) - # break + if self._ema is None: + self._ema = EMA( + self.model, + beta=self.cfg.model.ema_decay, + update_every=self.cfg.model.ema_update_interval, + ) def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: @@ -804,6 +792,23 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: # raise ValueError("No EMA model weights found in checkpoint") # else: # raise ValueError("No EMA keys found in checkpoint") + def on_load_checkpoint(self, checkpoint: dict) -> None: + if 'ema_state_dict' in checkpoint: + if self._ema is None: + self._ema = EMA( + self.model, + beta=self.cfg.model.ema_decay, + update_every=self.cfg.model.ema_update_interval, + ) + self._ema.ema_model.load_state_dict(checkpoint['ema_state_dict']) + print(f"[EMA] Restored EMA weights from checkpoint") + else: + print(f"[EMA] No EMA weights in checkpoint, initializing fresh") + + def on_save_checkpoint(self, checkpoint: dict) -> None: + if self._ema is not None: + checkpoint['ema_state_dict'] = self._ema.ema_model.state_dict() + def _predict_x0_from_xt_with_grad( self, x_t: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor, @@ -1214,40 +1219,30 @@ def stratified_timesteps(self, batch_size: int, num_timesteps: int, k_bins: int, class EMA(nn.Module): - """ - Exponential Moving Average (EMA) helper for model parameters. - """ - def __init__(self, model: nn.Module, beta: float = 0.9999, update_every: int = 10): + def __init__(self, model, beta, update_every): super().__init__() self.beta = beta self.update_every = update_every self.step = 0 - - # CRITICAL FIX 1: self.ema_model is the ONLY deepcopy. - # It holds the shadow weights. self.ema_model = copy.deepcopy(model) self.ema_model.eval() self.ema_model.requires_grad_(False) - # CRITICAL FIX 2: We keep a reference to the LIVE model (not a copy) - # so we can grab the latest trained weights during update(). - self.source_model = model + # Store as plain python attribute, not nn.Module attribute + # This prevents it being registered as a submodule and saved in state_dict + object.__setattr__(self, '_source_model', model) - # Buffer to store temporary weights for the context manager self.collected_params = [] - def update(self) -> None: - """ - Update the shadow parameters using the source model's current weights. - """ + def update(self): self.step += 1 if self.step % self.update_every != 0: return - with torch.no_grad(): - # Zip the shadow model (ema) against the live model (source) - for ema_p, src_p in zip(self.ema_model.parameters(), self.source_model.parameters()): - # ema_new = beta * ema_old + (1 - beta) * current_weight + for ema_p, src_p in zip( + self.ema_model.parameters(), + self._source_model.parameters() + ): ema_p.data.mul_(self.beta).add_(src_p.data, alpha=1.0 - self.beta) def store(self, parameters): diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 877496e..1b22d2d 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -270,6 +270,13 @@ def main() -> None: "Only applies to fast (DDIM) sampling." ), ) + parser.add_argument( + "--save-path", + type=str, + default=None, + help="Path to save evaluation results." + ) + args = parser.parse_args() use_run_path = args.run_path is not None @@ -492,6 +499,10 @@ def _sanitize_for_json(obj): with open(Path(args.save_dir) / "metrics.json", "w") as f: json.dump(_sanitize_for_json(metrics), f, indent=4) print(f"\n✅ Results saved to {Path(args.save_dir) / "metrics.json"}") + elif args.save_path: + with open(args.save_path, "w") as f: + json.dump(_sanitize_for_json(metrics), f, indent=4) + print(f"\n✅ Results saved to {args.save_path}") print("\n" + "=" * 60) diff --git a/scripts/train.py b/scripts/train.py index ed79245..bf426bb 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,6 +1,10 @@ from datetime import datetime import yaml from pathlib import Path +import numpy as np +import os +import torch +import random from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset @@ -80,6 +84,12 @@ def main(args) -> None: TC_LOSS_WEIGHT = args.tc_loss_weight run_name = args.run_name + np.random.seed(args.random_seed) + os.environ['PYTHONHASHSEED'] = str(args.random_seed) + random.seed(args.random_seed) + torch.manual_seed(args.random_seed) + + # Create run directory under runs/{dataset}/{run_name} RUNS_DIR.mkdir(parents=True, exist_ok=True) run_dir = RUNS_DIR / args.dataset / run_name @@ -205,6 +215,8 @@ def main(args) -> None: parser.add_argument("--run-name", type=str, required=True, help="Name of this run. A directory runs// will be created for checkpoints, cache, and summary.", ) + parser.add_argument("--random-seed", type=int, default=42, + help="Random seed for reproducibility",) args = parser.parse_args() main(args) \ No newline at end of file From 051dd8e07d472ae4cea83b1294e793053b09d276 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 16 Mar 2026 19:46:49 -0400 Subject: [PATCH 45/50] Intermittent FID eval --- cents/config/model/diffusion_ts.yaml | 2 +- cents/config/trainer/diffusion_ts.yaml | 10 +- cents/config/trainer/normalizer.yaml | 2 +- cents/eval/eval_metrics.py | 8 - cents/models/diffusion_ts.py | 82 +++++--- cents/trainer.py | 258 ++++++++++++++++++++++++- scripts/train.py | 22 ++- 7 files changed, 336 insertions(+), 48 deletions(-) diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index ce6949b..edd4718 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -12,7 +12,7 @@ sampling_timesteps: 200 sampling_batch_size: 4096 loss_type: l1 #l2 training_objective: v -loss_weighting: snr +loss_weighting: min_snr min_snr_gamma: 5.0 beta_schedule: cosine #linear diffusion ts paper uses linear schedule n_heads: 4 diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index c105cc3..c11877e 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -5,8 +5,9 @@ strategy: ddp_find_unused_parameters_true gradient_accumulate_every: 4 log_every_n_steps: 1 batch_size: 512 -max_epochs: 2500 +max_epochs: 5000 base_lr: 1e-4 +warmup_epochs: 100 # linear warmup from 1% to 100% of base_lr over first N epochs eval_after_training: False gradient_clip_val: 1.0 @@ -23,3 +24,10 @@ lr_scheduler_params: threshold: 1.0e-1 threshold_mode: rel verbose: false + +intermediate_fid: + enabled: True + every_n_epochs: 20 # check context-FID every N epochs + n_samples: 3500 # number of samples to generate (subsample of dataset) + fast_timesteps: 50 # DDIM steps to use (vs 200 default) for speed + top_k: 3 # keep top-k FID checkpoints + their neighbors diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index c39525d..96b3d49 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -4,7 +4,7 @@ devices: 3, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 -n_epochs: 2000 +n_epochs: 750 batch_size: 4096 lr: 3e-4 save_cycle: 5000 diff --git a/cents/eval/eval_metrics.py b/cents/eval/eval_metrics.py index ccb0917..1dbe86b 100644 --- a/cents/eval/eval_metrics.py +++ b/cents/eval/eval_metrics.py @@ -202,14 +202,6 @@ def Context_FID(ori_data: np.ndarray, generated_data: np.ndarray) -> float: return float("nan") return calculate_fid(ori_rep, gen_rep) - - ori_represenation = model.encode(ori_data, encoding_window="full_series") - gen_represenation = model.encode(generated_data, encoding_window="full_series") - idx = np.random.permutation(ori_data.shape[0]) - ori_represenation = ori_represenation[idx] - gen_represenation = gen_represenation[idx] - results = calculate_fid(ori_represenation, gen_represenation) - return results def compute_mig( diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index 5b20448..c50a250 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -216,14 +216,9 @@ def __init__(self, cfg: DictConfig): raise ValueError("Unknown beta schedule") eps = 1e-5 - # alphas = (1.0 - betas).double() - # alphas_cumprod = torch.cumprod(alphas, dim=0).float() - # alphas_cumprod = alphas_cumprod.clamp(min=eps, max=1.0 - eps) - # alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0 - eps) - - alphas = 1.0 - betas # float32, no double - alphas_cumprod = torch.cumprod(alphas, dim=0) # no clamp - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) # exactly 1.0 + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0).clamp(min=eps, max=1.0 - eps) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0 - eps) self.num_timesteps = betas.shape[0] self.sampling_timesteps = default( @@ -684,14 +679,9 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: loss = self.auxiliary_loss(outputs, labels) _nan_check(loss, f"training_step cond_loss[{var_name}]") cond_loss += loss.mean() - - # # if var_name in self.continuous_context_vars: - # # print(var_name) - # # print(loss) - # # print(outputs.mean(), labels.mean()) - - - # cond_loss /= len(cond_class_logits) + # Normalize by number of context variables so the weight is dataset-independent + # if len(cond_class_logits) > 0: + # cond_loss = cond_loss / len(cond_class_logits) h, _, _ = self._get_context_embedding(static_context_batch, dynamic_context_batch) _nan_check(h, "training_step h (for tc)") @@ -707,11 +697,16 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: ) _nan_check(total_loss, f"training_step total_loss batch_idx={batch_idx}") + # Skip this batch entirely if loss is bad — avoids corrupting weights before EMA can help + if not torch.isfinite(total_loss): + print(f"[training_step] Non-finite loss ({total_loss.item()}) at batch {batch_idx}, skipping.") + return None + self.log_dict( { "train_loss": total_loss.item(), "rec_loss": rec_loss.item(), - "cond_loss": cond_loss.item(), + "cond_loss": cond_loss.item() if isinstance(cond_loss, torch.Tensor) else float(cond_loss), "tc_loss": tc_term, "fourier_loss": fourier_loss.item(), }, @@ -732,12 +727,9 @@ def configure_optimizers(self) -> dict: self.parameters(), lr=self.cfg.trainer.base_lr, betas=(0.9, 0.96) ) scheduler = ReduceLROnPlateau(optimizer, **self.cfg.trainer.lr_scheduler_params) - print(self.cfg.trainer.get("gradient_clip_val", 0.0), "gradient clip val") return { "optimizer": optimizer, - "lr_scheduler": scheduler, - "monitor": "train_loss", - "gradient_clip_val": self.cfg.trainer.get("gradient_clip_val", 0.0), + "lr_scheduler": {"scheduler": scheduler, "monitor": "train_loss"}, } def on_train_start(self) -> None: @@ -745,17 +737,47 @@ def on_train_start(self) -> None: Initialize EMA helper at start of training. """ if self._ema is None: - self._ema = EMA( + object.__setattr__(self, '_ema', EMA( self.model, beta=self.cfg.model.ema_decay, update_every=self.cfg.model.ema_update_interval, - ) + )) + # EMA is not a registered submodule so PL won't move it automatically + self._ema.to(self.device) + def on_train_epoch_start(self) -> None: + """ + Apply linear LR warmup for the first warmup_epochs epochs. + After warmup, ReduceLROnPlateau takes over untouched. + """ + warmup_epochs = self.cfg.trainer.get("warmup_epochs", 0) + if warmup_epochs <= 0: + return + epoch = self.current_epoch + if epoch >= warmup_epochs: + return + target_lr = self.cfg.trainer.base_lr + warmup_lr = target_lr * max(0.01, epoch / warmup_epochs) + for opt in self.trainer.optimizers: + for pg in opt.param_groups: + pg["lr"] = warmup_lr + + def on_before_optimizer_step(self, optimizer) -> None: + """Log global gradient norm each step for observability.""" + total_norm_sq = 0.0 + for p in self.parameters(): + if p.grad is not None: + total_norm_sq += p.grad.detach().float().norm(2).item() ** 2 + grad_norm = total_norm_sq ** 0.5 + self.log("grad_norm", grad_norm, on_step=True, on_epoch=False, prog_bar=False) def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: """ Apply EMA update after each batch end. + Skip if the step was skipped due to NaN loss (outputs is None). """ + if outputs is None: + return if hasattr(self, '_ema') and self._ema: self._ema.update() @@ -792,14 +814,20 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: # raise ValueError("No EMA model weights found in checkpoint") # else: # raise ValueError("No EMA keys found in checkpoint") + def load_state_dict(self, state_dict, strict=True): + # Strip legacy _ema.* keys — EMA is restored separately via on_load_checkpoint. + # Old checkpoints have these because _ema was previously a registered submodule. + filtered = {k: v for k, v in state_dict.items() if not k.startswith('_ema.')} + return super().load_state_dict(filtered, strict=strict) + def on_load_checkpoint(self, checkpoint: dict) -> None: if 'ema_state_dict' in checkpoint: if self._ema is None: - self._ema = EMA( + object.__setattr__(self, '_ema', EMA( self.model, beta=self.cfg.model.ema_decay, update_every=self.cfg.model.ema_update_interval, - ) + )) self._ema.ema_model.load_state_dict(checkpoint['ema_state_dict']) print(f"[EMA] Restored EMA weights from checkpoint") else: @@ -1197,11 +1225,11 @@ def _ensure_ema_helper(self) -> None: """ if not hasattr(self, '_ema') or self._ema is None: print("Initializing EMA helper for inference...") - self._ema = EMA( + object.__setattr__(self, '_ema', EMA( self.model, beta=self.cfg.model.ema_decay, update_every=self.cfg.model.ema_update_interval, - ) + )) def stratified_timesteps(self, batch_size: int, num_timesteps: int, k_bins: int, device=None) -> torch.Tensor: device = device or "cpu" k_bins = min(k_bins, batch_size) diff --git a/cents/trainer.py b/cents/trainer.py index 445e319..923cf99 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -254,18 +254,34 @@ def _instantiate_trainer(self) -> pl.Trainer: callbacks.append(EvalAfterTraining(self.cfg, self.dataset)) + fid_cfg = tc.get("intermediate_fid", {}) + if fid_cfg.get("enabled", False) and self.dataset is not None: + callbacks.append(IntermediateFIDCallback( + cfg=self.cfg, + dataset=self.dataset, + every_n_epochs=fid_cfg.get("every_n_epochs", 20), + n_samples=fid_cfg.get("n_samples", 3500), + fast_timesteps=fid_cfg.get("fast_timesteps", 50), + top_k=fid_cfg.get("top_k", 3), + )) + if getattr(self.cfg, "run_dir", None): callbacks.append(LogLossToCsv(self.cfg.run_dir)) # ---- Logger ---- logger = False if getattr(self.cfg, "wandb", None) and self.cfg.wandb.enabled: + wandb_id_file = Path(self.cfg.run_dir) / "wandb_run_id.txt" + existing_run_id = wandb_id_file.read_text().strip() if wandb_id_file.exists() else None logger = WandbLogger( project=self.cfg.wandb.project or "cents", entity=self.cfg.wandb.entity, name=self.cfg.wandb.name, save_dir=self.cfg.run_dir, + id=existing_run_id, + resume="must" if existing_run_id else None, ) + callbacks.append(_WandbRunIdSaver(self.cfg.run_dir)) return pl.Trainer( max_epochs=tc.max_epochs, @@ -290,7 +306,8 @@ def __init__(self, run_dir: str): super().__init__() self.run_dir = Path(run_dir) self._csv_path = self.run_dir / "train_losses.csv" - self._header_written = False + # Don't rewrite the header if the file already exists (resume case) + self._header_written = self._csv_path.exists() def _ensure_header(self, metric_names: List[str]) -> None: if self._header_written: @@ -318,6 +335,20 @@ def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) csv.writer(f).writerow(row) +class _WandbRunIdSaver(Callback): + """Persist the W&B run ID to disk so --resume-from-checkpoint can continue the same run.""" + + def __init__(self, run_dir: str): + super().__init__() + self._id_file = Path(run_dir) / "wandb_run_id.txt" + + def on_train_start(self, trainer, pl_module): # noqa: ARG002 + run = getattr(trainer.logger, "experiment", None) + if run is not None and not self._id_file.exists(): + self._id_file.write_text(run.id) + print(f"[Cents] Saved W&B run ID {run.id} to {self._id_file}") + + class EvalAfterTraining(Callback): """Run full evaluator at the *end* of training and log metrics to W&B.""" @@ -336,3 +367,228 @@ def on_train_end(self, trainer, pl_module): run = getattr(trainer.logger, "experiment", None) if run is not None: run.log(results["metrics"]) + + +class IntermediateFIDCallback(Callback): + """ + Every N epochs: generate samples, compute context-FID, save a checkpoint. + + Checkpoint retention policy: + - Top-k epochs by FID (lower = better) are kept permanently. + - The FID-check epochs immediately before and after each top-k epoch are + kept as "neighbors" (i.e. ±every_n_epochs). + - The two most recent FID-check checkpoints are always kept as a rolling + buffer so that a newly-crowned top-k epoch's before-neighbor is available. + - Everything else is deleted after each FID check. + + Results are written to {run_dir}/intermediate_fid.csv and logged to W&B. + Checkpoints live in {run_dir}/fid_checkpoints/. + """ + + def __init__( + self, + cfg, + dataset, + every_n_epochs: int = 20, + n_samples: int = 3500, + fast_timesteps: int = 50, + top_k: int = 3, + ): + super().__init__() + self.cfg = cfg + self.dataset = dataset + self.every_n_epochs = every_n_epochs + self.n_samples = n_samples + self.fast_timesteps = fast_timesteps + self.top_k = top_k + self._csv_path = Path(cfg.run_dir) / "intermediate_fid.csv" + self._fid_ckpt_dir = Path(cfg.run_dir) / "fid_checkpoints" + self._fid_ckpt_dir.mkdir(parents=True, exist_ok=True) + # (fid, epoch, ckpt_path) — only currently-kept records + self._fid_records: List[tuple] = [] + self._header_written = self._csv_path.exists() + self._reload_records() + + def _reload_records(self) -> None: + """On resume: rebuild _fid_records from the CSV, keeping only entries whose checkpoint still exists on disk.""" + if not self._csv_path.exists(): + return + import csv as _csv + with open(self._csv_path, newline="") as f: + reader = _csv.DictReader(f) + for row in reader: + try: + epoch = int(row["epoch"]) - 1 # CSV stores epoch+1 + fid = float(row["context_fid"]) + except (KeyError, ValueError): + continue + ckpt_path = str(self._fid_ckpt_dir / f"fid_epoch={epoch:04d}.ckpt") + if Path(ckpt_path).exists(): + self._fid_records.append((fid, epoch, ckpt_path)) + if self._fid_records: + print(f"[IntermediateFID] Resumed with {len(self._fid_records)} prior FID records " + f"(best={min(r[0] for r in self._fid_records):.4f})") + + # ------------------------------------------------------------------ + # Main hook + # ------------------------------------------------------------------ + + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + epoch = trainer.current_epoch + if (epoch + 1) % self.every_n_epochs != 0: + return + + import numpy as np + import torch + from cents.eval.eval_metrics import Context_FID + + device = pl_module.device + dataset = self.dataset + + # Snapshot dataset.data, add rarity column, then restore after + orig_data = dataset.data + dataset.data = dataset.get_combined_rarity() + all_len = len(dataset.data) + n = min(self.n_samples, all_len) + rng = np.random.default_rng(42) + idx = rng.choice(all_len, size=n, replace=False) + real_data_subset = dataset.data.iloc[idx].reset_index(drop=True) + + # Build context tensors + continuous_vars = getattr(dataset, "continuous_vars", []) + static_context_vars = {} + for name in dataset.static_context_vars: + vals = real_data_subset[name].values + dtype = torch.float32 if name in continuous_vars else torch.long + static_context_vars[name] = torch.tensor(vals, dtype=dtype, device=device) + + dynamic_context_vars = {} + categorical_ts = getattr(dataset, "categorical_time_series", {}) + for name in dataset.dynamic_context_vars: + vals = real_data_subset[name].values + if len(vals) and hasattr(vals[0], "__len__") and not isinstance(vals[0], (str, bytes)): + arr = np.stack([ + np.asarray(v, dtype=np.float32 if name not in categorical_ts else np.int64) + for v in vals + ]) + else: + arr = np.asarray(vals, dtype=np.float32 if name not in categorical_ts else np.int64) + dtype = torch.long if name in categorical_ts else torch.float32 + dynamic_context_vars[name] = torch.tensor(arr, dtype=dtype, device=device) + + # Temporarily reduce sampling timesteps for speed, then restore + orig_sampling_timesteps = pl_module.sampling_timesteps + orig_fast_sampling = pl_module.fast_sampling + pl_module.sampling_timesteps = self.fast_timesteps + pl_module.fast_sampling = True + was_training = pl_module.training + pl_module.eval() + try: + generated_ts = pl_module.generate(static_context_vars, dynamic_context_vars).cpu().numpy() + finally: + pl_module.sampling_timesteps = orig_sampling_timesteps + pl_module.fast_sampling = orig_fast_sampling + dataset.data = orig_data + if was_training: + pl_module.train() + + if generated_ts.ndim == 2: + generated_ts = generated_ts.reshape(generated_ts.shape[0], -1, generated_ts.shape[1]) + + # Inverse-transform (mirrors evaluate_subset logic) + syn_data_subset = real_data_subset.copy() + syn_data_subset["timeseries"] = list(generated_ts) + normalizer = getattr(dataset, "_normalizer", None) + if not getattr(dataset, "normalize", True) and normalizer is not None: + def _inv(df): + split = dataset.split_timeseries(df.copy()) + split = normalizer.inverse_transform(split) + return dataset.merge_timeseries_columns(split) + real_data_inv = _inv(real_data_subset) + syn_data_inv = _inv(syn_data_subset) + else: + real_data_inv = dataset.inverse_transform(real_data_subset) + syn_data_inv = dataset.inverse_transform(syn_data_subset) + + real_data_array = np.stack(real_data_inv["timeseries"]) + syn_data_array = np.stack(syn_data_inv["timeseries"]) + + fid = Context_FID(real_data_array, syn_data_array) + print(f"[IntermediateFID] Epoch {epoch + 1}: Context-FID = {fid:.4f}") + + # Save checkpoint for this FID-check epoch, then prune + ckpt_path = str(self._fid_ckpt_dir / f"fid_epoch={epoch:04d}.ckpt") + trainer.save_checkpoint(ckpt_path) + self._fid_records.append((fid, epoch, ckpt_path)) + self._prune_fid_checkpoints() + + self._log_csv(epoch + 1, fid) + self._log_wandb(trainer, epoch, fid) + pl_module.log("intermediate_context_fid", fid, on_step=False, on_epoch=True, prog_bar=True) + + # ------------------------------------------------------------------ + # Checkpoint pruning + # ------------------------------------------------------------------ + + def _prune_fid_checkpoints(self) -> None: + """Delete FID checkpoints outside the keep-set.""" + if len(self._fid_records) < 2: + return + + all_epochs = [r[1] for r in self._fid_records] + epoch_set = set(all_epochs) + + # Top-k by FID (ascending — lower is better) + sorted_records = sorted(self._fid_records, key=lambda x: x[0]) + top_k_epochs = {r[1] for r in sorted_records[: self.top_k]} + + # Neighbors: ±1 FID-check interval around each top-k epoch + neighbor_epochs: set = set() + for ep in top_k_epochs: + neighbor_epochs.add(ep - self.every_n_epochs) + neighbor_epochs.add(ep + self.every_n_epochs) + neighbor_epochs &= epoch_set # only those we actually have on disk + + # Rolling buffer: always keep the two most recent FID-check epochs + recent_epochs = set(sorted(all_epochs)[-2:]) + + keep_epochs = top_k_epochs | neighbor_epochs | recent_epochs + + new_records = [] + for fid, ep, path in self._fid_records: + if ep in keep_epochs: + new_records.append((fid, ep, path)) + else: + p = Path(path) + if p.exists(): + p.unlink() + print(f"[IntermediateFID] Pruned checkpoint epoch={ep} (FID={fid:.4f})") + self._fid_records = new_records + + # ------------------------------------------------------------------ + # Logging helpers + # ------------------------------------------------------------------ + + def _log_wandb(self, trainer, epoch: int, fid: float) -> None: + run = getattr(trainer.logger, "experiment", None) + if run is None: + return + sorted_records = sorted(self._fid_records, key=lambda x: x[0]) + top_k_epochs = [r[1] for r in sorted_records[: self.top_k]] + run.log( + { + "intermediate_context_fid": fid, + "fid_best": sorted_records[0][0] if sorted_records else float("nan"), + "fid_top_k_epochs": str(top_k_epochs), + }, + step=trainer.global_step, + ) + + def _log_csv(self, epoch: int, fid: float) -> None: + Path(self.cfg.run_dir).mkdir(parents=True, exist_ok=True) + if not self._header_written: + with open(self._csv_path, "w", newline="") as f: + csv.writer(f).writerow(["epoch", "context_fid"]) + self._header_written = True + with open(self._csv_path, "a", newline="") as f: + csv.writer(f).writerow([epoch, float(fid)]) diff --git a/scripts/train.py b/scripts/train.py index bf426bb..9b9b1e8 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -39,13 +39,14 @@ def _load_dataset_config(dataset_name: str, overrides: list) -> OmegaConf: return cfg -def _write_run_summary(run_dir: Path, run_name: str, trainer: Trainer) -> None: +def _write_run_summary(run_dir: Path, run_name: str, trainer: Trainer, random_seed: int) -> None: """Write a summary YAML of run choices (context, model, dataset, trainer) to run_dir.""" cfg = trainer.cfg context_cfg = get_context_config() summary = { "run_name": run_name, "run_dir": str(run_dir), + "random_seed": random_seed, "dataset": OmegaConf.to_container(cfg.dataset, resolve=True) if hasattr(cfg, "dataset") and cfg.dataset else {}, "model": OmegaConf.to_container(cfg.model, resolve=True) if hasattr(cfg, "model") and cfg.model else {}, "context": OmegaConf.to_container(context_cfg, resolve=True) if context_cfg else {}, @@ -148,19 +149,16 @@ def main(args) -> None: f"trainer.checkpoint.every_n_epochs={args.every_n_epochs}", f"trainer.strategy={args.ddp_strategy}", f"trainer.devices={args.devices}", + f"trainer.accelerator={args.accelerator}", f"trainer.eval_after_training={args.eval_after_training}", - f"train.accelerator={args.accelerator}", - "trainer.early_stopping.patience=100", - "trainer.early_stopping.monitor=train_loss", - "trainer.early_stopping.mode=min", - f"trainer.enable_checkpointing={args.enable_checkpointing}", - "trainer.logger=False", + f"trainer.intermediate_fid.enabled={args.intermediate_fid}", + f"trainer.intermediate_fid.every_n_epochs={args.fid_every_n_epochs}", f"wandb.enabled={args.wandb_enabled}", f"wandb.project={args.wandb_project}", f"wandb.entity={args.wandb_entity}", + f"wandb.name={run_name}_{MODEL_NAME}_{datetime.now().strftime('%Y%m%d-%H%M%S')}", f"model.context_reconstruction_loss_weight={CR_LOSS_WEIGHT}", f"model.tc_loss_weight={TC_LOSS_WEIGHT}", - f"wandb.name=training_dai_{MODEL_NAME}_{datetime.now().strftime('%Y%m%d-%H%M%S')}_L{CR_LOSS_WEIGHT}_TC_{TC_LOSS_WEIGHT}_dim2", ] if args.model_overrides: trainer_overrides.extend(args.model_overrides) @@ -171,7 +169,7 @@ def main(args) -> None: overrides=trainer_overrides, ) - _write_run_summary(run_dir, run_name, trainer) + _write_run_summary(run_dir, run_name, trainer, args.random_seed) _write_run_configs(run_dir, trainer) trainer.fit(ckpt_path=args.resume_from_checkpoint) @@ -217,6 +215,12 @@ def main(args) -> None: ) parser.add_argument("--random-seed", type=int, default=42, help="Random seed for reproducibility",) + parser.add_argument("--intermediate-fid", action="store_true", default=True, + help="Enable intermediate context-FID checks during training (saves top-k checkpoints)") + parser.add_argument("--no-intermediate-fid", dest="intermediate_fid", action="store_false", + help="Disable intermediate context-FID checks") + parser.add_argument("--fid-every-n-epochs", type=int, default=20, + help="How often (in epochs) to compute context-FID during training") args = parser.parse_args() main(args) \ No newline at end of file From 15393de87e6dcfb418de1c6ba5993eca847640e3 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Wed, 18 Mar 2026 16:37:15 -0400 Subject: [PATCH 46/50] changes to generate. also, good 750 epoch aqr run --- cents/config/model/diffusion_ts.yaml | 4 +- cents/config/trainer/diffusion_ts.yaml | 2 +- cents/config/trainer/normalizer.yaml | 4 +- cents/datasets/utils.py | 12 +++-- cents/models/normalizer.py | 2 +- scripts/eval_pretrained.py | 13 ++++- scripts/generate.py | 70 +++++++++++++++++++------- 7 files changed, 80 insertions(+), 27 deletions(-) diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index edd4718..ce57214 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -3,7 +3,7 @@ name: diffusion_ts context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 noise_dim: 256 -cond_emb_dim: 16 +cond_emb_dim: 64 n_layer_enc: 4 n_layer_dec: 5 d_model: 128 @@ -12,7 +12,7 @@ sampling_timesteps: 200 sampling_batch_size: 4096 loss_type: l1 #l2 training_objective: v -loss_weighting: min_snr +loss_weighting: snr min_snr_gamma: 5.0 beta_schedule: cosine #linear diffusion ts paper uses linear schedule n_heads: 4 diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index c11877e..44dd34a 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -9,7 +9,7 @@ max_epochs: 5000 base_lr: 1e-4 warmup_epochs: 100 # linear warmup from 1% to 100% of base_lr over first N epochs eval_after_training: False -gradient_clip_val: 1.0 +# gradient_clip_val: 1.0 checkpoint: save_last: True # Save final model diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index 96b3d49..f090bac 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,10 +1,10 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: 3, +devices: 2, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 -n_epochs: 750 +n_epochs: 500 batch_size: 4096 lr: 3e-4 save_cycle: 5000 diff --git a/cents/datasets/utils.py b/cents/datasets/utils.py index 4fe06f7..45d7884 100644 --- a/cents/datasets/utils.py +++ b/cents/datasets/utils.py @@ -18,7 +18,7 @@ def split_timeseries( if n_dim != len(time_series_columns): raise ValueError("shape mismatch") for i, col in enumerate(time_series_columns): - df[col] = df["timeseries"].apply(lambda x: x[:, i]) + df[col] = df["timeseries"].apply(lambda x: x[:, i].tolist()) return df.drop(columns=["timeseries"]) @@ -241,8 +241,12 @@ def convert_generated_data_to_df( n_samples = data_np.shape[0] def _get_code_at(code: Any, i: int) -> Any: - if isinstance(code, torch.Tensor) and code.dim() == 1 and code.shape[0] == n_samples: - return code[i].item() + if isinstance(code, torch.Tensor): + if code.dim() == 2 and code.shape[0] == n_samples: + # Dynamic (time-series) context: return as numpy array + return code[i].cpu().numpy() + if code.dim() == 1 and code.shape[0] == n_samples: + return code[i].item() return code.item() if isinstance(code, torch.Tensor) else code records = [] @@ -254,6 +258,8 @@ def _get_code_at(code: Any, i: int) -> Any: if mapping is None: raise ValueError("Mapping must be provided when decode=True.") record[var] = mapping[var][v] + elif isinstance(v, np.ndarray): + record[var] = v.tolist() else: record[var] = v if isinstance(v, float) else int(v) record["timeseries"] = data_np[i] diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 37a4d54..245a569 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -726,7 +726,7 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: rng_eff = max(rng, self.min_scale_range) z = z * rng_eff + zmin_ arr = z * sigma_eff + mu[d] - df_out.at[i, col] = arr + df_out.at[i, col] = arr.tolist() return df_out def _build_training_samples( diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 1b22d2d..15c8501 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -20,7 +20,7 @@ from cents.datasets.metraq import MetraqDataset from cents.eval.eval import Evaluator from cents.utils.config_loader import load_yaml, apply_overrides -from cents.utils.utils import set_context_config_path +from cents.utils.utils import set_context_config_path, set_context_overrides logging.basicConfig( level=logging.INFO, @@ -229,6 +229,13 @@ def main() -> None: default=None, help="Path to custom context config YAML file (optional).", ) + parser.add_argument( + "--context-overrides", + type=str, + nargs="*", + default=[], + help="Override context config values (e.g., 'static_context.type=mlp' 'dynamic_context.type=cnn').", + ) parser.add_argument( "--no-normalizer-global-preprocessing", action="store_true", @@ -291,6 +298,10 @@ def main() -> None: if args.context_config_path: set_context_config_path(args.context_config_path) + # Set context config overrides if provided + if args.context_overrides: + set_context_overrides(args.context_overrides) + if use_run_path: run_path = Path(args.run_path).resolve() logging.info("Using run-path: %s (epoch=%s)", run_path, args.epoch) diff --git a/scripts/generate.py b/scripts/generate.py index f40fa8d..93b43b0 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -5,6 +5,7 @@ Supports: - Random context: sample context from the dataset's support (including continuous). - Explicit context: provide context as JSON (categorical: int codes; continuous: z-scored floats). + - Sample rows: sample full context (static + dynamic) from real dataset rows, preserving correlations. - Output to Parquet (default) or CSV. """ @@ -12,8 +13,10 @@ import json import logging import os +import random from pathlib import Path +import numpy as np import torch from omegaconf import OmegaConf @@ -23,7 +26,7 @@ from cents.datasets.airquality import AirQualityDataset from cents.datasets.utils import convert_generated_data_to_df from cents.utils.config_loader import load_yaml -from cents.utils.utils import set_context_config_path +from cents.utils.utils import set_context_config_path, set_context_overrides logging.basicConfig( level=logging.INFO, @@ -87,18 +90,20 @@ def main() -> None: default="samples.parquet", help="Output path for generated samples.", ) - parser.add_argument( - "--format", - type=str, - choices=("parquet", "csv"), - default="parquet", - help="Output format. Parquet preserves array columns better.", - ) parser.add_argument( "--random-context", action="store_true", help="Sample context randomly from the dataset support (categorical and continuous).", ) + parser.add_argument( + "--sample-rows", + action="store_true", + help=( + "Sample full context (static + dynamic) from real dataset rows. " + "Preserves correlations between covariates. " + "Samples with replacement if n > len(dataset)." + ), + ) parser.add_argument( "--context", type=str, @@ -119,6 +124,13 @@ def main() -> None: default=[], help="Extra dataset overrides, e.g. time_series_dims=1.", ) + parser.add_argument( + "--context-overrides", + type=str, + nargs="*", + default=[], + help="Override context config values (e.g., 'static_context.type=mlp' 'dynamic_context.type=cnn').", + ) parser.add_argument( "--no-ema", action="store_true", @@ -128,14 +140,19 @@ def main() -> None: use_random = args.random_context use_explicit = args.context is not None and args.context.strip() != "" - if not use_random and not use_explicit: - parser.error("Provide either --random-context or --context (JSON).") - if use_random and use_explicit: - parser.error("Provide only one of --random-context or --context.") + use_rows = args.sample_rows + n_modes = sum([use_random, use_explicit, use_rows]) + if n_modes == 0: + parser.error("Provide one of --random-context, --context (JSON), or --sample-rows.") + if n_modes > 1: + parser.error("Provide only one of --random-context, --context, or --sample-rows.") if args.context_config_path: set_context_config_path(args.context_config_path) + if args.context_overrides: + set_context_overrides(args.context_overrides) + overrides = list(args.dataset_overrides) if args.dataset_overrides else [] logging.info("Loading dataset %s...", args.dataset) @@ -154,7 +171,29 @@ def main() -> None: target.model.use_ema_sampling = cfg.model.use_ema_sampling gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) - if use_random: + if use_rows: + # Sample full context (static + dynamic) from real dataset rows. + # With-replacement sampling handles n > len(dataset). + indices = random.choices(range(len(dataset)), k=args.num_samples) + samples = [dataset[i] for i in indices] + static_batch = { + k: torch.stack([s[1][k] for s in samples]).to(gen.device) + for k in samples[0][1].keys() + } + dynamic_batch = { + k: torch.stack([s[2][k] for s in samples]).to(gen.device) + for k in samples[0][2].keys() + } if samples[0][2] else {} + logging.info( + "Generating %d samples conditioned on %d real rows (static: %s, dynamic: %s)...", + args.num_samples, args.num_samples, + list(static_batch.keys()), list(dynamic_batch.keys()), + ) + with torch.no_grad(): + ts = gen.model.generate(static_batch, dynamic_batch or None) + ctx_batch = {**static_batch, **dynamic_batch} + df = convert_generated_data_to_df(ts, ctx_batch, decode=False) + elif use_random: # Sample a new random context for each of the n samples contexts = [dataset.sample_random_context_vars() for _ in range(args.num_samples)] ctx_batch = { @@ -184,10 +223,7 @@ def main() -> None: out = Path(args.out) out.parent.mkdir(parents=True, exist_ok=True) - if args.format == "parquet": - df.to_parquet(out, index=False) - else: - df.to_csv(out, index=False) + df.to_parquet(out, index=False) logging.info("Wrote %d samples to %s", len(df), out.resolve()) From 10525f8147e8b1c1c5a02aa1ed6d89921d24a723 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Tue, 31 Mar 2026 13:39:37 -0400 Subject: [PATCH 47/50] Implemented walmart dataset; layernomr for dyn context ; config changes --- cents/config/dataset/metraq.yaml | 1 + cents/config/dataset/walmart.yaml | 37 ++++ cents/config/evaluator/default.yaml | 26 +++ cents/config/model/diffusion_ts.yaml | 2 +- cents/datasets/metraq.py | 2 +- cents/datasets/timeseries_dataset.py | 2 +- cents/datasets/walmart.py | 272 +++++++++++++++++++++++++++ cents/eval/eval.py | 139 ++++++++++++++ cents/eval/eval_metrics.py | 179 +++++++++++++++++- cents/models/context.py | 13 ++ cents/models/diffusion_ts.py | 2 +- scripts/eval_pretrained.py | 13 +- scripts/train.py | 7 + 13 files changed, 685 insertions(+), 10 deletions(-) create mode 100644 cents/config/dataset/walmart.yaml create mode 100644 cents/datasets/walmart.py diff --git a/cents/config/dataset/metraq.yaml b/cents/config/dataset/metraq.yaml index df87448..1ed5877 100644 --- a/cents/config/dataset/metraq.yaml +++ b/cents/config/dataset/metraq.yaml @@ -15,6 +15,7 @@ time_series_dims: 1 normalizer_stats_mode: group # Normalizer conditions only on these (e.g. per-station); diffusion still gets full context_vars normalizer_group_vars: ["sensor_name"] +max_z_threshold: 15.0 # Targets (what becomes the merged "timeseries" dims) # NOTE: use PMcoarse instead of PM10 diff --git a/cents/config/dataset/walmart.yaml b/cents/config/dataset/walmart.yaml new file mode 100644 index 0000000..a11ac30 --- /dev/null +++ b/cents/config/dataset/walmart.yaml @@ -0,0 +1,37 @@ +name: walmart +geography: null +normalize: True +scale: False +use_learned_normalizer: True +threshold: 8 +seq_len: 30 +shuffle: True +skip_heavy_processing: False +max_samples: null +path: "./data/walmart" +numeric_context_bins: 1 +reduce_cardinality: False +time_series_dims: 1 +normalizer_stats_mode: group +# Normalizer conditions on category × store to capture per-group sales distributions +normalizer_group_vars: ["cat_id", "store_id"] +max_z_threshold: 15.0 + +# Target: daily unit sales +time_series_columns: ["sales"] + +context_vars: + # Static categorical — characterise the window by when it starts + year: ["categorical", 6] # 2011–2016 + month: ["categorical", 12] + weekday: ["categorical", 7] + # Static categorical — item / store identity + cat_id: ["categorical", 3] # FOODS, HOBBIES, HOUSEHOLD + dept_id: ["categorical", 7] # e.g. FOODS_1 … HOUSEHOLD_2 + store_id: ["categorical", 10] # CA_1 … WI_3 + state_id: ["categorical", 3] # CA, TX, WI + + # Dynamic time-series context (co-occurring with target, length = seq_len) + sell_price: ["time_series", null] # weekly price broadcast to daily, z-scored + snap: ["time_series", null] # binary SNAP eligibility for the item's state + event_binary: ["time_series", null] # 1 if a named calendar event falls on that day diff --git a/cents/config/evaluator/default.yaml b/cents/config/evaluator/default.yaml index 2f3467f..7d09580 100644 --- a/cents/config/evaluator/default.yaml +++ b/cents/config/evaluator/default.yaml @@ -9,3 +9,29 @@ save_results: False eval_disentanglement: True job_name: diffusion_ts_commercial save_dir: outputs/diffusion_ts_commercial/eval + +# Context Faithfulness Score (CFS) and Granger Causality Preservation (GCP). +# Runs only when enabled=True AND either: +# - the generated signal has multiple dimensions (multivariate), OR +# - the dataset uses dynamic (time-series) context variables. +# +# pairs: list of {x, c} dicts specifying which time series to evaluate against each other. +# x — name of a generated output dimension (must match time_series_columns in dataset config) +# c — name of a dynamic context variable (from context_vars with type "time_series") +# OR another generated output dimension (multivariate case, GCP only) +# +# CFS is computed only when c is a dynamic context variable (shared context). +# GCP is computed for all pairs. +# +# Example for airquality dataset (PM2.5 generated, TEMP/DEWP as context): +# pairs: +# - {x: "PM2.5", c: "TEMP"} +# - {x: "PM2.5", c: "DEWP"} +# +eval_context_faithfulness: + enabled: true + gcp_max_lag: 5 + pairs: + - {x: "PM2.5", c: "T"} + - {x: "PM2.5", c: "R"} + # - {x: "PM2.5", c: "TEMP"} diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index ce57214..44c9a54 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -11,7 +11,7 @@ n_steps: 1000 sampling_timesteps: 200 sampling_batch_size: 4096 loss_type: l1 #l2 -training_objective: v +training_objective: x0 loss_weighting: snr min_snr_gamma: 5.0 beta_schedule: cosine #linear diffusion ts paper uses linear schedule diff --git a/cents/datasets/metraq.py b/cents/datasets/metraq.py index 9221807..43c9b57 100644 --- a/cents/datasets/metraq.py +++ b/cents/datasets/metraq.py @@ -53,7 +53,7 @@ def __init__( FileNotFoundError: If required CSV files are missing. """ if cfg is None: - cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "pecanstreet.yaml")) + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "metraq.yaml")) if overrides: cfg = apply_overrides(cfg, overrides) diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 5f11dd7..347fead 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -150,7 +150,7 @@ def __init__( print(f"[Main Process] Cached normalized data for subprocesses") self.data = self.merge_timeseries_columns(self.data) self.data = self.data.reset_index() - + # Check if we should skip heavy processing for DDP if is_ddp_subprocess and skip_heavy_processing: print("skipped rarity computation for DDP compatibility") diff --git a/cents/datasets/walmart.py b/cents/datasets/walmart.py new file mode 100644 index 0000000..2e2bf51 --- /dev/null +++ b/cents/datasets/walmart.py @@ -0,0 +1,272 @@ +import os +import warnings +from typing import List, Optional + +import numpy as np +import pandas as pd +from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides + +from cents.datasets.timeseries_dataset import TimeSeriesDataset + +warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +_MONTHS = [ + "January", "February", "March", "April", "May", "June", + "July", "August", "September", "October", "November", "December", +] + + +class WalmartDataset(TimeSeriesDataset): + """ + Dataset class for Walmart M5 daily unit sales time series. + + Data: https://www.kaggle.com/competitions/m5-forecasting-accuracy + + Each sample is a 30-day non-overlapping window of daily unit sales for a + single item-store pair, filtered to high-velocity items (top 20% by mean + daily sales). Static context encodes category, department, store, and + state. Dynamic context includes sell price, SNAP eligibility, and a + binary calendar-event indicator. + """ + + def __init__( + self, + cfg: Optional[DictConfig] = None, + overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None, + ): + if cfg is None: + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "walmart.yaml")) + if overrides: + cfg = apply_overrides(cfg, overrides) + + self.cfg = cfg + self.name = cfg.name + self.normalize = cfg.normalize + self.target_time_series_columns = list(cfg.time_series_columns) + self.time_series_dims = cfg.time_series_dims + self.geography = cfg.get("geography", None) + + self.context_time_series_columns = { + k: v[1] for k, v in cfg.context_vars.items() if v[0] == "time_series" + } + self.context_series_names = list(self.context_time_series_columns.keys()) + # No categorical time series for this dataset (all dynamic vars are continuous/binary) + self.categorical_time_series = { + k: v[1] for k, v in cfg.context_vars.items() + if v[0] == "time_series" and v[1] is not None + } + + self._load_data() + + ts_cols: List[str] = self.target_time_series_columns[: self.time_series_dims] + + super().__init__( + data=self.data, + time_series_column_names=ts_cols, + context_var_column_names=list(cfg.context_vars.keys()), + seq_len=cfg.seq_len, + normalize=cfg.normalize, + scale=cfg.scale, + skip_heavy_processing=cfg.get("skip_heavy_processing", False), + size=cfg.get("max_samples", None), + categorical_time_series=self.categorical_time_series, + force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, + ) + + def _load_data(self) -> None: + """ + Loads and joins sales, calendar, and price CSVs; filters to + high-velocity item-store pairs; constructs per-state SNAP flags + and binary event indicators. + """ + module_dir = os.path.dirname(os.path.abspath(__file__)) + path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) + + # --- Sales: wide → long --- + sales_path = os.path.join(path, "sales_train_evaluation.csv") + if not os.path.exists(sales_path): + raise FileNotFoundError(f"Sales file not found: {sales_path}") + + sales = pd.read_csv(sales_path) + meta_cols = ["id", "item_id", "dept_id", "cat_id", "store_id", "state_id"] + day_cols = [c for c in sales.columns if c.startswith("d_")] + sales_long = sales[meta_cols + day_cols].melt( + id_vars=meta_cols, var_name="d", value_name="sales" + ) + del sales # free memory before merges + + # --- Calendar --- + calendar = pd.read_csv( + os.path.join(path, "calendar.csv"), + usecols=["d", "date", "wm_yr_wk", "weekday", "month", "year", + "event_name_1", "snap_CA", "snap_TX", "snap_WI"], + ) + calendar["date"] = pd.to_datetime(calendar["date"]) + sales_long = sales_long.merge(calendar, on="d", how="left") + del calendar + + # --- Sell prices (weekly → broadcast to daily via merge, then ffill) --- + prices = pd.read_csv(os.path.join(path, "sell_prices.csv")) + sales_long = sales_long.merge(prices, on=["store_id", "item_id", "wm_yr_wk"], how="left") + del prices + + sales_long = sales_long.sort_values(["id", "date"]) + sales_long["sell_price"] = ( + sales_long.groupby("id")["sell_price"] + .transform(lambda x: x.ffill().bfill()) + ) + + # --- State-specific SNAP eligibility (vectorised) --- + sales_long["snap"] = np.where( + sales_long["state_id"] == "CA", sales_long["snap_CA"], + np.where( + sales_long["state_id"] == "TX", sales_long["snap_TX"], + sales_long["snap_WI"], + ), + ).astype(np.int8) + + # --- Binary calendar-event indicator --- + sales_long["event_binary"] = sales_long["event_name_1"].notna().astype(np.int8) + + # --- Filter to high-velocity item-store pairs (top 20% by mean daily sales) --- + mean_sales = sales_long.groupby("id")["sales"].mean() + threshold = mean_sales.quantile(0.80) + high_vel_ids = mean_sales.index[mean_sales >= threshold] + sales_long = sales_long[sales_long["id"].isin(high_vel_ids)].copy() + + # --- Month name (consistent with other datasets) --- + sales_long["month"] = sales_long["month"].map(lambda x: _MONTHS[int(x) - 1]) + + # Drop columns that are no longer needed after engineering + sales_long.drop( + columns=["snap_CA", "snap_TX", "snap_WI", "event_name_1", + "wm_yr_wk", "d", "item_id"], + errors="ignore", + inplace=True, + ) + + self.data = sales_long.reset_index(drop=True) + + def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: + """ + Assigns non-overlapping 30-day window IDs within each item-store + series, groups into windows, and z-score normalises continuous + dynamic context channels. + """ + data = data.copy() + + ts_ctx = list(self.context_series_names) # ["sell_price", "snap", "event_binary"] + tgt_ts = list(self.target_time_series_columns) # ["sales"] + all_ts = ts_ctx + tgt_ts + + # Assign window indices: non-overlapping blocks of seq_len days per series + data = data.sort_values(["id", "date"]).reset_index(drop=True) + data["_row"] = data.groupby("id").cumcount() + data["_window"] = data["_row"] // self.cfg.seq_len + + group_keys = ["id", "_window"] + static_cols = ["cat_id", "dept_id", "store_id", "state_id", "year", "month", "weekday"] + + agg_dict = {c: list for c in all_ts} + agg_dict.update({c: "first" for c in static_cols}) + + grouped = ( + data.groupby(group_keys, as_index=False, sort=False) + .agg(agg_dict) + ) + grouped.drop(columns=["_window"], inplace=True, errors="ignore") + + for c in all_ts: + grouped[c] = grouped[c].map(np.asarray) + + # Keep only complete windows + len_col = tgt_ts[0] + grouped = grouped[grouped[len_col].apply(len) == self.cfg.seq_len].reset_index(drop=True) + + grouped = self._handle_missing_data(grouped) + + # Z-score normalise continuous dynamic context; pass binaries through + binary_channels = {"snap", "event_binary"} + clip_bound = 5.0 + eps = 1e-8 + ctx_stats = {} + + for c in ts_ctx: + X = np.stack(grouped[c].values).astype(np.float32) + + if c in binary_channels: + grouped[c] = list(X) + continue + + mu = float(X.mean()) + sd = float(X.std()) + if sd < 1e-6: + sd = 1.0 + ctx_stats[c] = (mu, sd) + + Xn = np.clip((X - mu) / (sd + eps), -clip_bound, clip_bound).astype(np.float32) + grouped[c] = list(Xn) + + self.context_ts_stats_ = ctx_stats + + # Convert arrays → tuples (hashable, required by base class) + for c in all_ts: + grouped[c] = grouped[c].map(tuple) + + return grouped + + def _handle_missing_data(self, data: pd.DataFrame) -> pd.DataFrame: + numeric_series = [ + c for c in self.context_series_names if c not in self.categorical_time_series + ] + + # Drop windows where any numeric context series is entirely NaN + if numeric_series: + mask = data[numeric_series].applymap(is_all_nan).any(axis=1) + data = data[~mask] + + # Fill isolated NaNs in numeric context series with within-window mean + for col in numeric_series: + data[col] = data[col].apply(fill_with_row_mean) + + # Drop windows with NaN in any categorical time series (none expected here) + cat_cols = list(self.categorical_time_series.keys()) + if cat_cols: + mask = data[cat_cols].applymap(is_any_nan).any(axis=1) + data = data[~mask] + + # Drop windows with any NaN in target + for tcol in self.target_time_series_columns: + if tcol in data.columns: + data = data.loc[ + data[tcol].apply( + lambda x: not np.isnan(np.asarray(x, dtype=float)).any() + ) + ] + + # Drop near-constant windows (std < 0.01 → degenerate for diffusion) + def _low_std(row, cols, thresh=0.01): + return any(np.asarray(row[c], dtype=np.float32).std() < thresh for c in cols) + + mask = data.apply(lambda row: _low_std(row, self.target_time_series_columns), axis=1) + data = data[~mask] + return data + + +def is_all_nan(arr): + return pd.isna(arr).all() + + +def is_any_nan(arr): + return pd.isna(arr).any() + + +def fill_with_row_mean(lst): + s = pd.Series(lst, dtype=float) + m = s.mean(skipna=True) + return s.fillna(m).tolist() diff --git a/cents/eval/eval.py b/cents/eval/eval.py index cd30b02..2d72dbf 100644 --- a/cents/eval/eval.py +++ b/cents/eval/eval.py @@ -18,6 +18,8 @@ from cents.eval.eval_metrics import ( Context_FID, calculate_mmd, + compute_cfs, + compute_gcp, compute_mig, compute_sap, dynamic_time_warping_dist, @@ -189,6 +191,8 @@ def compute_quality_metrics( mask: Optional[np.ndarray] = None, target: Optional[Dict] = None, log_prefix: str = "", + context_arrays: Optional[Dict[str, np.ndarray]] = None, + ts_column_names: Optional[list] = None, ) -> Dict: """ Compute evaluation metrics and store them in current_results (or in target if provided). @@ -200,6 +204,8 @@ def compute_quality_metrics( mask (Optional[np.ndarray]): Boolean array indicating which rows are "rare" target (Optional[Dict]): If set, write metrics into this dict instead of current_results (for normalized_domain). log_prefix (str): Prefix for log messages (e.g. "[normalized]"). + context_arrays (Optional[Dict[str, np.ndarray]]): Named dynamic context arrays (N, T) each. + ts_column_names (Optional[list]): Names of the time-series output dimensions, aligned with real_data axis-2. """ logger.info(f"[Cents] --- {log_prefix}Full-Subset Metrics ---") @@ -226,6 +232,14 @@ def compute_quality_metrics( metrics["Pred_Score"] = pred_score logger.info(f"[Cents] Pred Score completed") + # CFS / GCP — only when dynamic context or multivariate signal is present + cf_cfg = self.cfg.evaluator.get("eval_context_faithfulness", None) + if cf_cfg and cf_cfg.get("enabled", False) and context_arrays is not None: + cf_metrics = self._compute_context_faithfulness_metrics( + real_data, syn_data, context_arrays, ts_column_names, cf_cfg + ) + metrics["context_faithfulness"] = cf_metrics + if mask is not None: logger.info("[Cents] Starting Rare-Subset Metrics") rare_metrics = {} @@ -277,6 +291,108 @@ def compute_quality_metrics( self.current_results["metrics"] = metrics return metrics + def _compute_context_faithfulness_metrics( + self, + real_data: np.ndarray, + syn_data: np.ndarray, + context_arrays: Dict[str, np.ndarray], + ts_column_names: Optional[list], + cf_cfg, + ) -> Dict: + """ + Compute CFS and GCP for configured (x_dim, c_dim) pairs. + + Each pair specifies one x dimension (by name from ts_column_names, or by index) + and one c dimension (by name from context_arrays, or by name from ts_column_names + for within-signal multivariate pairs). + + CFS is only computed when c comes from context_arrays (shared context). + GCP is computed for all pairs. + """ + results: Dict = {} + max_lag: int = int(cf_cfg.get("gcp_max_lag", 5)) + pairs_cfg = cf_cfg.get("pairs", []) + + if not pairs_cfg: + logger.warning("[Cents] eval_context_faithfulness.pairs is empty — skipping CFS/GCP.") + return results + + ts_names = list(ts_column_names) if ts_column_names else [] + + def _resolve_x(name) -> Optional[Tuple[np.ndarray, np.ndarray]]: + """Return (x_real_slice, x_synth_slice) of shape (N, T, 1).""" + if isinstance(name, int): + idx = name + elif name in ts_names: + idx = ts_names.index(name) + else: + logger.warning(f"[Cents] CFS/GCP: x dim '{name}' not found in ts_column_names {ts_names}; skipping.") + return None + return real_data[:, :, idx : idx + 1], syn_data[:, :, idx : idx + 1] + + def _resolve_c(name) -> Optional[Tuple[np.ndarray, np.ndarray, bool]]: + """Return (c_real, c_synth, is_shared) of shape (N, T, 1). is_shared=True if dynamic context.""" + if name in context_arrays: + arr = context_arrays[name] # (N, T) + if arr.ndim == 1: + arr = arr[:, np.newaxis] # edge-case: (N,) → (N, 1) + if arr.ndim == 2: + arr = arr[:, :, np.newaxis] # (N, T, 1) + return arr, arr, True # same array for real and synth + elif name in ts_names: + idx = ts_names.index(name) + c_r = real_data[:, :, idx : idx + 1] + c_s = syn_data[:, :, idx : idx + 1] + return c_r, c_s, False + else: + logger.warning(f"[Cents] CFS/GCP: c dim '{name}' not found in context_arrays or ts_column_names; skipping.") + return None + + for pair in pairs_cfg: + x_name = pair.get("x") + c_name = pair.get("c") + if x_name is None or c_name is None: + logger.warning(f"[Cents] CFS/GCP: pair {pair} missing 'x' or 'c' key; skipping.") + continue + + x_resolved = _resolve_x(x_name) + c_resolved = _resolve_c(c_name) + if x_resolved is None or c_resolved is None: + continue + + x_r, x_s = x_resolved + c_r, c_s, is_shared = c_resolved + pair_key = f"{x_name}_vs_{c_name}" + + pair_result: Dict = {} + + # CFS — only meaningful when context is shared (dynamic context) + if is_shared: + try: + cfs = compute_cfs(x_r, x_s, c_r) + pair_result["CFS"] = cfs + logger.info(f"[Cents] CFS completed for {pair_key}: {cfs:.4f}") + except Exception as e: + logger.warning(f"[Cents] CFS failed for {pair_key}: {e}") + pair_result["CFS"] = float("nan") + else: + logger.info(f"[Cents] CFS skipped for {pair_key} (c is a generated dim, not shared context).") + + # GCP — works for both shared and generated-dim context + try: + gcp, diag = compute_gcp(x_r, x_s, c_r, c_s, max_lag=max_lag) + pair_result["GCP"] = gcp + pair_result["GCP_diagnostics"] = diag + logger.info(f"[Cents] GCP completed for {pair_key}: {gcp:.4f}") + except Exception as e: + logger.warning(f"[Cents] GCP failed for {pair_key}: {e}") + pair_result["GCP"] = float("nan") + pair_result["GCP_diagnostics"] = {} + + results[pair_key] = pair_result + + return results + def compute_disentanglement_metrics( self, context_vars: Dict[str, torch.Tensor], @@ -401,6 +517,25 @@ def _inv(df): real_data_array = np.stack(real_data_inv["timeseries"]) syn_data_array = np.stack(syn_data_inv["timeseries"]) + # Extract dynamic context as (N, T) numpy arrays for CFS/GCP + context_np: Dict[str, np.ndarray] = {} + for name, tensor in dynamic_context_vars.items(): + arr = tensor.cpu().numpy() # (N,) static scalar or (N, T) time series + if arr.ndim == 2: # (N, T) — keep time-series context only + context_np[name] = arr + + ts_col_names = list(getattr(dataset, "time_series_column_names", [])) + + # Decide whether CFS/GCP should run: multivariate signal OR dynamic context present + has_dynamic_context = len(context_np) > 0 + has_multivariate_signal = real_data_array.shape[-1] > 1 + cf_cfg = self.cfg.evaluator.get("eval_context_faithfulness", None) + run_cf = ( + cf_cfg is not None + and cf_cfg.get("enabled", False) + and (has_dynamic_context or has_multivariate_signal) + ) + if self.cfg.evaluator.eval_metrics: rare_mask = None @@ -414,6 +549,8 @@ def _inv(df): self.compute_quality_metrics( real_data_array, syn_data_array, real_data_inv, rare_mask, log_prefix="[raw] ", + context_arrays=context_np if run_cf else None, + ts_column_names=ts_col_names if run_cf else None, ) # Metrics in normalized (z) domain for cross-domain comparability. @@ -431,6 +568,8 @@ def _inv(df): real_data_norm, syn_data_norm, real_data_inv, rare_mask, target=normalized_metrics, log_prefix="[normalized] ", + context_arrays=context_np if run_cf else None, + ts_column_names=ts_col_names if run_cf else None, ) self.current_results["metrics"]["normalized_domain"] = normalized_metrics diff --git a/cents/eval/eval_metrics.py b/cents/eval/eval_metrics.py index 1dbe86b..7c78487 100644 --- a/cents/eval/eval_metrics.py +++ b/cents/eval/eval_metrics.py @@ -1,14 +1,19 @@ import warnings from functools import partial -from typing import Dict, Tuple +from itertools import product +from typing import Dict, List, Optional, Tuple import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy from dtaidistance import dtw -from sklearn.linear_model import Ridge -from sklearn.metrics import mutual_info_score, r2_score +from scipy.stats import f as f_dist +from scipy.stats import wasserstein_distance +from sklearn.linear_model import LogisticRegression, Ridge +from sklearn.metrics import mutual_info_score, r2_score, roc_auc_score +from sklearn.model_selection import StratifiedKFold +from statsmodels.tsa.tsatools import lagmat from cents.eval.eval_utils import ( gaussian_kernel_matrix, @@ -204,6 +209,174 @@ def Context_FID(ori_data: np.ndarray, generated_data: np.ndarray) -> float: return calculate_fid(ori_rep, gen_rep) +def compute_cfs( + x_real: np.ndarray, + x_synth: np.ndarray, + c: np.ndarray, + n_folds: int = 5, +) -> float: + """ + Compute Context Faithfulness Score (CFS). + + Trains a classifier to distinguish real (x, c) pairs from synthetic (x, c) pairs + using cross-validation, then returns 2 * |AUROC - 0.5|. + + Args: + x_real: (N, T, D_x) real time series + x_synth: (N, T, D_x) synthetic time series + c: (N, T, D_c) shared context (same for real and synthetic) + n_folds: number of cross-validation folds + + Returns: + float: CFS in [0, 1]. 0 = indistinguishable (perfect), 1 = fully separable (failed). + """ + N = x_real.shape[0] + + real_pairs = np.concatenate([x_real, c], axis=-1) # (N, T, D_x+D_c) + synth_pairs = np.concatenate([x_synth, c], axis=-1) # (N, T, D_x+D_c) + + # Mean pool over time → fixed-size vectors + X_real_enc = real_pairs.mean(axis=1) # (N, D_x+D_c) + X_synth_enc = synth_pairs.mean(axis=1) # (N, D_x+D_c) + + X_all = np.concatenate([X_real_enc, X_synth_enc], axis=0) # (2N, D) + y_all = np.concatenate([np.ones(N), np.zeros(N)]) # (2N,) + + # Drop rows with NaN + valid = ~np.isnan(X_all).any(axis=1) + X_all = X_all[valid] + y_all = y_all[valid] + + if len(X_all) < 2 * n_folds or len(np.unique(y_all)) < 2: + warnings.warn("compute_cfs: insufficient valid samples; returning nan.") + return float("nan") + + skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42) + auroc_scores = [] + for train_idx, val_idx in skf.split(X_all, y_all): + clf = LogisticRegression(max_iter=1000, random_state=42) + clf.fit(X_all[train_idx], y_all[train_idx]) + proba = clf.predict_proba(X_all[val_idx])[:, 1] + auroc_scores.append(roc_auc_score(y_all[val_idx], proba)) + + mean_auroc = float(np.mean(auroc_scores)) + return float(2.0 * abs(mean_auroc - 0.5)) + + +def _build_lag_matrix(x: np.ndarray, max_lag: int) -> np.ndarray: + """Return (T-max_lag, max_lag) lag matrix with rows [x_{t-1}, ..., x_{t-L}].""" + return lagmat(x, maxlag=max_lag, trim="forward", original="ex")[max_lag:] + + +def _compute_f_stats_batch( + x_arr: np.ndarray, + c_arr: np.ndarray, + max_lag: int, + pairs: List[Tuple[int, int]], +) -> np.ndarray: + """Compute mean F-statistic per sample for the given (dx, dc) pairs.""" + N = x_arr.shape[0] + f_per_sample = [] + for i in range(N): + f_vals = [] + for dx, dc in pairs: + xi = x_arr[i, :, dx] + ci = c_arr[i, :, dc] + + if np.isnan(xi).any() or np.isnan(ci).any(): + continue + + X_x = _build_lag_matrix(xi, max_lag) # (T-max_lag, max_lag) + X_c = _build_lag_matrix(ci, max_lag) # (T-max_lag, max_lag) + y = xi[max_lag:] + + ones = np.ones((len(y), 1)) + X_r = np.hstack([ones, X_x]) # restricted: own lags only + X_u = np.hstack([ones, X_x, X_c]) # unrestricted: own + c lags + + beta_r, _, _, _ = np.linalg.lstsq(X_r, y, rcond=None) + beta_u, _, _, _ = np.linalg.lstsq(X_u, y, rcond=None) + + rss_r = float(np.sum((y - X_r @ beta_r) ** 2)) + rss_u = float(np.sum((y - X_u @ beta_u) ** 2)) + + df1 = max_lag + df2 = len(y) - 2 * max_lag - 1 + + if df2 <= 0 or rss_u < 1e-12: + continue + + F = ((rss_r - rss_u) / df1) / (rss_u / df2) + f_vals.append(F) + + if f_vals: + f_per_sample.append(float(np.mean(f_vals))) + + return np.array(f_per_sample) + + +def compute_gcp( + x_real: np.ndarray, + x_synth: np.ndarray, + c_real: np.ndarray, + c_synth: np.ndarray, + max_lag: int = 5, + alpha: float = 0.05, +) -> Tuple[float, Dict]: + """ + Compute Granger Causality Preservation (GCP). + + Measures how well synthetic data preserves the Granger-causal structure from + context c → signal x, via Wasserstein distance between F-statistic distributions. + + Args: + x_real: (N, T, D_x) real signal + x_synth: (N, T, D_x) synthetic signal + c_real: (N, T, D_c) context for real data + c_synth: (N, T, D_c) context for synthetic data + max_lag: maximum lag order (capped at T // 10 for short series) + alpha: significance threshold for diagnostic sig-rate computation + + Returns: + gcp: float >= 0. 0 = perfect preservation. Higher = more divergence. + diagnostics: dict with sig_rate_real, sig_rate_synth, sig_rate_delta + """ + N, T, D_x = x_real.shape + D_c = c_real.shape[-1] + + # Cap lag for short series + max_lag = min(max_lag, max(1, T // 10)) + + pairs = list(product(range(D_x), range(D_c))) + + F_real = _compute_f_stats_batch(x_real, c_real, max_lag, pairs) + F_synth = _compute_f_stats_batch(x_synth, c_synth, max_lag, pairs) + + if len(F_real) == 0 or len(F_synth) == 0: + warnings.warn("compute_gcp: no valid F-statistics computed; returning nan.") + return float("nan"), {} + + gcp = float(wasserstein_distance(F_real, F_synth)) + + # Diagnostics: significance rates + df2 = T - 2 * max_lag - 1 + if df2 > 0: + pvals_real = f_dist.sf(F_real, max_lag, df2) + pvals_synth = f_dist.sf(F_synth, max_lag, df2) + sig_real = float(np.mean(pvals_real < alpha)) + sig_synth = float(np.mean(pvals_synth < alpha)) + else: + sig_real = sig_synth = float("nan") + + diagnostics = { + "sig_rate_real": sig_real, + "sig_rate_synth": sig_synth, + "sig_rate_delta": float(abs(sig_real - sig_synth)) if not np.isnan(sig_real) else float("nan"), + } + + return gcp, diagnostics + + def compute_mig( embeddings: np.ndarray, context_vars: Dict[str, np.ndarray], diff --git a/cents/models/context.py b/cents/models/context.py index 4dbf8ae..dbbb543 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -502,6 +502,15 @@ def __init__( # Per-variable weight (scalar) for the additive mixture across variables self.var_mix = nn.Linear(n_vars * embedding_dim, embedding_dim) if n_vars > 1 else None + # Per-variable layer norm applied after each transformer encoder output + all_var_names = list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars + self.post_encoder_norms = nn.ModuleDict({ + name: nn.LayerNorm(embedding_dim) + for name in all_var_names + }) + # Final layer norm applied after var_mix (or single-variable output) + self.post_mix_norm = nn.LayerNorm(embedding_dim) + # Initialize weights self._initialize_weights() @@ -548,6 +557,7 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, if self.pos_encodings is not None and name in self.pos_encodings: embedded = embedded + self.pos_encodings[name][:, :embedded.size(1)] encoded = self.ts_encoders[name](embedded) # (B, T, emb_dim) + encoded = self.post_encoder_norms[name](encoded) if torch.isnan(encoded).any() or torch.isinf(encoded).any(): raise ValueError(f"NaN/Inf after transformer encoding '{name}'") sequences.append(encoded) @@ -569,6 +579,7 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, if torch.isnan(embedded).any() or torch.isinf(embedded).any(): raise ValueError(f"NaN/Inf after projection for '{name}'") encoded = self.ts_encoders[name](embedded) # (B, T, emb_dim) + encoded = self.post_encoder_norms[name](encoded) if torch.isnan(encoded).any() or torch.isinf(encoded).any(): raise ValueError(f"NaN/Inf after transformer encoding numeric TS '{name}'") sequences.append(encoded) @@ -588,6 +599,8 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, # Single-variable fallback (var_mix is None only when n_vars == 1) out = sequences[0] + out = self.post_mix_norm(out) + if torch.isnan(out).any() or torch.isinf(out).any(): raise ValueError("NaN/Inf in dynamic context sequence output") diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index c50a250..dbc34e0 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -821,7 +821,7 @@ def load_state_dict(self, state_dict, strict=True): return super().load_state_dict(filtered, strict=strict) def on_load_checkpoint(self, checkpoint: dict) -> None: - if 'ema_state_dict' in checkpoint: + if 'ema_state_dict' in checkpoint and False: if self._ema is None: object.__setattr__(self, '_ema', EMA( self.model, diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 15c8501..5ff4275 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -262,7 +262,7 @@ def main() -> None: parser.add_argument( "--seed", type=int, - default=None, + default=42, help="Random seed for reproducible sampling (sets Python, NumPy, and PyTorch seeds).", ) parser.add_argument( @@ -278,11 +278,17 @@ def main() -> None: ), ) parser.add_argument( - "--save-path", + "--save-path", type=str, default=None, help="Path to save evaluation results." ) + parser.add_argument( + "--model-config", + type=str, + default=None, + help="Path to a model config YAML file. Overrides the default cents/config/model/{model_type}.yaml when using --model-ckpt.", + ) args = parser.parse_args() @@ -390,8 +396,9 @@ def main() -> None: cfg.evaluator = eval_cfg cfg.wandb = top_cfg.get("wandb", {}) cfg.device = f"cuda:{args.device}" + model_config_path = args.model_config if args.model_config else f"cents/config/model/{model_type}.yaml" cfg.model = OmegaConf.create( - OmegaConf.to_container(OmegaConf.load(f"cents/config/model/{model_type}.yaml"), resolve=True) + OmegaConf.to_container(OmegaConf.load(model_config_path), resolve=True) ) cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) if args.no_normalizer_global_preprocessing: diff --git a/scripts/train.py b/scripts/train.py index 9b9b1e8..11961a5 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -10,6 +10,7 @@ from cents.datasets.commercial import CommercialDataset from cents.datasets.airquality import AirQualityDataset from cents.datasets.metraq import MetraqDataset +from cents.datasets.walmart import WalmartDataset from cents.trainer import Trainer from cents.utils.utils import set_context_config_path, set_context_overrides, get_context_config @@ -138,6 +139,12 @@ def main(args) -> None: force_retrain_normalizer=args.force_retrain_normalizer, run_dir=str(run_dir), ) + elif args.dataset == "walmart": + dataset = WalmartDataset( + cfg=dataset_cfg, + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), + ) else: raise ValueError(f"Dataset {args.dataset} not supported") From bb5f213448d8337e9e78c9d4821858bdf5e82355 Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Sun, 12 Apr 2026 14:12:27 -0400 Subject: [PATCH 48/50] Joint dyn transformer, new eval metrics --- cents/config/dataset/airquality.yaml | 4 +- cents/config/dataset/metraq.yaml | 13 +-- cents/config/dataset/walmart.yaml | 8 +- cents/config/evaluator/airquality.yaml | 32 +++++++ cents/config/evaluator/default.yaml | 4 +- cents/config/evaluator/walmart.yaml | 34 +++++++ cents/config/model/diffusion_ts.yaml | 4 +- cents/datasets/metraq.py | 53 ++++++----- cents/datasets/walmart.py | 78 +++++++++++----- cents/eval/eval.py | 21 +++++ cents/eval/eval_metrics.py | 85 ++++++++++++++++++ cents/eval/predictive_score.py | 4 +- cents/models/context.py | 120 ++++++++++++++++++++++++- cents/models/diffusion_ts.py | 17 ++-- scripts/eval_pretrained.py | 23 ++--- 15 files changed, 422 insertions(+), 78 deletions(-) create mode 100644 cents/config/evaluator/airquality.yaml create mode 100644 cents/config/evaluator/walmart.yaml diff --git a/cents/config/dataset/airquality.yaml b/cents/config/dataset/airquality.yaml index 3ce51c2..cb14197 100644 --- a/cents/config/dataset/airquality.yaml +++ b/cents/config/dataset/airquality.yaml @@ -11,14 +11,14 @@ max_samples: null path: "./data/airquality" numeric_context_bins: 1 reduce_cardinality: False -time_series_dims: 1 +time_series_dims: 4 normalizer_stats_mode: group # Normalizer conditions only on these (e.g. per-station); diffusion still gets full context_vars normalizer_group_vars: ["station"] # Targets (what becomes the merged "timeseries" dims) # NOTE: use PMcoarse instead of PM10 -time_series_columns: ["PM2.5"] +time_series_columns: ["PM2.5", "SO2", "NO2", "CO"] # Raw CSV columns to load # Keep wd/WSPM because we need them to engineer wind_u/wind_v diff --git a/cents/config/dataset/metraq.yaml b/cents/config/dataset/metraq.yaml index 1ed5877..439f762 100644 --- a/cents/config/dataset/metraq.yaml +++ b/cents/config/dataset/metraq.yaml @@ -45,9 +45,10 @@ context_vars: # WS and WD are decomposed into wind_u/wind_v in preprocessing to handle # the circularity of wind direction (WD=355° ≈ WD=5°, but z-score would give opposite signs). T: ["time_series", null] - wind_u: ["time_series", null] - wind_v: ["time_series", null] - wd_valid: ["time_series", null] - RH: ["time_series", null] - AP: ["time_series", null] - R: ["time_series", null] \ No newline at end of file + # wind_u: ["time_series", null] + # wind_v: ["time_series", null] + # wd_valid: ["time_series", null] + # RH, AP, R dropped — per-sample correlation with PM2.5 < 0.025 across all stations + # Traffic: TI = vehicles/hour (Kriging interpolation); SP = avg speed km/h + TI: ["time_series", null] + SP: ["time_series", null] \ No newline at end of file diff --git a/cents/config/dataset/walmart.yaml b/cents/config/dataset/walmart.yaml index a11ac30..37f5eb6 100644 --- a/cents/config/dataset/walmart.yaml +++ b/cents/config/dataset/walmart.yaml @@ -4,7 +4,7 @@ normalize: True scale: False use_learned_normalizer: True threshold: 8 -seq_len: 30 +seq_len: 28 shuffle: True skip_heavy_processing: False max_samples: null @@ -24,7 +24,6 @@ context_vars: # Static categorical — characterise the window by when it starts year: ["categorical", 6] # 2011–2016 month: ["categorical", 12] - weekday: ["categorical", 7] # Static categorical — item / store identity cat_id: ["categorical", 3] # FOODS, HOBBIES, HOUSEHOLD dept_id: ["categorical", 7] # e.g. FOODS_1 … HOUSEHOLD_2 @@ -32,6 +31,7 @@ context_vars: state_id: ["categorical", 3] # CA, TX, WI # Dynamic time-series context (co-occurring with target, length = seq_len) - sell_price: ["time_series", null] # weekly price broadcast to daily, z-scored - snap: ["time_series", null] # binary SNAP eligibility for the item's state + # sell_price: ["time_series", null] # weekly price broadcast to daily, z-scored + # snap: ["time_series", null] # binary SNAP eligibility for the item's state event_binary: ["time_series", null] # 1 if a named calendar event falls on that day + weekday: ["time_series", null] # day of week encoded as 0 (Mon) – 6 (Sun), z-scored diff --git a/cents/config/evaluator/airquality.yaml b/cents/config/evaluator/airquality.yaml new file mode 100644 index 0000000..9b8017e --- /dev/null +++ b/cents/config/evaluator/airquality.yaml @@ -0,0 +1,32 @@ +model: + name: diffusion_ts +dataset: + name: airquality +eval_pv_shift: False +eval_metrics: True +eval_context_sparse: True +save_results: False +eval_disentanglement: True +eval_context_recovery: True +job_name: diffusion_ts_airquality +save_dir: outputs/diffusion_ts_airquality/eval + +# Context Faithfulness Score (CFS) and Granger Causality Preservation (GCP). +# Runs only when enabled=True AND either: +# - the generated signal has multiple dimensions (multivariate), OR +# - the dataset uses dynamic (time-series) context variables. +# +# pairs: list of {x, c} dicts specifying which time series to evaluate against each other. +# x — name of a generated output dimension (must match time_series_columns in dataset config) +# c — name of a dynamic context variable (from context_vars with type "time_series") +# OR another generated output dimension (multivariate case, GCP only) +# +# CFS is computed only when c is a dynamic context variable (shared context). +# GCP is computed for all pairs. +# +eval_context_faithfulness: + enabled: true + gcp_max_lag: 5 + pairs: + - {x: "PM2.5", c: "TEMP"} + - {x: "PM2.5", c: "DEWP"} diff --git a/cents/config/evaluator/default.yaml b/cents/config/evaluator/default.yaml index 7d09580..b7d1523 100644 --- a/cents/config/evaluator/default.yaml +++ b/cents/config/evaluator/default.yaml @@ -7,6 +7,7 @@ eval_metrics: True eval_context_sparse: True save_results: False eval_disentanglement: True +eval_context_recovery: True job_name: diffusion_ts_commercial save_dir: outputs/diffusion_ts_commercial/eval @@ -33,5 +34,6 @@ eval_context_faithfulness: gcp_max_lag: 5 pairs: - {x: "PM2.5", c: "T"} - - {x: "PM2.5", c: "R"} + - {x: "PM2.5", c: "TI"} + - {x: "PM2.5", c: "SP"} # - {x: "PM2.5", c: "TEMP"} diff --git a/cents/config/evaluator/walmart.yaml b/cents/config/evaluator/walmart.yaml new file mode 100644 index 0000000..5785d0a --- /dev/null +++ b/cents/config/evaluator/walmart.yaml @@ -0,0 +1,34 @@ +model: + name: diffusion_ts +dataset: + name: walmart +eval_pv_shift: False +eval_metrics: True +eval_context_sparse: True +save_results: False +eval_disentanglement: True +eval_context_recovery: True +job_name: diffusion_ts_walmart +save_dir: outputs/diffusion_ts_walmart/eval + +# Context Faithfulness Score (CFS) and Granger Causality Preservation (GCP). +# Runs only when enabled=True AND either: +# - the generated signal has multiple dimensions (multivariate), OR +# - the dataset uses dynamic (time-series) context variables. +# +# pairs: list of {x, c} dicts specifying which time series to evaluate against each other. +# x — name of a generated output dimension (must match time_series_columns in dataset config) +# c — name of a dynamic context variable (from context_vars with type "time_series") +# OR another generated output dimension (multivariate case, GCP only) +# +# CFS is computed only when c is a dynamic context variable (shared context). +# GCP is computed for all pairs. +# +eval_context_faithfulness: + enabled: true + gcp_max_lag: 5 + pairs: + - {x: "sales", c: "sell_price"} + - {x: "sales", c: "snap"} + - {x: "sales", c: "event_binary"} + - {x: "sales", c: "weekday"} diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 44c9a54..73daad9 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -11,7 +11,7 @@ n_steps: 1000 sampling_timesteps: 200 sampling_batch_size: 4096 loss_type: l1 #l2 -training_objective: x0 +training_objective: v loss_weighting: snr min_snr_gamma: 5.0 beta_schedule: cosine #linear diffusion ts paper uses linear schedule @@ -25,7 +25,7 @@ padding_size: null use_ff: True reg_weight: null gradient_accumulate_every: 2 -ema_decay: 0.9999 +ema_decay: 0.999 ema_update_interval: 1 use_ema_sampling: True k_bins: 20 diff --git a/cents/datasets/metraq.py b/cents/datasets/metraq.py index 43c9b57..59bbbe8 100644 --- a/cents/datasets/metraq.py +++ b/cents/datasets/metraq.py @@ -193,36 +193,45 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: ctx_numeric = [c for c in ctx_ts if c not in self.categorical_time_series] - log1p_channels = {"R"} + # TI (traffic intensity) is strictly non-negative — log1p compresses the + # heavy right tail (rush-hour spikes) before z-scoring. + log1p_channels = {"TI"} binary_channels = {"wd_valid"} # already in [0, 1] — skip z-scoring clip_bound = 5.0 eps = 1e-8 - # Compute global mean/std per channel over all rows and timesteps - ctx_stats = {} - for c in ctx_numeric: - # stacked shape: (N, L) - X = np.stack(grouped[c].values).astype(np.float32) + # Per-station z-score normalization: compute (mu, sd) separately for each + # sensor_name so the model sees locally-relative deviations. A global + # z-score would conflate cross-station level differences with within-station + # variation, obscuring the context–target relationship the model needs to learn. + ctx_stats = {} # {channel: {sensor_name: (mu, sd)}} + for c in ctx_numeric: if c in binary_channels: - # Already in [0, 1] — pass through without z-scoring - grouped[c] = list(X) + grouped[c] = list(np.stack(grouped[c].values).astype(np.float32)) continue - if c in log1p_channels: - X = np.log1p(np.clip(X, a_min=0.0, a_max=None)) - - mu = float(X.mean()) - sd = float(X.std()) - if sd < 1e-6: - sd = 1.0 # avoid divide-by-zero; effectively makes it "center only" - ctx_stats[c] = (mu, sd) + ctx_stats[c] = {} + col_arrays = grouped[c].map(np.asarray) - Xn = (X - mu) / (sd + eps) - Xn = np.clip(Xn, -clip_bound, clip_bound).astype(np.float32) - - grouped[c] = list(Xn) - - # (Optional) store for later inverse-transform / debugging + if c in log1p_channels: + col_arrays = col_arrays.map(lambda x: np.log1p(np.clip(x, 0.0, None))) + + normalized = col_arrays.copy() + for stn, idx in grouped.groupby("sensor_name").groups.items(): + # idx contains label-based indices (not positional) — use .loc + X = np.stack(col_arrays.loc[idx].values).astype(np.float32) + mu = float(X.mean()) + sd = float(X.std()) + if sd < 1e-6: + sd = 1.0 + ctx_stats[c][stn] = (mu, sd) + Xn = np.clip((X - mu) / (sd + eps), -clip_bound, clip_bound) + for arr_i, row_i in enumerate(idx): + normalized.loc[row_i] = Xn[arr_i] + + grouped[c] = list(normalized) + + # Store for later inverse-transform / debugging self.context_ts_stats_ = ctx_stats # arrays -> tuples (hashable) diff --git a/cents/datasets/walmart.py b/cents/datasets/walmart.py index 2e2bf51..52a4837 100644 --- a/cents/datasets/walmart.py +++ b/cents/datasets/walmart.py @@ -1,3 +1,4 @@ +import logging import os import warnings from typing import List, Optional @@ -24,11 +25,12 @@ class WalmartDataset(TimeSeriesDataset): Data: https://www.kaggle.com/competitions/m5-forecasting-accuracy - Each sample is a 30-day non-overlapping window of daily unit sales for a - single item-store pair, filtered to high-velocity items (top 20% by mean - daily sales). Static context encodes category, department, store, and - state. Dynamic context includes sell price, SNAP eligibility, and a - binary calendar-event indicator. + Each sample is the first 28 days of a calendar month for a single + item-store pair (days 29–31 are discarded), filtered to high-velocity + items (top 20% by mean daily sales). The 28-day window aligns exactly + with four weeks, eliminating month-length drift and making month a clean + static context variable. Dynamic context includes sell price, SNAP + eligibility, and a binary calendar-event indicator. """ def __init__( @@ -95,6 +97,29 @@ def _load_data(self) -> None: sales = pd.read_csv(sales_path) meta_cols = ["id", "item_id", "dept_id", "cat_id", "store_id", "state_id"] day_cols = [c for c in sales.columns if c.startswith("d_")] + + # --- High-velocity filter computed from wide format (before the expensive melt) --- + day_means = sales[day_cols].mean(axis=1) + vel_threshold = day_means.quantile(0.80) + keep_ids = set(sales.loc[day_means >= vel_threshold, "id"].values) + + # --- Early ID subsampling when max_samples is set --- + # Each high-velocity ID produces ~60 monthly windows across the ~5-year dataset. + # Keeping 2× the IDs needed gives a safe buffer for windows dropped during + # preprocessing (missing data, near-constant sequences, etc.). + max_samples = self.cfg.get("max_samples", None) + if max_samples is not None and len(keep_ids) > 0: + est_windows_per_id = 60 + n_ids = min(len(keep_ids), int(np.ceil(max_samples * 2.0 / est_windows_per_id))) + keep_ids = set(np.random.choice(sorted(keep_ids), size=n_ids, replace=False)) + logging.info( + "WalmartDataset: subsampling %d IDs (est. ~%d windows) for max_samples=%d", + n_ids, n_ids * est_windows_per_id, max_samples, + ) + + # Filter BEFORE the melt — dramatically reduces the wide→long expansion cost + sales = sales[sales["id"].isin(keep_ids)] + sales_long = sales[meta_cols + day_cols].melt( id_vars=meta_cols, var_name="d", value_name="sales" ) @@ -133,15 +158,16 @@ def _load_data(self) -> None: # --- Binary calendar-event indicator --- sales_long["event_binary"] = sales_long["event_name_1"].notna().astype(np.int8) - # --- Filter to high-velocity item-store pairs (top 20% by mean daily sales) --- - mean_sales = sales_long.groupby("id")["sales"].mean() - threshold = mean_sales.quantile(0.80) - high_vel_ids = mean_sales.index[mean_sales >= threshold] - sales_long = sales_long[sales_long["id"].isin(high_vel_ids)].copy() - # --- Month name (consistent with other datasets) --- sales_long["month"] = sales_long["month"].map(lambda x: _MONTHS[int(x) - 1]) + # --- Weekday as integer 0 (Mon) – 6 (Sun) for use as dynamic context --- + _WEEKDAY_MAP = { + "Monday": 0, "Tuesday": 1, "Wednesday": 2, "Thursday": 3, + "Friday": 4, "Saturday": 5, "Sunday": 6, + } + sales_long["weekday"] = sales_long["weekday"].map(_WEEKDAY_MAP).astype(np.int8) + # Drop columns that are no longer needed after engineering sales_long.drop( columns=["snap_CA", "snap_TX", "snap_WI", "event_name_1", @@ -154,9 +180,11 @@ def _load_data(self) -> None: def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: """ - Assigns non-overlapping 30-day window IDs within each item-store - series, groups into windows, and z-score normalises continuous - dynamic context channels. + Groups data into calendar-month windows using only the first 28 days + of each month (days 29–31 are discarded). Each (id, year, month) + group therefore has exactly 28 days, giving fixed-length sequences + with no end-of-month drift and month as a clean context variable. + Continuous dynamic context channels are z-score normalised. """ data = data.copy() @@ -164,27 +192,26 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: tgt_ts = list(self.target_time_series_columns) # ["sales"] all_ts = ts_ctx + tgt_ts - # Assign window indices: non-overlapping blocks of seq_len days per series + # Keep only the first 28 days of each calendar month data = data.sort_values(["id", "date"]).reset_index(drop=True) - data["_row"] = data.groupby("id").cumcount() - data["_window"] = data["_row"] // self.cfg.seq_len + data = data[data["date"].dt.day <= 28].reset_index(drop=True) - group_keys = ["id", "_window"] - static_cols = ["cat_id", "dept_id", "store_id", "state_id", "year", "month", "weekday"] + # Group by calendar month; year and month are already columns + group_keys = ["id", "year", "month"] + static_cols = ["cat_id", "dept_id", "store_id", "state_id", "year", "month"] agg_dict = {c: list for c in all_ts} - agg_dict.update({c: "first" for c in static_cols}) + agg_dict.update({c: "first" for c in static_cols if c not in group_keys}) grouped = ( data.groupby(group_keys, as_index=False, sort=False) .agg(agg_dict) ) - grouped.drop(columns=["_window"], inplace=True, errors="ignore") for c in all_ts: grouped[c] = grouped[c].map(np.asarray) - # Keep only complete windows + # Keep only complete 28-day windows len_col = tgt_ts[0] grouped = grouped[grouped[len_col].apply(len) == self.cfg.seq_len].reset_index(drop=True) @@ -249,6 +276,13 @@ def _handle_missing_data(self, data: pd.DataFrame) -> pd.DataFrame: ) ] + # Drop windows where any target series sums to zero (all-zero sequences) + for tcol in self.target_time_series_columns: + if tcol in data.columns: + data = data.loc[ + data[tcol].apply(lambda x: np.asarray(x, dtype=float).sum() > 0) + ] + # Drop near-constant windows (std < 0.01 → degenerate for diffusion) def _low_std(row, cols, thresh=0.01): return any(np.asarray(row[c], dtype=np.float32).std() < thresh for c in cols) diff --git a/cents/eval/eval.py b/cents/eval/eval.py index 2d72dbf..045ec1f 100644 --- a/cents/eval/eval.py +++ b/cents/eval/eval.py @@ -19,6 +19,7 @@ Context_FID, calculate_mmd, compute_cfs, + compute_context_recovery_score, compute_gcp, compute_mig, compute_sap, @@ -193,6 +194,8 @@ def compute_quality_metrics( log_prefix: str = "", context_arrays: Optional[Dict[str, np.ndarray]] = None, ts_column_names: Optional[list] = None, + static_context_arrays: Optional[Dict[str, np.ndarray]] = None, + continuous_vars: Optional[list] = None, ) -> Dict: """ Compute evaluation metrics and store them in current_results (or in target if provided). @@ -240,6 +243,14 @@ def compute_quality_metrics( ) metrics["context_faithfulness"] = cf_metrics + # Context Recovery Score — tests whether static context is encoded in generated outputs + if self.cfg.evaluator.get("eval_context_recovery", False) and static_context_arrays: + crs_score, crs_per_var = compute_context_recovery_score( + real_data, syn_data, static_context_arrays, continuous_vars=continuous_vars + ) + metrics["context_recovery"] = {"overall": crs_score, "per_var": crs_per_var} + logger.info(f"[Cents] Context Recovery Score completed: {crs_score:.4f}") + if mask is not None: logger.info("[Cents] Starting Rare-Subset Metrics") rare_metrics = {} @@ -536,6 +547,12 @@ def _inv(df): and (has_dynamic_context or has_multivariate_signal) ) + # Build static context numpy arrays for Context Recovery Score + static_context_np: Dict[str, np.ndarray] = { + name: tensor.cpu().numpy() for name, tensor in static_context_vars.items() + } + dataset_continuous_vars: list = getattr(dataset, "continuous_vars", []) + if self.cfg.evaluator.eval_metrics: rare_mask = None @@ -551,6 +568,8 @@ def _inv(df): log_prefix="[raw] ", context_arrays=context_np if run_cf else None, ts_column_names=ts_col_names if run_cf else None, + static_context_arrays=static_context_np if static_context_np else None, + continuous_vars=dataset_continuous_vars, ) # Metrics in normalized (z) domain for cross-domain comparability. @@ -570,6 +589,8 @@ def _inv(df): log_prefix="[normalized] ", context_arrays=context_np if run_cf else None, ts_column_names=ts_col_names if run_cf else None, + static_context_arrays=static_context_np if static_context_np else None, + continuous_vars=dataset_continuous_vars, ) self.current_results["metrics"]["normalized_domain"] = normalized_metrics diff --git a/cents/eval/eval_metrics.py b/cents/eval/eval_metrics.py index 7c78487..a6e8ab4 100644 --- a/cents/eval/eval_metrics.py +++ b/cents/eval/eval_metrics.py @@ -377,6 +377,91 @@ def compute_gcp( return gcp, diagnostics +def compute_context_recovery_score( + real_data: np.ndarray, + syn_data: np.ndarray, + context_labels: Dict[str, np.ndarray], + continuous_vars: Optional[List[str]] = None, + test_ratio: float = 0.2, +) -> Tuple[float, Dict[str, Dict]]: + """ + Context Recovery Score: measures whether static context is reflected in generated outputs. + + Trains a predictor f: timeseries -> context_label on real data, then evaluates + accuracy on synthetic data conditioned on the same labels. High score means + the model correctly encodes the conditioning variable in the output. + + For categorical variables: classification accuracy (chance = 1/n_classes). + For continuous variables: R² (0 = no recovery, 1 = perfect). + + Args: + real_data: (N, T, D) real time series + syn_data: (N, T, D) synthetic time series (same conditioning as real) + context_labels: {name: (N,) array of integer labels or float values} + continuous_vars: names of continuous context variables; all others treated as categorical + test_ratio: fraction of real data held out for real_baseline evaluation + + Returns: + overall_score: mean synth_score across all context variables + per_var: {name: {"synth_score": float, "real_baseline": float, "type": str}} + """ + if continuous_vars is None: + continuous_vars = [] + + N, T, D = real_data.shape + + def _features(data: np.ndarray) -> np.ndarray: + """Mean + std pool over time → (N, 2*D) fixed-size representation.""" + return np.concatenate([data.mean(axis=1), data.std(axis=1)], axis=-1) + + X_real = _features(real_data) + X_synth = _features(syn_data) + + rng = np.random.RandomState(42) + idx = rng.permutation(N) + test_size = max(1, int(N * test_ratio)) + train_idx = idx[test_size:] + test_idx = idx[:test_size] + + per_var: Dict[str, Dict] = {} + scores: List[float] = [] + + for name, labels in context_labels.items(): + labels_f = labels.astype(float) + if np.isnan(labels_f).any(): + per_var[name] = {"synth_score": float("nan"), "real_baseline": float("nan"), "type": "unknown"} + continue + + if name in continuous_vars: + y = labels_f + clf = Ridge(alpha=1e-3, fit_intercept=True) + clf.fit(X_real[train_idx], y[train_idx]) + real_baseline = float(r2_score(y[test_idx], clf.predict(X_real[test_idx]))) + synth_score = float(r2_score(y, clf.predict(X_synth))) + score_type = "r2" + else: + y = labels.astype(int) + n_classes = len(np.unique(y)) + if n_classes < 2: + per_var[name] = {"synth_score": float("nan"), "real_baseline": float("nan"), "type": "accuracy"} + continue + clf = LogisticRegression(max_iter=1000, random_state=42) + clf.fit(X_real[train_idx], y[train_idx]) + real_baseline = float(np.mean(clf.predict(X_real[test_idx]) == y[test_idx])) + synth_score = float(np.mean(clf.predict(X_synth) == y)) + score_type = "accuracy" + + per_var[name] = { + "synth_score": synth_score, + "real_baseline": real_baseline, + "type": score_type, + } + scores.append(synth_score) + + overall = float(np.mean(scores)) if scores else float("nan") + return overall, per_var + + def compute_mig( embeddings: np.ndarray, context_vars: Dict[str, np.ndarray], diff --git a/cents/eval/predictive_score.py b/cents/eval/predictive_score.py index e4371c9..68b9ae0 100644 --- a/cents/eval/predictive_score.py +++ b/cents/eval/predictive_score.py @@ -68,7 +68,7 @@ def predictive_score_metrics(ori_data, generated_data): X_mb = [ generated_data[i][:-1, :] for i in train_idx ] # Use all dimensions for input - T_mb = [generated_time[i] - 1 for i in train_idx] + T_mb = [max(generated_time[i] - 1, 1) for i in train_idx] Y_mb = [ generated_data[i][1:, :].reshape(-1, dim) for i in train_idx ] # Predict all dimensions @@ -84,7 +84,7 @@ def predictive_score_metrics(ori_data, generated_data): optimizer.step() X_mb = [ori_data[i][:-1, :] for i in range(no)] - T_mb = [ori_time[i] - 1 for i in range(no)] + T_mb = [max(ori_time[i] - 1, 1) for i in range(no)] Y_mb = [ori_data[i][1:, :].reshape(-1, dim) for i in range(no)] X_mb = torch.tensor(np.array(X_mb), dtype=torch.float32).to(device) diff --git a/cents/models/context.py b/cents/models/context.py index dbbb543..246488c 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -609,4 +609,122 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, def on_after_backward(self): unused = [n for n,p in self.named_parameters() if p.requires_grad and p.grad is None] if unused: - print("UNUSED:", unused[:50]) \ No newline at end of file + print("UNUSED:", unused[:50]) + + +@register_context_module("dynamic_joint_transformer") +class DynamicContextModule_JointTransformer(BaseContextModule): + """ + Joint multi-channel encoder for dynamic context. + + All numeric time-series variables are stacked as channels (B, T, n_vars), + projected jointly to (B, T, embedding_dim), then encoded by a single shared + Transformer. Self-attention operates across time steps while seeing all + variables simultaneously, allowing it to learn non-linear variable + interactions (e.g. high TI + low wind → elevated PM2.5). + + Compare to DynamicContextModule_Transformer which runs one independent + transformer per variable and combines them with a linear mix — that + architecture can only learn additive contributions. + """ + + returns_sequence = True + + def __init__( + self, + context_vars: dict, + embedding_dim: int, + seq_len: int = None, + n_layers: int = 2, + n_heads: int = 4, + dropout: float = 0.1, + dim_feedforward: int = 256, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + self.numeric_ts_vars = [ + k for k, v in context_vars.items() + if v[0] == "time_series" and v[1] is None + ] + self.categorical_ts_vars = { + k: v[1] for k, v in context_vars.items() + if v[0] == "time_series" and v[1] is not None + } + + n_numeric = len(self.numeric_ts_vars) + + # Project all numeric channels jointly: (B, T, n_vars) → (B, T, emb_dim) + # This single linear sees every variable at every timestep simultaneously. + if n_numeric > 0: + self.numeric_input_proj = nn.Linear(n_numeric, embedding_dim) + else: + self.numeric_input_proj = None + + # Categorical time-series: embed each to emb_dim and add into the joint repr + self.ts_cat_embeddings = nn.ModuleDict({ + name: nn.Embedding(num_categories, embedding_dim) + for name, num_categories in self.categorical_ts_vars.items() + }) + + if seq_len is not None: + self.pos_encoding = nn.Parameter(torch.zeros(1, seq_len, embedding_dim)) + else: + self.pos_encoding = None + + encoder_layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=n_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation='gelu', + batch_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + self.post_norm = nn.LayerNorm(embedding_dim) + + self._init_weights() + + def _init_weights(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + if self.pos_encoding is not None: + nn.init.normal_(self.pos_encoding, std=0.02) + + def forward(self, context_vars: dict) -> tuple: + numeric_tensors = [] + for name in self.numeric_ts_vars: + if name in context_vars: + ts = context_vars[name] + if not ts.is_floating_point(): + ts = ts.float() + ts = torch.where(torch.isfinite(ts), ts, torch.zeros_like(ts)) + ts_mean = ts.mean(dim=1, keepdim=True) + ts_std = ts.std(dim=1, keepdim=True) + 1e-8 + ts = (ts - ts_mean) / ts_std + numeric_tensors.append(ts) + + if numeric_tensors and self.numeric_input_proj is not None: + # (B, T, n_vars) → (B, T, emb_dim) + x = self.numeric_input_proj(torch.stack(numeric_tensors, dim=-1)) + else: + device = next(iter(context_vars.values())).device + B = next(iter(context_vars.values())).size(0) + T = self.seq_len or 1 + x = torch.zeros(B, T, self.embedding_dim, device=device) + + for name, emb_layer in self.ts_cat_embeddings.items(): + if name in context_vars: + x = x + emb_layer(context_vars[name]) # (B, T, emb_dim) + + if self.pos_encoding is not None: + x = x + self.pos_encoding[:, :x.size(1)] + + x = self.encoder(x) + x = self.post_norm(x) + + return x, {} \ No newline at end of file diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index dbc34e0..fcb3495 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -586,12 +586,17 @@ def forward(self, x: torch.Tensor, static_context_vars: dict, dynamic_context_va _nan_check(pred_noise, "forward pred_noise (eps)") loss_per_elem = self.recon_loss_fn(pred_noise, noise, reduction="none") else: # v - pred_noise = self.predict_noise_from_start(x_noisy, t, x_start_pred) - _nan_check(pred_noise, "forward pred_noise (v)") - pred_v = ( - self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * pred_noise - - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x_start_pred - ) + # Compute pred_v directly from x_start_pred and x_noisy, avoiding the + # two-step path through predict_noise_from_start which divides by + # sqrt_recipm1_alphas_cumprod — a value near 0 at low t (cosine schedule + # gives ~0.01 at t=0), amplifying prediction errors ~100x into pred_noise + # before they land in pred_v. The algebraic identity: + # v = sqrt(α_bar)*ε - sqrt(1-α_bar)*x0 + # ε = (x_noisy - sqrt(α_bar)*x0) / sqrt(1-α_bar) + # => pred_v = (sqrt(α_bar)*x_noisy - x0) / sqrt(1-α_bar).clamp(min=1e-3) + sqrt_ab = self.sqrt_alphas_cumprod[t].view(-1, 1, 1) + sqrt_1mab = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1).clamp(min=1e-3) + pred_v = (sqrt_ab * x_noisy - x_start_pred) / sqrt_1mab true_v = ( self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * noise - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 5ff4275..2f7a260 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -18,6 +18,7 @@ from cents.datasets.commercial import CommercialDataset from cents.datasets.airquality import AirQualityDataset from cents.datasets.metraq import MetraqDataset +from cents.datasets.walmart import WalmartDataset from cents.eval.eval import Evaluator from cents.utils.config_loader import load_yaml, apply_overrides from cents.utils.utils import set_context_config_path, set_context_overrides @@ -59,6 +60,8 @@ def _load_dataset(name: str, dataset_cfg: OmegaConf, run_dir: str = None): return AirQualityDataset(**kwargs) if name == "metraq": return MetraqDataset(**kwargs) + if name == "walmart": + return WalmartDataset(**kwargs) raise ValueError(f"Dataset {name} not supported. Use: pecanstreet, commercial, airquality.") @@ -153,7 +156,7 @@ def main() -> None: "--dataset", type=str, default="pecanstreet", - choices=("pecanstreet", "commercial", "airquality", "metraq"), + choices=("pecanstreet", "commercial", "airquality", "metraq", "walmart"), help="Dataset name (must match the one used to train the model).", ) parser.add_argument( @@ -300,6 +303,15 @@ def main() -> None: if use_run_path and args.model_ckpt: parser.error("Do not use --model-ckpt with --run-path; checkpoint is resolved from run-path and --epoch.") + # Set random seed before any dataset loading so that subsampling is reproducible + if args.seed is not None: + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + logging.info("Random seed set to %d", args.seed) + # Set custom context config path if provided if args.context_config_path: set_context_config_path(args.context_config_path) @@ -442,15 +454,6 @@ def main() -> None: # gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) cfg.dataset = gen.model.cfg.dataset - # Set random seed for reproducible sampling - if args.seed is not None: - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(args.seed) - logging.info("Random seed set to %d", args.seed) - # Set CFG scale on the model instance (read by generate() at inference time) if args.cfg_scale != 1.0: logging.info("Classifier-free guidance scale: %.2f", args.cfg_scale) From b07a57f6f4f3e253e4eee8629f1cf5d090ffbefa Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Fri, 24 Apr 2026 10:50:41 -0400 Subject: [PATCH 49/50] Context Recovery Score, rounding for walmart dataset, additional static embedder --- cents/config/context/default.yaml | 8 +- cents/config/dataset/airquality.yaml | 4 +- cents/config/dataset/pecanstreet.yaml | 1 - cents/config/dataset/walmart.yaml | 4 +- cents/config/model/diffusion_ts.yaml | 2 +- cents/config/trainer/normalizer.yaml | 2 +- cents/data_generator.py | 21 ++++- cents/datasets/commercial.py | 10 +++ cents/datasets/timeseries_dataset.py | 7 +- cents/eval/eval_metrics.py | 4 +- cents/models/base.py | 8 +- cents/models/context.py | 118 ++++++++++++++++++++++++++ cents/models/model_utils.py | 2 +- cents/trainer.py | 13 ++- scripts/generate.py | 77 ++++++++++++++--- 15 files changed, 249 insertions(+), 32 deletions(-) diff --git a/cents/config/context/default.yaml b/cents/config/context/default.yaml index 4a413ca..58ab35d 100644 --- a/cents/config/context/default.yaml +++ b/cents/config/context/default.yaml @@ -3,10 +3,12 @@ # Static context: used by generative models (ACGAN, Diffusion_TS) for conditioning static_context: - type: mlp # Context module type (e.g., "mlp", "sep_mlp") - # Future parameters can be added here: + type: mlp # Options: "mlp", "sep_mlp", "transformer" + # TransformerStaticContextModule hyperparameters (ignored by mlp/sep_mlp): + # n_heads: 4 # n_layers: 2 - # hidden_dim: 256 + # dropout: 0.1 + # dim_feedforward: 256 # Normalizer: stats head configuration for the normalizer normalizer: diff --git a/cents/config/dataset/airquality.yaml b/cents/config/dataset/airquality.yaml index cb14197..8088cba 100644 --- a/cents/config/dataset/airquality.yaml +++ b/cents/config/dataset/airquality.yaml @@ -11,14 +11,14 @@ max_samples: null path: "./data/airquality" numeric_context_bins: 1 reduce_cardinality: False -time_series_dims: 4 +time_series_dims: 1 normalizer_stats_mode: group # Normalizer conditions only on these (e.g. per-station); diffusion still gets full context_vars normalizer_group_vars: ["station"] # Targets (what becomes the merged "timeseries" dims) # NOTE: use PMcoarse instead of PM10 -time_series_columns: ["PM2.5", "SO2", "NO2", "CO"] +time_series_columns: ["PM2.5" ] # "SO2", "NO2", "CO"] # Raw CSV columns to load # Keep wd/WSPM because we need them to engineer wind_u/wind_v diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index ad30e36..bc1139a 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -16,7 +16,6 @@ metadata_columns: ["dataid","building_type","solar","car1","city","state","total user_group: all # non_pv_users, all, pv_users numeric_context_bins: 5 normalizer_stats_mode: group -# normalizer_group_vars: ["state", "city"] context_vars: diff --git a/cents/config/dataset/walmart.yaml b/cents/config/dataset/walmart.yaml index 37f5eb6..41c0a05 100644 --- a/cents/config/dataset/walmart.yaml +++ b/cents/config/dataset/walmart.yaml @@ -31,7 +31,7 @@ context_vars: state_id: ["categorical", 3] # CA, TX, WI # Dynamic time-series context (co-occurring with target, length = seq_len) - # sell_price: ["time_series", null] # weekly price broadcast to daily, z-scored - # snap: ["time_series", null] # binary SNAP eligibility for the item's state + sell_price: ["time_series", null] # weekly price broadcast to daily, z-scored + snap: ["time_series", null] # binary SNAP eligibility for the item's state event_binary: ["time_series", null] # 1 if a named calendar event falls on that day weekday: ["time_series", null] # day of week encoded as 0 (Mon) – 6 (Sun), z-scored diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 73daad9..1de3cbf 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -2,7 +2,7 @@ _target_: generator.diffusion_ts.gaussian_diffusion.Diffusion_TS name: diffusion_ts context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 -noise_dim: 256 +noise_dim: 128 cond_emb_dim: 64 n_layer_enc: 4 n_layer_dec: 5 diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index f090bac..ad08d89 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,6 +1,6 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: 2, +devices: 0, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 diff --git a/cents/data_generator.py b/cents/data_generator.py index 81f7038..883e769 100644 --- a/cents/data_generator.py +++ b/cents/data_generator.py @@ -172,12 +172,16 @@ def set_context(self, auto_fill_missing: bool = False, **context_vars: Union[int self._ctx_buff[var] = torch.tensor(code, dtype=torch.long, device=self.device) @torch.no_grad() - def generate(self, n: int = 128) -> "pd.DataFrame": + def generate(self, n: int = 128, stochastic_round: bool = False) -> "pd.DataFrame": """ Produce n synthetic samples under the previously set context. Args: n: Number of samples to generate. + stochastic_round: If True, round output to non-negative integers using + stochastic rounding (floor + bernoulli on fractional part). Applied + after inverse_transform so normalizer stats are unaffected. Useful + for count-valued datasets (e.g. Walmart unit sales). Returns: DataFrame with context columns + 'timeseries'. @@ -195,7 +199,20 @@ def generate(self, n: int = 128) -> "pd.DataFrame": ctx_batch = {k: v.repeat(n) for k, v in self._ctx_buff.items()} ts = self.model.generate(ctx_batch) df = convert_generated_data_to_df(ts, self._ctx_buff, decode=False) - return self.normalizer.inverse_transform(df) if self.normalizer else df + df = self.normalizer.inverse_transform(df) if self.normalizer else df + + if stochastic_round: + import numpy as np + + def _stochastic_round(x): + x = np.clip(x, 0, None) + floor = np.floor(x).astype(int) + frac = x - floor + return floor + (np.random.random(x.shape) < frac).astype(int) + + df["timeseries"] = df["timeseries"].apply(_stochastic_round) + + return df def load_from_checkpoint( self, diff --git a/cents/datasets/commercial.py b/cents/datasets/commercial.py index 7bb7b66..75739b8 100644 --- a/cents/datasets/commercial.py +++ b/cents/datasets/commercial.py @@ -91,6 +91,16 @@ def _load_data(self): data = data[data["site_id"] == self.geography] metadata = metadata[metadata["site_id"] == self.geography] + max_samples = self.cfg.get('max_samples', None) + if max_samples is not None: + unique_ids = data['dataid'].unique() + # Each building produces ~700 daily sequences; 2x buffer accounts for filtering + est_seqs_per_id = 700 + n_ids = min(len(unique_ids), max(1, int(np.ceil(max_samples * 2.0 / est_seqs_per_id)))) + kept_ids = np.random.choice(unique_ids, size=n_ids, replace=False) + data = data[data['dataid'].isin(kept_ids)] + print(f"CommercialDataset: subsampling {n_ids} building IDs (est. ~{n_ids * est_seqs_per_id} sequences) for max_samples={max_samples}") + self.data = data self.metadata = metadata diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 347fead..2e68f05 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -125,8 +125,11 @@ def __init__( self.context_cfg = get_context_config() - self.dynamic_module_type = self.context_cfg.dynamic_context.type - self.static_module_type = self.context_cfg.static_context.type + self.dynamic_module_type = self.context_cfg.dynamic_context.type + # Normalizer uses its own context type (defaults to mlp) so that switching the + # diffusion model to a heavier static embedder (e.g. transformer) doesn't affect + # the much simpler normalizer training. + self.static_module_type = getattr(self.context_cfg.normalizer, "context_type", "mlp") self.stats_head_type = self.context_cfg.normalizer.stats_head_type is_ddp_subprocess = self._is_ddp_subprocess() diff --git a/cents/eval/eval_metrics.py b/cents/eval/eval_metrics.py index a6e8ab4..8b5e846 100644 --- a/cents/eval/eval_metrics.py +++ b/cents/eval/eval_metrics.py @@ -13,6 +13,8 @@ from sklearn.linear_model import LogisticRegression, Ridge from sklearn.metrics import mutual_info_score, r2_score, roc_auc_score from sklearn.model_selection import StratifiedKFold +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler from statsmodels.tsa.tsatools import lagmat from cents.eval.eval_utils import ( @@ -445,7 +447,7 @@ def _features(data: np.ndarray) -> np.ndarray: if n_classes < 2: per_var[name] = {"synth_score": float("nan"), "real_baseline": float("nan"), "type": "accuracy"} continue - clf = LogisticRegression(max_iter=1000, random_state=42) + clf = make_pipeline(StandardScaler(), LogisticRegression(max_iter=2000, random_state=42)) clf.fit(X_real[train_idx], y[train_idx]) real_baseline = float(np.mean(clf.predict(X_real[test_idx]) == y[test_idx])) synth_score = float(np.mean(clf.predict(X_synth) == y)) diff --git a/cents/models/base.py b/cents/models/base.py index 10a82bb..0221f73 100644 --- a/cents/models/base.py +++ b/cents/models/base.py @@ -6,7 +6,7 @@ import torch.nn as nn from omegaconf import DictConfig -from cents.models.context import MLPContextModule, SepMLPContextModule # Import to trigger registration +from cents.models.context import MLPContextModule, SepMLPContextModule, TransformerStaticContextModule # Import to trigger registration from cents.models.context_registry import get_context_module_cls from cents.utils.utils import get_context_config @@ -63,9 +63,15 @@ def __init__(self, cfg: DictConfig = None): if k in static_context_vars } print(static_context_vars_dict) + static_ctx_kwargs = { + k: getattr(context_cfg.static_context, k) + for k in ("n_heads", "n_layers", "dropout", "dim_feedforward") + if hasattr(context_cfg.static_context, k) + } self.static_context_module = StaticContextModuleCls( static_context_vars_dict, emb_dim, + **static_ctx_kwargs, ) # Create dynamic context module (for time_series) diff --git a/cents/models/context.py b/cents/models/context.py index 246488c..06a0e12 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -238,6 +238,124 @@ def forward(self, context_vars): return embedding, all_outputs +@register_context_module("transformer") +class TransformerStaticContextModule(BaseContextModule): + """ + Transformer-based static context embedder. + + Each context variable is projected to a token of size embedding_dim, + augmented with a per-variable type embedding, then normalised before + being fed into a shared Transformer encoder. Mean-pooling across the + variable tokens produces the final (B, embedding_dim) conditioning vector. + + Compared to MLPContextModule: + - Attention captures interactions between context variables + - pre-LN (norm_first=True) and GELU throughout → more stable gradients + - No hardcoded bottleneck; width is controlled by dim_feedforward + """ + + def __init__( + self, + context_vars: dict[str, list], + embedding_dim: int, + n_heads: int = 4, + n_layers: int = 2, + dropout: float = 0.1, + dim_feedforward: int = 256, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.continuous_vars = [k for k, v in context_vars.items() if v[0] == "continuous"] + self.categorical_vars = {k: v[1] for k, v in context_vars.items() if v[0] == "categorical"} + self.var_names = list(self.categorical_vars.keys()) + self.continuous_vars + + self.cat_embeddings = nn.ModuleDict({ + name: nn.Embedding(n_cats, embedding_dim) + for name, n_cats in self.categorical_vars.items() + }) + self.cont_projections = nn.ModuleDict({ + name: nn.Linear(1, embedding_dim) + for name in self.continuous_vars + }) + # Per-variable learnable offset so attention can distinguish variable identity + self.type_embeddings = nn.Embedding(len(self.var_names), embedding_dim) + self.register_buffer("_var_indices", torch.arange(len(self.var_names))) + + # Normalise tokens before encoder to equalize scales across embed/projection types + self.token_norm = nn.LayerNorm(embedding_dim) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=n_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + self.output_norm = nn.LayerNorm(embedding_dim) + + self.classification_heads = nn.ModuleDict({ + name: nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.GELU(), + nn.Linear(embedding_dim, n_cats), + ) + for name, n_cats in self.categorical_vars.items() + }) + self.regression_heads = nn.ModuleDict({ + name: nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.GELU(), + nn.Linear(embedding_dim, 1), + ) + for name in self.continuous_vars + }) + + self._init_weights() + + def _init_weights(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + nn.init.normal_(self.type_embeddings.weight, std=0.02) + + def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + all_type_embs = self.type_embeddings(self._var_indices) # (N_vars, D) + tokens = [] + for i, name in enumerate(self.var_names): + if name in self.categorical_vars: + idx = context_vars[name] + if idx.dtype in (torch.long, torch.int, torch.int32, torch.int64): + idx = idx.clamp(0, self.cat_embeddings[name].num_embeddings - 1) + tok = self.cat_embeddings[name](idx) + all_type_embs[i] + else: + val = context_vars[name].float() + if val.dim() == 1: + val = val.unsqueeze(-1) + tok = self.cont_projections[name](val) + all_type_embs[i] + tokens.append(self.token_norm(tok)) + + x = torch.stack(tokens, dim=1) # (B, N_vars, D) + x = self.encoder(x) # (B, N_vars, D) + embedding = self.output_norm(x.mean(dim=1)) # (B, D) + + classification_logits = { + name: head(embedding) + for name, head in self.classification_heads.items() + if name in context_vars + } + regression_outputs = { + name: head(embedding).squeeze(-1) + for name, head in self.regression_heads.items() + if name in context_vars + } + return embedding, {**classification_logits, **regression_outputs} + + @register_context_module("dynamic_cnn") class DynamicContextModule_CNN(BaseContextModule): """ diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index 7ed194f..5006bfa 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -788,7 +788,7 @@ def __init__( condition_dim=condition_dim, has_dynamic_ctx=has_dynamic_ctx, ) - for _ in range(n_layer) + for _ in range(n_layer) ] ) diff --git a/cents/trainer.py b/cents/trainer.py index 923cf99..a8797b9 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -438,6 +438,14 @@ def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) if (epoch + 1) % self.every_n_epochs != 0: return + # Checkpoint saving must happen on all ranks — DDP collects state across processes. + ckpt_path = str(self._fid_ckpt_dir / f"fid_epoch={epoch:04d}.ckpt") + trainer.save_checkpoint(ckpt_path) + + # Everything else (generation, FID, logging, pruning) only on rank 0. + if not trainer.is_global_zero: + return + import numpy as np import torch from cents.eval.eval_metrics import Context_FID @@ -516,15 +524,12 @@ def _inv(df): fid = Context_FID(real_data_array, syn_data_array) print(f"[IntermediateFID] Epoch {epoch + 1}: Context-FID = {fid:.4f}") - # Save checkpoint for this FID-check epoch, then prune - ckpt_path = str(self._fid_ckpt_dir / f"fid_epoch={epoch:04d}.ckpt") - trainer.save_checkpoint(ckpt_path) self._fid_records.append((fid, epoch, ckpt_path)) self._prune_fid_checkpoints() self._log_csv(epoch + 1, fid) self._log_wandb(trainer, epoch, fid) - pl_module.log("intermediate_context_fid", fid, on_step=False, on_epoch=True, prog_bar=True) + pl_module.log("intermediate_context_fid", fid, on_step=False, on_epoch=True, prog_bar=True, sync_dist=False) # ------------------------------------------------------------------ # Checkpoint pruning diff --git a/scripts/generate.py b/scripts/generate.py index 93b43b0..9d88682 100644 --- a/scripts/generate.py +++ b/scripts/generate.py @@ -24,26 +24,40 @@ from cents.datasets.pecanstreet import PecanStreetDataset from cents.datasets.commercial import CommercialDataset from cents.datasets.airquality import AirQualityDataset +from cents.datasets.walmart import WalmartDataset from cents.datasets.utils import convert_generated_data_to_df -from cents.utils.config_loader import load_yaml +from cents.utils.config_loader import load_yaml, apply_overrides from cents.utils.utils import set_context_config_path, set_context_overrides +CONFIG_DATASET_DIR = Path(__file__).resolve().parent.parent / "cents" / "config" / "dataset" + logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) -DATASET_OVERRIDES = ["max_samples=10000", "skip_heavy_processing=True"] +DATASET_OVERRIDES = ["normalize=False", "max_samples=10000", "skip_heavy_processing=True"] PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] -def _load_dataset(name: str, overrides: list): +def _load_dataset_config(dataset_name: str, overrides: list) -> OmegaConf: + config_path = CONFIG_DATASET_DIR / f"{dataset_name}.yaml" + cfg = load_yaml(str(config_path)) + if overrides: + cfg = apply_overrides(cfg, overrides) + return cfg + + +def _load_dataset(name: str, dataset_cfg: OmegaConf): + kwargs = {"cfg": dataset_cfg} if name == "pecanstreet": - return PecanStreetDataset(overrides=DATASET_OVERRIDES + PECAN_OVERRIDES + (overrides or [])) + return PecanStreetDataset(**kwargs) if name == "commercial": - return CommercialDataset(overrides=DATASET_OVERRIDES + (overrides or [])) + return CommercialDataset(**kwargs) if name == "airquality": - return AirQualityDataset(overrides=DATASET_OVERRIDES + (overrides or [])) - raise ValueError(f"Dataset {name} not supported. Use: pecanstreet, commercial, airquality.") + return AirQualityDataset(**kwargs) + if name == "walmart": + return WalmartDataset(**kwargs) + raise ValueError(f"Dataset {name} not supported. Use: pecanstreet, commercial, airquality, walmart.") def main() -> None: @@ -73,7 +87,7 @@ def main() -> None: "--dataset", type=str, default="pecanstreet", - choices=("pecanstreet", "commercial", "airquality"), + choices=("pecanstreet", "commercial", "airquality", "walmart"), help="Dataset name (must match the one used to train the model).", ) parser.add_argument( @@ -136,6 +150,27 @@ def main() -> None: action="store_true", help="Disable EMA sampling (EMA is used by default when present in the checkpoint).", ) + parser.add_argument( + "--stochastic-round", + action="store_true", + help=( + "Round output to non-negative integers using stochastic rounding " + "(floor + Bernoulli on fractional part). Applied after inverse-transform. " + "Useful for count-valued datasets such as Walmart unit sales." + ), + ) + parser.add_argument( + "--model-config", + type=str, + default=None, + help="Path to a model config YAML file. Overrides the default cents/config/model/{model_type}.yaml.", + ) + parser.add_argument( + "--dataset-config", + type=str, + default=None, + help="Path to a dataset config YAML file. Overrides the default cents/config/dataset/{dataset}.yaml.", + ) args = parser.parse_args() use_random = args.random_context @@ -153,12 +188,22 @@ def main() -> None: if args.context_overrides: set_context_overrides(args.context_overrides) - overrides = list(args.dataset_overrides) if args.dataset_overrides else [] + base_overrides = list(DATASET_OVERRIDES) + if args.dataset == "pecanstreet": + base_overrides += PECAN_OVERRIDES + if args.dataset_overrides: + base_overrides += list(args.dataset_overrides) logging.info("Loading dataset %s...", args.dataset) - dataset = _load_dataset(args.dataset, overrides) + if args.dataset_config: + dataset_cfg = OmegaConf.load(args.dataset_config) + dataset_cfg = apply_overrides(dataset_cfg, base_overrides) + else: + dataset_cfg = _load_dataset_config(args.dataset, base_overrides) + dataset = _load_dataset(args.dataset, dataset_cfg) cfg = OmegaConf.create({}) - cfg.model = load_yaml(Path("cents/config/model") / f"{args.model_type}.yaml") + model_config_path = args.model_config if args.model_config else f"cents/config/model/{args.model_type}.yaml" + cfg.model = OmegaConf.create(OmegaConf.to_container(OmegaConf.load(model_config_path), resolve=True)) cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) cfg.model.use_ema_sampling = not args.no_ema @@ -221,6 +266,16 @@ def main() -> None: else: logging.warning("No normalizer loaded; outputs are in normalized space.") + if args.stochastic_round: + def _stochastic_round(x): + x = np.clip(x, 0, None) + floor = np.floor(x).astype(int) + return floor + (np.random.random(x.shape) < (x - floor)).astype(int) + + col_name = dataset_cfg.time_series_columns[0] + df[col_name] = df[col_name].apply(_stochastic_round) + logging.info("Applied stochastic rounding to integer counts.") + out = Path(args.out) out.parent.mkdir(parents=True, exist_ok=True) df.to_parquet(out, index=False) From e442ddfc7d223fd560a3a8ab842aec0964fe315e Mon Sep 17 00:00:00 2001 From: Pieter Feenstra Date: Mon, 4 May 2026 13:58:03 -0400 Subject: [PATCH 50/50] Removed unused code, removes nan checks, removed print statements --- cents/config/context/default.yaml | 6 - cents/config/dataset/airquality.yaml | 2 +- cents/config/evaluator/default.yaml | 1 + cents/config/model/diffusion_ts.yaml | 4 +- cents/config/trainer/diffusion_ts.yaml | 1 - cents/config/trainer/normalizer.yaml | 2 +- cents/data_generator.py | 7 +- cents/datasets/airquality.py | 35 +--- cents/datasets/metraq.py | 43 +---- cents/datasets/timeseries_dataset.py | 6 - cents/datasets/utils.py | 13 +- cents/datasets/walmart.py | 25 +-- cents/eval/eval.py | 8 +- cents/eval/eval_metrics.py | 164 +++++++++++++++---- cents/eval/eval_utils.py | 117 -------------- cents/eval/predictive_score.py | 101 +++++++----- cents/models/context.py | 188 +--------------------- cents/models/diffusion_ts.py | 211 +++---------------------- cents/models/model_utils.py | 53 ------- cents/models/normalizer.py | 57 ++----- 20 files changed, 265 insertions(+), 779 deletions(-) diff --git a/cents/config/context/default.yaml b/cents/config/context/default.yaml index 58ab35d..e8fb2a7 100644 --- a/cents/config/context/default.yaml +++ b/cents/config/context/default.yaml @@ -1,7 +1,6 @@ # Context configuration # This file defines the context modules used across the codebase -# Static context: used by generative models (ACGAN, Diffusion_TS) for conditioning static_context: type: mlp # Options: "mlp", "sep_mlp", "transformer" # TransformerStaticContextModule hyperparameters (ignored by mlp/sep_mlp): @@ -13,14 +12,9 @@ static_context: # Normalizer: stats head configuration for the normalizer normalizer: stats_head_type: mlp # Stats head type (e.g., "mlp") - # Future parameters can be added here: n_layers: 5 # hidden_dim: 512 # Dynamic context: context module used by the normalizer for time series context variables dynamic_context: type: null # Context module type for dynamic context (e.g., "cnn") - # Future parameters can be added here: - # n_layers: 2 - # hidden_dim: 256 - diff --git a/cents/config/dataset/airquality.yaml b/cents/config/dataset/airquality.yaml index 8088cba..3ce51c2 100644 --- a/cents/config/dataset/airquality.yaml +++ b/cents/config/dataset/airquality.yaml @@ -18,7 +18,7 @@ normalizer_group_vars: ["station"] # Targets (what becomes the merged "timeseries" dims) # NOTE: use PMcoarse instead of PM10 -time_series_columns: ["PM2.5" ] # "SO2", "NO2", "CO"] +time_series_columns: ["PM2.5"] # Raw CSV columns to load # Keep wd/WSPM because we need them to engineer wind_u/wind_v diff --git a/cents/config/evaluator/default.yaml b/cents/config/evaluator/default.yaml index b7d1523..f39eca1 100644 --- a/cents/config/evaluator/default.yaml +++ b/cents/config/evaluator/default.yaml @@ -4,6 +4,7 @@ dataset: name: commercial # Set this to your dataset name (e.g., "commercial") eval_pv_shift: False eval_metrics: True +pred_score_trtr: True # If True, also trains on real data (TRTR) and reports MAE delta alongside TSTR MAE eval_context_sparse: True save_results: False eval_disentanglement: True diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 1de3cbf..911005c 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -27,7 +27,7 @@ reg_weight: null gradient_accumulate_every: 2 ema_decay: 0.999 ema_update_interval: 1 -use_ema_sampling: True +use_ema_sampling: False k_bins: 20 # Reconstruction-guided sampling (Algorithms 1 & 2) recon_guide_eta: 0.1 # gradient scale for guidance @@ -40,4 +40,4 @@ recon_cond_len: null # int or null; if set, use fc_a / fc_b for first vs context_embed_dropout: 0 # probability of zeroing entire context embedding per sample (CFG-compatible) blue_noise_power: 0.0 # 0.0 = white noise, 1.0 = blue noise, 2.0 = violet noise # When true and time_series_dims > 1, noise is correlated across dimensions (same draw per timestep). -correlated_noise: True \ No newline at end of file +correlated_noise: False \ No newline at end of file diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 44dd34a..1fb90f2 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -9,7 +9,6 @@ max_epochs: 5000 base_lr: 1e-4 warmup_epochs: 100 # linear warmup from 1% to 100% of base_lr over first N epochs eval_after_training: False -# gradient_clip_val: 1.0 checkpoint: save_last: True # Save final model diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index ad08d89..f34db7f 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,6 +1,6 @@ strategy: ddp_find_unused_parameters_true accelerator: gpu -devices: 0, +devices: 1, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 diff --git a/cents/data_generator.py b/cents/data_generator.py index 883e769..30f1ae8 100644 --- a/cents/data_generator.py +++ b/cents/data_generator.py @@ -151,7 +151,7 @@ def set_context(self, auto_fill_missing: bool = False, **context_vars: Union[int context_vars[var] = random.uniform(0.0, 1.0) else: context_vars.setdefault(var, random.randrange(n)) - else: + elif context_vars: missing = set(required) - set(context_vars) if missing: raise ValueError(f"Missing context vars: {missing}") @@ -193,11 +193,9 @@ def generate(self, n: int = 128, stochastic_round: bool = False) -> "pd.DataFram raise RuntimeError( "No model loaded. Call `load_from_checkpoint(...)` first." ) - if not self._ctx_buff: - raise RuntimeError("No context set – call `set_context()` first.") ctx_batch = {k: v.repeat(n) for k, v in self._ctx_buff.items()} - ts = self.model.generate(ctx_batch) + ts = self.model.generate(ctx_batch, n=n) df = convert_generated_data_to_df(ts, self._ctx_buff, decode=False) df = self.normalizer.inverse_transform(df) if self.normalizer else df @@ -238,7 +236,6 @@ def load_from_checkpoint( ckpt_path, state = self._resolve_ckpt(model_ckpt) ModelCls = get_model_cls(self.model_type) - print(self.cfg) if ckpt_path.suffix == ".ckpt": print(f"[Cents] Loading model from checkpoint: {ckpt_path}") diff --git a/cents/datasets/airquality.py b/cents/datasets/airquality.py index 7b81f85..0f9f9eb 100644 --- a/cents/datasets/airquality.py +++ b/cents/datasets/airquality.py @@ -7,6 +7,7 @@ import pandas as pd from omegaconf import DictConfig from cents.utils.config_loader import load_yaml, apply_overrides +from cents.datasets.utils import is_all_nan, is_any_nan, fill_with_row_mean from cents.datasets.timeseries_dataset import TimeSeriesDataset @@ -115,22 +116,16 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: wd_clean = data["wd"].astype(str).str.strip().str.upper() wd_deg = wd_clean.map(wd_deg_map) - # indicator can help if any weird labels slip in data["wd_valid"] = wd_deg.notna().astype(np.int8) theta = np.deg2rad(wd_deg.fillna(0.0).to_numpy(dtype=float)) wspm = pd.to_numeric(data["WSPM"], errors="coerce").fillna(0.0) - # u = speed * cos(theta), v = speed * sin(theta) - # (note: choice of axes is arbitrary here; consistency matters more than convention) data["wind_u"] = wspm * np.cos(theta) data["wind_v"] = wspm * np.sin(theta) - # Drop raw wind columns after engineering data.drop(columns=["wd", "WSPM"], inplace=True) else: - # If one is missing, don't silently create nonsense - # You can choose to raise instead if this should never happen. if "wd" in data.columns: data.drop(columns=["wd"], inplace=True) if "WSPM" in data.columns: @@ -147,38 +142,30 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: # ------------------------- # Choose time-series columns # ------------------------- - # Context TS columns come from cfg; targets come from cfg.time_series_columns ctx_ts = list(self.context_series_names) tgt_ts = list(self.target_time_series_columns) - # Replace context wind variables: remove wd/WSPM if present, add wind_u/wind_v (+wd_valid if you want) ctx_ts = [c for c in ctx_ts if c not in ("wd", "WSPM")] for c in ("wind_u", "wind_v"): if c in data.columns and c not in ctx_ts: ctx_ts.append(c) - # optional if "wd_valid" in data.columns and "wd_valid" not in ctx_ts: ctx_ts.append("wd_valid") - # Replace PM10 target with PMcoarse if PM10 is in targets if "PMcoarse" in data.columns: tgt_ts = ["PMcoarse" if c == "PM10" else c for c in tgt_ts] - # (optional) if you *also* want to drop PM10 if it still exists tgt_ts = [c for c in tgt_ts if c != "PM10"] ts_cols = ctx_ts + tgt_ts - # Ensure all ts_cols exist missing = [c for c in ts_cols if c not in data.columns] if missing: raise ValueError(f"Missing required time-series columns after preprocessing: {missing}") - # Sort data = data.sort_values(["station", "year", "month", "day", "hour"]) - # Month name mapping (keeps your categorical month encoding behavior) months = [ "January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December", @@ -192,25 +179,20 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: .agg({c: list for c in ts_cols}) ) - # lists -> numpy arrays for c in ts_cols: grouped[c] = grouped[c].map(np.asarray) - # Keep only full-length sequences - # Use the first target if possible, else fall back to first ts col len_col = tgt_ts[0] if len(tgt_ts) > 0 else ts_cols[0] grouped = grouped[grouped[len_col].apply(len) == self.cfg.seq_len].reset_index(drop=True) grouped = self._handle_missing_data(grouped) ctx_numeric = [c for c in ctx_ts if c not in self.categorical_time_series] - # Optional: handle heavy-tailed / zero-inflated channels log1p_channels = {"RAIN"} # add more if needed clip_bound = 5.0 eps = 1e-8 - # Compute global mean/std per channel over all rows and timesteps ctx_stats = {} for c in ctx_numeric: # stacked shape: (N, L) @@ -230,10 +212,8 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: grouped[c] = list(Xn) - # (Optional) store for later inverse-transform / debugging self.context_ts_stats_ = ctx_stats - # arrays -> tuples (hashable) for c in ts_cols: grouped[c] = grouped[c].map(tuple) @@ -249,15 +229,12 @@ def _handle_missing_data(self, data): for col in numeric_series: data[col] = data[col].apply(fill_with_row_mean) - # categorical time series must have no NaNs cat_cols = list(self.categorical_time_series.keys()) if cat_cols: mask = data[cat_cols].applymap(is_any_nan).any(axis=1) data = data[~mask] - # ensure no NaNs in target series columns for tcol in self.target_time_series_columns: - # If you replaced PM10->PMcoarse in cfg, this remains correct if tcol in data.columns: data = data.loc[data[tcol].apply(lambda x: not np.isnan(np.asarray(x, dtype=float)).any())] @@ -278,13 +255,3 @@ def row_has_low_std(row, cols, thresh=0.01): -def is_all_nan(arr): - return pd.isna(arr).all() - -def is_any_nan(arr): - return pd.isna(arr).any() - -def fill_with_row_mean(lst): - s = pd.Series(lst, dtype=float) - m = s.mean(skipna=True) - return s.fillna(m).tolist() diff --git a/cents/datasets/metraq.py b/cents/datasets/metraq.py index 59bbbe8..a9bc2f6 100644 --- a/cents/datasets/metraq.py +++ b/cents/datasets/metraq.py @@ -8,6 +8,7 @@ from cents.utils.config_loader import load_yaml, apply_overrides from cents.datasets.timeseries_dataset import TimeSeriesDataset +from cents.datasets.utils import is_all_nan, is_any_nan, fill_with_row_mean warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -132,21 +133,15 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: if "PMcoarse" in data.columns: tgt_ts = ["PMcoarse" if c == "PM10" else c for c in tgt_ts] - # Decompose circular wind direction into Cartesian components so z-score - # normalization is meaningful. WD=355° and WD=5° are 10° apart but would - # get opposite signs after z-scoring — wind_u/wind_v avoids this. if "WD" in data.columns and "WS" in data.columns: wd_deg = pd.to_numeric(data["WD"], errors="coerce") ws = pd.to_numeric(data["WS"], errors="coerce").clip(lower=0.0) - # Binary mask: 1 where WD is measured, 0 where it is missing data["wd_valid"] = wd_deg.notna().astype(np.int8) wd_deg = wd_deg.fillna(0.0) ws = ws.fillna(0.0) wd_rad = np.deg2rad(wd_deg) data["wind_u"] = ws * np.sin(wd_rad) data["wind_v"] = ws * np.cos(wd_rad) - # Replace WS/WD in ctx_ts with wind_u/wind_v (handles legacy configs that - # listed WS/WD; current config lists wind_u/wind_v/wd_valid directly). ctx_ts = [ "wind_u" if c == "WS" else "wind_v" if c == "WD" else c for c in ctx_ts @@ -160,12 +155,7 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: data = data.sort_values(["sensor_name", "timestamp"]) - # print(data) - group_keys = ["sensor_name", "year", "month", "day", "weekday"] - - # Continuous (scalar) context vars are constant per station — carry them through - # with "first" so they survive the groupby without being collapsed into lists. static_continuous_cols = [ k for k, v in self.cfg.context_vars.items() if v[0] == "continuous" and k in data.columns @@ -178,7 +168,6 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: .agg(agg_dict) ) - # print(grouped) for c in ts_cols: grouped[c] = grouped[c].map(np.asarray) @@ -188,22 +177,12 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: grouped = self._handle_missing_data(grouped) - # print("POST CLEAN") - # print(grouped) - ctx_numeric = [c for c in ctx_ts if c not in self.categorical_time_series] - - # TI (traffic intensity) is strictly non-negative — log1p compresses the - # heavy right tail (rush-hour spikes) before z-scoring. log1p_channels = {"TI"} binary_channels = {"wd_valid"} # already in [0, 1] — skip z-scoring clip_bound = 5.0 eps = 1e-8 - # Per-station z-score normalization: compute (mu, sd) separately for each - # sensor_name so the model sees locally-relative deviations. A global - # z-score would conflate cross-station level differences with within-station - # variation, obscuring the context–target relationship the model needs to learn. ctx_stats = {} # {channel: {sensor_name: (mu, sd)}} for c in ctx_numeric: if c in binary_channels: @@ -218,7 +197,6 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: normalized = col_arrays.copy() for stn, idx in grouped.groupby("sensor_name").groups.items(): - # idx contains label-based indices (not positional) — use .loc X = np.stack(col_arrays.loc[idx].values).astype(np.float32) mu = float(X.mean()) sd = float(X.std()) @@ -231,10 +209,8 @@ def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: grouped[c] = list(normalized) - # Store for later inverse-transform / debugging self.context_ts_stats_ = ctx_stats - # arrays -> tuples (hashable) for c in ts_cols: grouped[c] = grouped[c].map(tuple) @@ -249,15 +225,12 @@ def _handle_missing_data(self, data): for col in numeric_series: data[col] = data[col].apply(fill_with_row_mean) - # categorical time series must have no NaNs cat_cols = list(self.categorical_time_series.keys()) if cat_cols: mask = data[cat_cols].applymap(is_any_nan).any(axis=1) data = data[~mask] - # ensure no NaNs in target series columns for tcol in self.target_time_series_columns: - # If you replaced PM10->PMcoarse in cfg, this remains correct if tcol in data.columns: data = data.loc[data[tcol].apply(lambda x: not np.isnan(np.asarray(x, dtype=float)).any())] @@ -274,16 +247,4 @@ def row_has_low_std(row, cols, thresh=0.01): ) data = data[~mask] - return data - - -def is_all_nan(arr): - return pd.isna(arr).all() - -def is_any_nan(arr): - return pd.isna(arr).any() - -def fill_with_row_mean(lst): - s = pd.Series(lst, dtype=float) - m = s.mean(skipna=True) - return s.fillna(m).tolist() \ No newline at end of file + return data \ No newline at end of file diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 2e68f05..9e6940c 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -107,9 +107,6 @@ def __init__( # Store categorical time series info self.categorical_time_series = categorical_time_series or {} - # if self.scale: - # assert self.normalize, "Normalization must be enabled if scaling is enabled" - # Preprocess and optionally encode context self.data = self._preprocess_data(data) @@ -126,9 +123,6 @@ def __init__( self.context_cfg = get_context_config() self.dynamic_module_type = self.context_cfg.dynamic_context.type - # Normalizer uses its own context type (defaults to mlp) so that switching the - # diffusion model to a heavier static embedder (e.g. transformer) doesn't affect - # the much simpler normalizer training. self.static_module_type = getattr(self.context_cfg.normalizer, "context_type", "mlp") self.stats_head_type = self.context_cfg.normalizer.stats_head_type diff --git a/cents/datasets/utils.py b/cents/datasets/utils.py index 45d7884..bf615cb 100644 --- a/cents/datasets/utils.py +++ b/cents/datasets/utils.py @@ -278,4 +278,15 @@ def encode_list_column(series: pd.Series): encoded = series.apply(lambda x: [tok2id[t] for t in x]) encoded = encoded.map(tuple) mapping = dict(enumerate(vocab)) # id -> token - return encoded, mapping \ No newline at end of file + return encoded, mapping + +def is_all_nan(arr): + return pd.isna(arr).all() + +def is_any_nan(arr): + return pd.isna(arr).any() + +def fill_with_row_mean(lst): + s = pd.Series(lst, dtype=float) + m = s.mean(skipna=True) + return s.fillna(m).tolist() diff --git a/cents/datasets/walmart.py b/cents/datasets/walmart.py index 52a4837..99fda3e 100644 --- a/cents/datasets/walmart.py +++ b/cents/datasets/walmart.py @@ -9,6 +9,7 @@ from cents.utils.config_loader import load_yaml, apply_overrides from cents.datasets.timeseries_dataset import TimeSeriesDataset +from cents.datasets.utils import is_all_nan, is_any_nan, fill_with_row_mean warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -103,10 +104,6 @@ def _load_data(self) -> None: vel_threshold = day_means.quantile(0.80) keep_ids = set(sales.loc[day_means >= vel_threshold, "id"].values) - # --- Early ID subsampling when max_samples is set --- - # Each high-velocity ID produces ~60 monthly windows across the ~5-year dataset. - # Keeping 2× the IDs needed gives a safe buffer for windows dropped during - # preprocessing (missing data, near-constant sequences, etc.). max_samples = self.cfg.get("max_samples", None) if max_samples is not None and len(keep_ids) > 0: est_windows_per_id = 60 @@ -117,7 +114,6 @@ def _load_data(self) -> None: n_ids, n_ids * est_windows_per_id, max_samples, ) - # Filter BEFORE the melt — dramatically reduces the wide→long expansion cost sales = sales[sales["id"].isin(keep_ids)] sales_long = sales[meta_cols + day_cols].melt( @@ -146,7 +142,6 @@ def _load_data(self) -> None: .transform(lambda x: x.ffill().bfill()) ) - # --- State-specific SNAP eligibility (vectorised) --- sales_long["snap"] = np.where( sales_long["state_id"] == "CA", sales_long["snap_CA"], np.where( @@ -155,20 +150,16 @@ def _load_data(self) -> None: ), ).astype(np.int8) - # --- Binary calendar-event indicator --- sales_long["event_binary"] = sales_long["event_name_1"].notna().astype(np.int8) - # --- Month name (consistent with other datasets) --- sales_long["month"] = sales_long["month"].map(lambda x: _MONTHS[int(x) - 1]) - # --- Weekday as integer 0 (Mon) – 6 (Sun) for use as dynamic context --- _WEEKDAY_MAP = { "Monday": 0, "Tuesday": 1, "Wednesday": 2, "Thursday": 3, "Friday": 4, "Saturday": 5, "Sunday": 6, } sales_long["weekday"] = sales_long["weekday"].map(_WEEKDAY_MAP).astype(np.int8) - # Drop columns that are no longer needed after engineering sales_long.drop( columns=["snap_CA", "snap_TX", "snap_WI", "event_name_1", "wm_yr_wk", "d", "item_id"], @@ -290,17 +281,3 @@ def _low_std(row, cols, thresh=0.01): mask = data.apply(lambda row: _low_std(row, self.target_time_series_columns), axis=1) data = data[~mask] return data - - -def is_all_nan(arr): - return pd.isna(arr).all() - - -def is_any_nan(arr): - return pd.isna(arr).any() - - -def fill_with_row_mean(lst): - s = pd.Series(lst, dtype=float) - m = s.mean(skipna=True) - return s.fillna(m).tolist() diff --git a/cents/eval/eval.py b/cents/eval/eval.py index 045ec1f..5dab1e9 100644 --- a/cents/eval/eval.py +++ b/cents/eval/eval.py @@ -17,6 +17,7 @@ from cents.eval.discriminative_score import discriminative_score_metrics from cents.eval.eval_metrics import ( Context_FID, + calculate_banded_mse, calculate_mmd, compute_cfs, compute_context_recovery_score, @@ -223,6 +224,10 @@ def compute_quality_metrics( metrics["MMD"] = {"mean": mmd_mean, "std": mmd_std} logger.info(f"[Cents] MMD completed") + banded_mse = calculate_banded_mse(real_data, syn_data) + metrics["Banded_MSE"] = banded_mse + logger.info(f"[Cents] Banded MSE completed") + fid_score = Context_FID(real_data, syn_data) metrics["Context_FID"] = fid_score logger.info(f"[Cents] Context-FID completed") @@ -231,7 +236,8 @@ def compute_quality_metrics( metrics["Disc_Score"] = discr_score logger.info(f"[Cents] Discr Score completed") - pred_score = predictive_score_metrics(real_data, syn_data) + trtr = self.cfg.evaluator.get("pred_score_trtr", False) + pred_score = predictive_score_metrics(real_data, syn_data, trtr=trtr) metrics["Pred_Score"] = pred_score logger.info(f"[Cents] Pred Score completed") diff --git a/cents/eval/eval_metrics.py b/cents/eval/eval_metrics.py index 8b5e846..792d4c0 100644 --- a/cents/eval/eval_metrics.py +++ b/cents/eval/eval_metrics.py @@ -1,22 +1,37 @@ import warnings from functools import partial from itertools import product -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy +import torch +import torch.nn as nn +import torch.optim as optim from dtaidistance import dtw from scipy.stats import f as f_dist from scipy.stats import wasserstein_distance from sklearn.linear_model import LogisticRegression, Ridge from sklearn.metrics import mutual_info_score, r2_score, roc_auc_score -from sklearn.model_selection import StratifiedKFold from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from statsmodels.tsa.tsatools import lagmat +_cfs_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class _CFSDiscriminator(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int): + super().__init__() + self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True) + self.fc = nn.Linear(hidden_dim, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, h_n = self.gru(x) + return self.fc(h_n.squeeze(0)) + from cents.eval.eval_utils import ( gaussian_kernel_matrix, get_period_bounds, @@ -54,6 +69,56 @@ def dynamic_time_warping_dist(X: np.ndarray, Y: np.ndarray) -> Tuple[float, floa return np.mean(dtw_distances), np.std(dtw_distances) +def calculate_banded_mse( + real_data: np.ndarray, + syn_data: np.ndarray, + n_bands: int = 5, +) -> Dict[str, Any]: + """ + Compute MSE between real and synthetic time series broken down by amplitude band. + + Band boundaries are quantile-derived from the real data, so each band contains + an equal number of real observations. Per-band MSE is averaged over samples + that have at least one timestep in the band. + + Args: + real_data: Real time series (N, T, D). + syn_data: Synthetic time series (N, T, D). + n_bands: Number of quantile bands (default 5). + + Returns: + Dict with keys "band_1" … "band_N", each containing + {"mean": float, "std": float, "range": [lo, hi]}. + """ + assert real_data.shape == syn_data.shape, "real_data and syn_data must have the same shape" + N, T, D = real_data.shape + + edges = np.percentile(real_data.reshape(-1), np.linspace(0, 100, n_bands + 1)) + edges[-1] += 1e-8 # make upper bound inclusive for max value + + sq_err = (real_data - syn_data) ** 2 # (N, T, D) + results: Dict[str, Any] = {} + + for i in range(n_bands): + lo, hi = edges[i], edges[i + 1] + in_band = (real_data >= lo) & (real_data < hi) # (N, T, D) + + band_mses = [] + for n in range(N): + mask = in_band[n] # (T, D) + if mask.sum() == 0: + continue + band_mses.append(float(sq_err[n][mask].mean())) + + results[f"band_{i + 1}"] = { + "mean": float(np.mean(band_mses)) if band_mses else float("nan"), + "std": float(np.std(band_mses)) if band_mses else float("nan"), + "range": [round(float(lo), 4), round(float(hi), 4)], + } + + return results + + def calculate_period_bound_mse( real_dataframe: pd.DataFrame, synthetic_timeseries: np.ndarray ) -> Tuple[float, float]: @@ -215,54 +280,87 @@ def compute_cfs( x_real: np.ndarray, x_synth: np.ndarray, c: np.ndarray, - n_folds: int = 5, + iterations: int = 2000, + batch_size: int = 128, + test_ratio: float = 0.2, ) -> float: """ Compute Context Faithfulness Score (CFS). - Trains a classifier to distinguish real (x, c) pairs from synthetic (x, c) pairs - using cross-validation, then returns 2 * |AUROC - 0.5|. + Trains a GRU to distinguish real (x, c) pairs from synthetic (x, c) pairs, + then returns 2 * |AUROC - 0.5|. Args: - x_real: (N, T, D_x) real time series - x_synth: (N, T, D_x) synthetic time series - c: (N, T, D_c) shared context (same for real and synthetic) - n_folds: number of cross-validation folds + x_real: (N, T, D_x) real time series + x_synth: (N, T, D_x) synthetic time series + c: (N, T, D_c) shared context (same for real and synthetic) + iterations: number of training steps + batch_size: training batch size + test_ratio: fraction of data held out for AUROC evaluation Returns: float: CFS in [0, 1]. 0 = indistinguishable (perfect), 1 = fully separable (failed). """ - N = x_real.shape[0] + N, T, _ = x_real.shape - real_pairs = np.concatenate([x_real, c], axis=-1) # (N, T, D_x+D_c) - synth_pairs = np.concatenate([x_synth, c], axis=-1) # (N, T, D_x+D_c) + # Concatenate signal and context along feature dim: (N, T, D_x+D_c) + real_seq = np.concatenate([x_real, c], axis=-1).astype(np.float32) + synth_seq = np.concatenate([x_synth, c], axis=-1).astype(np.float32) - # Mean pool over time → fixed-size vectors - X_real_enc = real_pairs.mean(axis=1) # (N, D_x+D_c) - X_synth_enc = synth_pairs.mean(axis=1) # (N, D_x+D_c) + # Drop samples with NaN + valid = ~(np.isnan(real_seq).any(axis=(1, 2)) | np.isnan(synth_seq).any(axis=(1, 2))) + real_seq, synth_seq = real_seq[valid], synth_seq[valid] + N = len(real_seq) - X_all = np.concatenate([X_real_enc, X_synth_enc], axis=0) # (2N, D) - y_all = np.concatenate([np.ones(N), np.zeros(N)]) # (2N,) - - # Drop rows with NaN - valid = ~np.isnan(X_all).any(axis=1) - X_all = X_all[valid] - y_all = y_all[valid] - - if len(X_all) < 2 * n_folds or len(np.unique(y_all)) < 2: + if N < 4: warnings.warn("compute_cfs: insufficient valid samples; returning nan.") return float("nan") - skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42) - auroc_scores = [] - for train_idx, val_idx in skf.split(X_all, y_all): - clf = LogisticRegression(max_iter=1000, random_state=42) - clf.fit(X_all[train_idx], y_all[train_idx]) - proba = clf.predict_proba(X_all[val_idx])[:, 1] - auroc_scores.append(roc_auc_score(y_all[val_idx], proba)) + # Train/test split + idx = np.random.RandomState(42).permutation(N) + test_n = max(1, int(N * test_ratio)) + train_idx, test_idx = idx[test_n:], idx[:test_n] + + train_real, test_real = real_seq[train_idx], real_seq[test_idx] + train_synth, test_synth = synth_seq[train_idx], synth_seq[test_idx] + + input_dim = real_seq.shape[-1] + hidden_dim = int(input_dim * 2) + model = _CFSDiscriminator(input_dim, hidden_dim).to(_cfs_device) + optimizer = optim.Adam(model.parameters()) + criterion = nn.BCEWithLogitsLoss() + + n_train = len(train_real) + for _ in range(iterations): + idx_r = np.random.randint(0, n_train, batch_size) + idx_s = np.random.randint(0, n_train, batch_size) + + xr = torch.tensor(train_real[idx_r]).to(_cfs_device) + xs = torch.tensor(train_synth[idx_s]).to(_cfs_device) + + optimizer.zero_grad() + logits_r = model(xr) + logits_s = model(xs) + loss = criterion(logits_r, torch.ones_like(logits_r)) + \ + criterion(logits_s, torch.zeros_like(logits_s)) + loss.backward() + optimizer.step() + + with torch.no_grad(): + xr_test = torch.tensor(test_real).to(_cfs_device) + xs_test = torch.tensor(test_synth).to(_cfs_device) + prob_r = torch.sigmoid(model(xr_test)).cpu().numpy().squeeze() + prob_s = torch.sigmoid(model(xs_test)).cpu().numpy().squeeze() + + probs = np.concatenate([prob_r, prob_s]) + labels = np.concatenate([np.ones(len(prob_r)), np.zeros(len(prob_s))]) + + if len(np.unique(labels)) < 2: + warnings.warn("compute_cfs: test set has only one class; returning nan.") + return float("nan") - mean_auroc = float(np.mean(auroc_scores)) - return float(2.0 * abs(mean_auroc - 0.5)) + auroc = float(roc_auc_score(labels, probs)) + return float(2.0 * abs(auroc - 0.5)) def _build_lag_matrix(x: np.ndarray, max_lag: int) -> np.ndarray: diff --git a/cents/eval/eval_utils.py b/cents/eval/eval_utils.py index e082804..5afd759 100644 --- a/cents/eval/eval_utils.py +++ b/cents/eval/eval_utils.py @@ -884,123 +884,6 @@ def create_visualizations( visualizations[f"TSNE_Dim_{i}"] = plot -# def evaluate_pv_shift(self, dataset: Any, model: Any): -# avg_shift = dataset.compute_average_pv_shift() -# if avg_shift is None or np.allclose(avg_shift, 0.0): -# return -# test_contexts = dataset.sample_shift_test_contexts() -# n_sampled = len(test_contexts) -# n_pv1_missing = sum(1 for c in test_contexts if c["missing_pv"] == 1) -# n_pv0_missing = sum(1 for c in test_contexts if c["missing_pv"] == 0) - -# print(f"[Shift Contexts] Sampled: {n_sampled}.") -# print(f"[Shift Contexts] PV=1 is missing in {n_pv1_missing} of these contexts.") -# print(f"[Shift Contexts] PV=0 is missing in {n_pv0_missing} of these contexts.") -# if len(test_contexts) == 0: -# return -# present_ctx_list = [] -# missing_ctx_list = [] -# present_pv_values = [] -# for cinfo in test_contexts: -# base_ctx = cinfo["base_context"] -# present_pv = cinfo["present_pv"] -# missing_pv = cinfo["missing_pv"] -# ctx_p = dict(base_ctx) -# ctx_m = dict(base_ctx) -# ctx_p["has_solar"] = present_pv -# ctx_m["has_solar"] = missing_pv -# present_ctx_list.append(ctx_p) -# missing_ctx_list.append(ctx_m) -# present_pv_values.append(present_pv) -# present_ctx_tensors = {} -# missing_ctx_tensors = {} -# all_keys = present_ctx_list[0].keys() -# for k in all_keys: -# present_ctx_tensors[k] = torch.tensor( -# [pc[k] for pc in present_ctx_list], dtype=torch.long, device=self.device -# ) -# missing_ctx_tensors[k] = torch.tensor( -# [mc[k] for mc in missing_ctx_list], dtype=torch.long, device=self.device -# ) -# with torch.no_grad(): -# syn_ts_present = model.generate(present_ctx_tensors) -# syn_ts_missing = model.generate(missing_ctx_tensors) -# syn_ts_present = syn_ts_present.cpu().numpy() -# syn_ts_missing = syn_ts_missing.cpu().numpy() -# if syn_ts_present.ndim == 3 and syn_ts_present.shape[-1] == 1: -# syn_ts_present = syn_ts_present[:, :, 0] -# syn_ts_missing = syn_ts_missing[:, :, 0] -# shifts = [] -# for i, pv_val in enumerate(present_pv_values): -# shift_i = syn_ts_missing[i] - syn_ts_present[i] -# if pv_val == 1: -# shift_i = -shift_i -# shifts.append(shift_i) -# shifts = np.array(shifts) -# avg_shift = np.asarray(avg_shift).reshape(-1) -# l2_values = [] -# for i in range(shifts.shape[0]): -# diff = shifts[i] - avg_shift -# l2 = np.sqrt((diff**2).sum()) -# l2_values.append(l2) -# mean_l2 = np.mean(l2_values) -# wandb.log({"Shift_L2": mean_l2}) - -# def find_context_matched_shift(dataset, cinfo): -# base_ctx = cinfo["base_context"] -# city_val = base_ctx.get("city", None) -# btype_val = base_ctx.get("building_type", None) -# df = dataset.data.copy() -# mask = pd.Series([True] * len(df)) -# if city_val is not None and "city" in df.columns: -# mask = mask & (df["city"] == city_val) -# if btype_val is not None and "building_type" in df.columns: -# mask = mask & (df["building_type"] == btype_val) -# df_matched = df[mask] -# if df_matched.empty: -# return None -# df_pv0 = df_matched[df_matched["has_solar"] == 0] -# df_pv1 = df_matched[df_matched["has_solar"] == 1] -# if df_pv0.empty or df_pv1.empty: -# return None -# ts_pv0 = np.stack(df_pv0["timeseries"].values, axis=0) -# ts_pv1 = np.stack(df_pv1["timeseries"].values, axis=0) -# mean_pv0 = ts_pv0.mean(axis=0) -# mean_pv1 = ts_pv1.mean(axis=0) -# mean_pv0_dim0 = mean_pv0[:, 0] -# mean_pv1_dim0 = mean_pv1[:, 0] -# real_shift = mean_pv1_dim0 - mean_pv0_dim0 -# return real_shift - -# matched_shifts = [] -# for cinfo in test_contexts: -# matched = find_context_matched_shift(dataset, cinfo) -# matched_shifts.append(matched) -# n_plots = min(6, shifts.shape[0]) -# for j, idx in enumerate( -# np.random.choice(shifts.shape[0], size=n_plots, replace=False) -# ): -# fig, ax = plt.subplots(figsize=(8, 4)) -# ax.plot(avg_shift, label="Real shift", color="red") -# ax.plot(shifts[idx], label="Synthetic shift", color="blue", linestyle="--") -# matched_s = matched_shifts[idx] -# if matched_s is not None: -# ax.plot( -# matched_s, -# label="Context-matched shift", -# color="green", -# linestyle=":", -# ) -# font_size = 12 -# ax.tick_params(axis="both", which="major", labelsize=font_size) -# ax.set_xlabel("Timestep", fontsize=font_size) -# ax.set_ylabel("kWh", fontsize=font_size) -# leg = ax.legend() -# leg.prop.set_size(font_size) -# fig.tight_layout() -# wandb.log({f"ShiftPlot_{j}": wandb.Image(fig)}) -# plt.close(fig) - def flatten_log_dict(d: Dict[str, Any], prefix: str = "") -> Dict[str, float]: """ diff --git a/cents/eval/predictive_score.py b/cents/eval/predictive_score.py index 68b9ae0..657cb5c 100644 --- a/cents/eval/predictive_score.py +++ b/cents/eval/predictive_score.py @@ -42,65 +42,86 @@ def forward(self, x, t): return y_hat -def predictive_score_metrics(ori_data, generated_data): - no, seq_len, dim = ori_data.shape - - ori_time, ori_max_seq_len = extract_time(ori_data) - generated_time, generated_max_seq_len = extract_time(generated_data) - max([ori_max_seq_len, generated_max_seq_len]) - +def _train_predictor(data, time, dim, iterations, batch_size): + """Train a GRU predictor on data and return the model.""" hidden_dim = max(int(dim / 2), 1) - iterations = 5000 - batch_size = 128 - model = Predictor(input_dim=dim, hidden_dim=hidden_dim).to(device) criterion = nn.L1Loss() optimizer = optim.Adam(model.parameters()) - for itt in tqdm( - range(iterations), - desc="[Cents] Training Predictive Score Model", - total=iterations, - ): - idx = np.random.permutation(len(generated_data)) + for _ in tqdm(range(iterations), desc="[Cents] Training Predictive Score Model", total=iterations): + idx = np.random.permutation(len(data)) train_idx = idx[:batch_size] - X_mb = [ - generated_data[i][:-1, :] for i in train_idx - ] # Use all dimensions for input - T_mb = [max(generated_time[i] - 1, 1) for i in train_idx] - Y_mb = [ - generated_data[i][1:, :].reshape(-1, dim) for i in train_idx - ] # Predict all dimensions - - X_mb = torch.tensor(np.array(X_mb), dtype=torch.float32).to(device) - T_mb = torch.tensor(np.array(T_mb), dtype=torch.int64).to(device) - Y_mb = torch.tensor(np.array(Y_mb), dtype=torch.float32).to(device) + X_mb = torch.tensor(np.array([data[i][:-1, :] for i in train_idx]), dtype=torch.float32).to(device) + T_mb = torch.tensor(np.array([max(time[i] - 1, 1) for i in train_idx]), dtype=torch.int64).to(device) + Y_mb = torch.tensor(np.array([data[i][1:, :].reshape(-1, dim) for i in train_idx]), dtype=torch.float32).to(device) optimizer.zero_grad() - y_pred = model(X_mb, T_mb) - loss = criterion(y_pred, Y_mb) + loss = criterion(model(X_mb, T_mb), Y_mb) loss.backward() optimizer.step() - X_mb = [ori_data[i][:-1, :] for i in range(no)] - T_mb = [max(ori_time[i] - 1, 1) for i in range(no)] - Y_mb = [ori_data[i][1:, :].reshape(-1, dim) for i in range(no)] + return model - X_mb = torch.tensor(np.array(X_mb), dtype=torch.float32).to(device) - T_mb = torch.tensor(np.array(T_mb), dtype=torch.int64).to(device) - Y_mb = torch.tensor(np.array(Y_mb), dtype=torch.float32).to(device) + +def _eval_mae(model, data, time, dim): + """Evaluate a predictor's MAE on data.""" + no = len(data) + X_mb = torch.tensor(np.array([data[i][:-1, :] for i in range(no)]), dtype=torch.float32).to(device) + T_mb = torch.tensor(np.array([max(time[i] - 1, 1) for i in range(no)]), dtype=torch.int64).to(device) + Y_mb = torch.tensor(np.array([data[i][1:, :].reshape(-1, dim) for i in range(no)]), dtype=torch.float32).to(device) with torch.no_grad(): y_pred = model(X_mb, T_mb) - MAE_temp = 0 - for i in range(no): - MAE_temp += mean_absolute_error(Y_mb[i].cpu().numpy(), y_pred[i].cpu().numpy()) + mae = sum( + mean_absolute_error(Y_mb[i].cpu().numpy(), y_pred[i].cpu().numpy()) + for i in range(no) + ) / no + return mae + + +def predictive_score_metrics(ori_data, generated_data, trtr: bool = False): + """ + Compute predictive score. + + Trains a GRU predictor on synthetic data (TSTR) and evaluates one-step-ahead + MAE on real data. When trtr=True, also trains on real data and returns the MAE + delta (TSTR_MAE - TRTR_MAE) alongside both individual MAEs. + + Args: + ori_data: Real time series (N, T, D). + generated_data: Synthetic time series (N, T, D). + trtr: If True, additionally train on real data for TRTR comparison. + + Returns: + float if trtr=False: TSTR MAE. + dict if trtr=True: {"tstr_mae": float, "trtr_mae": float, "delta": float}. + """ + no, seq_len, dim = ori_data.shape + ori_time, _ = extract_time(ori_data) + generated_time, _ = extract_time(generated_data) + + iterations = 5000 + batch_size = 128 + + # Train on synthetic, test on real (TSTR) + synth_model = _train_predictor(generated_data, generated_time, dim, iterations, batch_size) + tstr_mae = _eval_mae(synth_model, ori_data, ori_time, dim) + + if not trtr: + return tstr_mae - predictive_score = MAE_temp / no + # Train on real, test on real (TRTR) + real_model = _train_predictor(ori_data, ori_time, dim, iterations, batch_size) + trtr_mae = _eval_mae(real_model, ori_data, ori_time, dim) - return predictive_score + return { + "tstr_mae": tstr_mae, + "trtr_mae": trtr_mae, + "delta": tstr_mae - trtr_mae, + } def extract_time(data): diff --git a/cents/models/context.py b/cents/models/context.py index 06a0e12..e843024 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -356,176 +356,6 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, return embedding, {**classification_logits, **regression_outputs} -@register_context_module("dynamic_cnn") -class DynamicContextModule_CNN(BaseContextModule): - """ - Context module for processing dynamic (time series) context variables. - Uses 1D convolutions to encode time series sequences into embeddings. - """ - - def __init__( - self, - context_vars: dict[str, int], - embedding_dim: int, - seq_len: int = None, - ): - """ - Initialize DynamicContextModule. - - Args: - context_vars: Mapping of variable names to category counts (for categorical time series) - or None (for numeric time series). Format: {name: [type, num_categories]} - embedding_dim: Size of embedding vectors. - seq_len: Sequence length of time series context variables. - """ - super().__init__() - self.embedding_dim = embedding_dim - - # Separate categorical and numeric time series - self.categorical_ts_vars = { - k: v[1] for k, v in context_vars.items() - if v[0] == "time_series" and v[1] is not None - } - self.numeric_ts_vars = [ - k for k, v in context_vars.items() - if v[0] == "time_series" and v[1] is None - ] - - # For categorical time series, use embedding + CNN - self.ts_embeddings = nn.ModuleDict({ - name: nn.Embedding(num_categories, embedding_dim) - for name, num_categories in self.categorical_ts_vars.items() - }) - - # CNN encoders for each time series variable - # For categorical: input is (batch, seq_len) -> embedding -> (batch, seq_len, emb_dim) -> CNN - # For numeric: input is (batch, seq_len) -> CNN - self.ts_encoders = nn.ModuleDict() - - for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars: - # 1D CNN to encode time series: (batch, channels, seq_len) -> (batch, embedding_dim) - encoder = nn.Sequential( - nn.Conv1d(embedding_dim if name in self.categorical_ts_vars else 1, 64, kernel_size=3, padding=1), - nn.ReLU(), - nn.Conv1d(64, 128, kernel_size=3, padding=1), - nn.ReLU(), - nn.AdaptiveAvgPool1d(1), # Global average pooling - nn.Flatten(), - nn.Linear(128, embedding_dim), - ) - self.ts_encoders[name] = encoder - - # Mixing MLP to combine all time series embeddings - total_dim = embedding_dim * (len(self.categorical_ts_vars) + len(self.numeric_ts_vars)) - if total_dim > 0: - self.mixing_mlp = nn.Sequential( - nn.Linear(total_dim, 128), - nn.ReLU(), - nn.Linear(128, embedding_dim), - ) - else: - self.mixing_mlp = nn.Identity() - - # Initialize weights with Kaiming initialization - self._initialize_weights() - - def _initialize_weights(self): - """ - Initialize weights using Kaiming (He) initialization for better training with ReLU activations. - This is particularly important for the CNN layers and Linear layers. - """ - for module in self.modules(): - if isinstance(module, nn.Conv1d): - # Kaiming initialization for Conv1d layers (already default for ReLU, but make explicit) - nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') - if module.bias is not None: - nn.init.constant_(module.bias, 0) - elif isinstance(module, nn.Linear): - # Kaiming initialization for Linear layers (better than default Xavier for ReLU) - nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') - if module.bias is not None: - nn.init.constant_(module.bias, 0) - # Note: Embedding layers keep their default initialization (normal with std=1.0) - # which is appropriate for embeddings - - def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: - """ - Process dynamic (time series) context variables. - - Args: - context_vars: Dict mapping variable names to tensors. - For categorical TS: (batch, seq_len) with integer values - For numeric TS: (batch, seq_len) with float values - - Returns: - embedding: Combined embedding of shape (batch_size, embedding_dim) - outputs: Empty dict for compatibility - """ - embeddings = [] - - # # Process categorical time series - # for name in self.categorical_ts_vars.keys(): - # if name in context_vars: - # # Input: (batch, seq_len) with integer indices - # ts_data = context_vars[name] # (batch, seq_len) - # # Check for NaN/Inf in input - # if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): - # raise ValueError(f"NaN/Inf detected in categorical time series input '{name}'") - # # Embed: (batch, seq_len) -> (batch, seq_len, embedding_dim) - # embedded = self.ts_embeddings[name](ts_data) - # # Transpose for CNN: (batch, embedding_dim, seq_len) - # embedded = embedded.transpose(1, 2) - # # Check for NaN after embedding - # if torch.isnan(embedded).any() or torch.isinf(embedded).any(): - # raise ValueError(f"NaN/Inf detected after embedding for '{name}'") - # # Encode: (batch, embedding_dim, seq_len) -> (batch, embedding_dim) - # encoded = self.ts_encoders[name](embedded) - # # Check for NaN after encoding - # if torch.isnan(encoded).any() or torch.isinf(encoded).any(): - # raise ValueError(f"NaN/Inf detected after encoding for '{name}'") - # embeddings.append(encoded) - - # Process numeric time series - for name in self.numeric_ts_vars: - if name in context_vars: - # Input: (batch, seq_len) with float values - ts_data = context_vars[name] # (batch, seq_len) - # Ensure numeric time series are float type (not long/int) - if not ts_data.is_floating_point(): - ts_data = ts_data.float() - # Check for NaN/Inf in input - if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): - raise ValueError(f"NaN/Inf detected in numeric time series input '{name}'") - # Replace NaN/Inf with zeros to prevent propagation - ts_data = torch.where(torch.isfinite(ts_data), ts_data, torch.zeros_like(ts_data)) - # Add channel dimension: (batch, 1, seq_len) - ts_data = ts_data.unsqueeze(1) - # Encode: (batch, 1, seq_len) -> (batch, embedding_dim) - encoded = self.ts_encoders[name](ts_data) - # Check for NaN after encoding - if torch.isnan(encoded).any() or torch.isinf(encoded).any(): - raise ValueError(f"NaN/Inf detected after encoding numeric TS '{name}'") - embeddings.append(encoded) - - if not embeddings: - # No dynamic context variables, return zero embedding - batch_size = next(iter(context_vars.values())).size(0) if context_vars else 1 - embedding = torch.zeros(batch_size, self.embedding_dim, device=next(iter(context_vars.values())).device if context_vars else None) - return embedding, {} - - # Combine all time series embeddings - combined = torch.cat(embeddings, dim=1) # (batch, total_dim) - # Check for NaN before mixing - if torch.isnan(combined).any() or torch.isinf(combined).any(): - raise ValueError(f"NaN/Inf detected in combined embeddings before mixing MLP") - embedding = self.mixing_mlp(combined) # (batch, embedding_dim) - # Check for NaN after mixing - if torch.isnan(embedding).any() or torch.isinf(embedding).any(): - raise ValueError(f"NaN/Inf detected in final embedding after mixing MLP") - - return embedding, {} - - @register_context_module("dynamic_transformer") class DynamicContextModule_Transformer(BaseContextModule): """ @@ -596,7 +426,6 @@ def __init__( for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars }) else: - # If seq_len not provided, use learnable positional encoding that can adapt self.pos_encodings = None # Transformer encoder for each time series variable @@ -614,19 +443,15 @@ def __init__( for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars }) - # Output projection: sum contributions from all variables, then project - # to embedding_dim so the cross-attention key/value dim is consistent. + n_vars = len(self.categorical_ts_vars) + len(self.numeric_ts_vars) - # Per-variable weight (scalar) for the additive mixture across variables self.var_mix = nn.Linear(n_vars * embedding_dim, embedding_dim) if n_vars > 1 else None - # Per-variable layer norm applied after each transformer encoder output all_var_names = list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars self.post_encoder_norms = nn.ModuleDict({ name: nn.LayerNorm(embedding_dim) for name in all_var_names }) - # Final layer norm applied after var_mix (or single-variable output) self.post_mix_norm = nn.LayerNorm(embedding_dim) # Initialize weights @@ -638,16 +463,12 @@ def _initialize_weights(self): """ for module in self.modules(): if isinstance(module, nn.Linear): - # Xavier initialization for transformer linear layers nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.Parameter): - # Initialize positional encodings if module.dim() == 3: # (1, seq_len, embedding_dim) nn.init.normal_(module, std=0.02) - # Note: Embedding layers keep their default initialization - # Transformer encoder layers use their own initialization def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ @@ -724,11 +545,6 @@ def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, return out, {} - def on_after_backward(self): - unused = [n for n,p in self.named_parameters() if p.requires_grad and p.grad is None] - if unused: - print("UNUSED:", unused[:50]) - @register_context_module("dynamic_joint_transformer") class DynamicContextModule_JointTransformer(BaseContextModule): @@ -773,8 +589,6 @@ def __init__( n_numeric = len(self.numeric_ts_vars) - # Project all numeric channels jointly: (B, T, n_vars) → (B, T, emb_dim) - # This single linear sees every variable at every timestep simultaneously. if n_numeric > 0: self.numeric_input_proj = nn.Linear(n_numeric, embedding_dim) else: diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index fcb3495..0717e55 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -15,23 +15,6 @@ from contextlib import contextmanager -def _nan_check(t: Optional[torch.Tensor], name: str, extra: str = "") -> None: - """Print location and stats when tensor contains NaN or Inf (for debugging).""" - if t is None or not isinstance(t, torch.Tensor): - return - if not (torch.isnan(t).any() or torch.isinf(t).any()): - return - nan_c = torch.isnan(t).sum().item() - inf_c = torch.isinf(t).sum().item() - finite = t[~(torch.isnan(t) | torch.isinf(t))] - min_s = finite.min().item() if finite.numel() > 0 else float("nan") - max_s = finite.max().item() if finite.numel() > 0 else float("nan") - mean_s = finite.float().mean().item() if finite.numel() > 0 else float("nan") - print( - f"[NaN/Inf] {name}: shape={tuple(t.shape)}, nan_count={nan_c}, inf_count={inf_c}, " - f"finite_min={min_s:.6g}, finite_max={max_s:.6g}, finite_mean={mean_s:.6g} {extra}".strip() - ) - from cents.models.base import GenerativeModel from cents.models.model_utils import ( @@ -103,7 +86,6 @@ def blueish_noise_like( n_blue = n_blue.expand(B, L, C).clone() out = n_blue.to(dtype=x.dtype) - _nan_check(out, "blueish_noise_like output") return out @@ -277,9 +259,6 @@ def __init__(self, cfg: DictConfig): gamma=min_snr_gamma, ) self.register_buffer("loss_weight", lw) - _nan_check(self.loss_weight, "init loss_weight") - _nan_check(self.betas, "init betas") - _nan_check(self.sqrt_alphas_cumprod, "init sqrt_alphas_cumprod") # choose reconstruction loss if self.loss_type == "l1": @@ -296,7 +275,8 @@ def __init__(self, cfg: DictConfig): self.categorical_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "categorical"] def _get_context_embedding( - self, static_context_vars: dict, dynamic_context_vars: dict = None + self, static_context_vars: dict, dynamic_context_vars: dict = None, + batch_size: int = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], dict]: """ Get context embeddings from static and/or dynamic context modules. @@ -318,21 +298,7 @@ def _get_context_embedding( k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v for k, v in static_context_vars.items() } - for k, v in static_vars.items(): - if isinstance(v, torch.Tensor) and (torch.isnan(v).any() or torch.isinf(v).any()): - nan_c = torch.isnan(v).sum().item() - inf_c = torch.isinf(v).sum().item() - finite = v[~(torch.isnan(v) | torch.isinf(v))] - min_s = finite.min().item() if finite.numel() > 0 else float("nan") - max_s = finite.max().item() if finite.numel() > 0 else float("nan") - mean_s = finite.float().mean().item() if finite.numel() > 0 else float("nan") - print( - f"[NaN/Inf] static_var '{k}': shape={tuple(v.shape)}, dtype={v.dtype}, " - f"nan_count={nan_c}, inf_count={inf_c}, finite_min={min_s:.6g}, " - f"finite_max={max_s:.6g}, finite_mean={mean_s:.6g}" - ) static_emb, static_logits = self.static_context_module(static_vars) - _nan_check(static_emb, "_get_context_embedding static_emb") all_logits.update(static_logits) # --- Dynamic context (time series) → (B, T, emb_dim) for cross-attention --- @@ -343,7 +309,6 @@ def _get_context_embedding( for k, v in dynamic_context_vars.items() } dyn_out, dyn_logits = self.dynamic_context_module(dyn_vars) - _nan_check(dyn_out, "_get_context_embedding dyn_out") all_logits.update(dyn_logits) if getattr(self.dynamic_context_module, "returns_sequence", False): @@ -358,21 +323,19 @@ def _get_context_embedding( static_emb = dyn_out if static_emb is None: - raise ValueError("No static context embedding could be produced") + if batch_size is None: + raise ValueError("No context provided and batch_size not given — cannot create unconditional embedding") + device = next(self.parameters()).device + static_emb = torch.zeros(batch_size, self.embedding_dim, device=device) if static_emb.is_floating_point(): static_emb = static_emb.float() - _nan_check(static_emb, "_get_context_embedding static_emb (before dropout)") if self.training and self.context_embed_dropout_p > 0: - # Sample-wise mask: zero the entire embedding for ~p fraction of samples. - # Each sample independently gets its context dropped (not individual features), - # which teaches the model to work unconditionally — enabling CFG at inference. mask = torch.bernoulli( torch.full((static_emb.shape[0], 1), 1.0 - self.context_embed_dropout_p, device=static_emb.device, dtype=static_emb.dtype) ) static_emb = static_emb * mask - _nan_check(static_emb, "_get_context_embedding static_emb (final)") return static_emb, dyn_ctx_seq, all_logits def _decode_to_x0(self, backbone: torch.Tensor) -> torch.Tensor: @@ -380,7 +343,6 @@ def _decode_to_x0(self, backbone: torch.Tensor) -> torch.Tensor: Map backbone output (trend+season) to x0 prediction. Uses single fc or dual fc_a/fc_b when recon_cond_len is set. backbone: (B, L, time_series_dims). """ - _nan_check(backbone, "_decode_to_x0 backbone") if self.fc is not None: out = self.fc(backbone) else: @@ -389,7 +351,6 @@ def _decode_to_x0(self, backbone: torch.Tensor) -> torch.Tensor: self.fc_a(backbone[:, :cond_len]), self.fc_b(backbone[:, cond_len:]), ], dim=1) - _nan_check(out, "_decode_to_x0 output") return out def predict_noise_from_start( @@ -409,7 +370,6 @@ def predict_noise_from_start( out = ( self.sqrt_recip_alphas_cumprod[t].view(-1, 1, 1) * x_t - x0 ) / self.sqrt_recipm1_alphas_cumprod[t].view(-1, 1, 1) - _nan_check(out, "predict_noise_from_start output") return out def predict_start_from_noise( @@ -430,7 +390,6 @@ def predict_start_from_noise( self.sqrt_recip_alphas_cumprod[t].view(-1, 1, 1) * x_t - self.sqrt_recipm1_alphas_cumprod[t].view(-1, 1, 1) * noise ) - _nan_check(out, "predict_start_from_noise output") return out def predict_start_from_v( @@ -444,7 +403,6 @@ def predict_start_from_v( self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x_t - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * v ) - _nan_check(out, "predict_start_from_v output") return out def predict_noise_from_v( @@ -458,7 +416,6 @@ def predict_noise_from_v( self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x_t + self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * v ) - _nan_check(out, "predict_noise_from_v output") return out @@ -533,9 +490,6 @@ def q_posterior( ) pv = self.posterior_variance[t].view(-1, 1, 1) plv = self.posterior_log_variance_clipped[t].view(-1, 1, 1) - _nan_check(pm, "q_posterior pm") - _nan_check(pv, "q_posterior pv") - _nan_check(plv, "q_posterior plv") return pm, pv, plv def forward(self, x: torch.Tensor, static_context_vars: dict, dynamic_context_vars: dict = None) -> Tuple[torch.Tensor, dict]: @@ -550,50 +504,27 @@ def forward(self, x: torch.Tensor, static_context_vars: dict, dynamic_context_va rec_loss: Reconstruction loss tensor. cond_logits: Classification logits dict from context module. """ - _nan_check(x, "forward input x") - # Log when x is in reasonable range but we still see NaN later (helps distinguish bad input vs numerical instability) - # if isinstance(x, torch.Tensor): - # x_abs_max = x.abs().max().item() - # if x_abs_max > 50.0: - # print( - # f"[forward] input x has large values: min={x.min().item():.6g}, max={x.max().item():.6g}, abs_max={x_abs_max:.6g}" - # ) b = x.shape[0] t = torch.randint(0, self.num_timesteps, (b,), device=self.device) - embedding, dyn_ctx_seq, cond_classification_logits = self._get_context_embedding(static_context_vars, dynamic_context_vars) - _nan_check(embedding, "forward embedding") + embedding, dyn_ctx_seq, cond_classification_logits = self._get_context_embedding(static_context_vars, dynamic_context_vars, batch_size=b) noise = blueish_noise_like( x, power=self.blue_noise_power, correlated=self.correlated_noise ) - _nan_check(noise, "forward noise") x_noisy = ( self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * noise ) - _nan_check(x_noisy, "forward x_noisy") trend, season = self.model(x_noisy, t, padding_masks=None, cond=embedding, dyn_ctx=dyn_ctx_seq) - _nan_check(trend, "forward trend") - _nan_check(season, "forward season") x_start_pred = self._decode_to_x0((trend + season).contiguous()) - _nan_check(x_start_pred, "forward x_start_pred") # Compute loss based on training objective (network always predicts x0; we derive epsilon/v as needed) if self.training_objective == "x0": loss_per_elem = self.recon_loss_fn(x_start_pred, x, reduction="none") elif self.training_objective == "eps": pred_noise = self.predict_noise_from_start(x_noisy, t, x_start_pred) - _nan_check(pred_noise, "forward pred_noise (eps)") loss_per_elem = self.recon_loss_fn(pred_noise, noise, reduction="none") else: # v - # Compute pred_v directly from x_start_pred and x_noisy, avoiding the - # two-step path through predict_noise_from_start which divides by - # sqrt_recipm1_alphas_cumprod — a value near 0 at low t (cosine schedule - # gives ~0.01 at t=0), amplifying prediction errors ~100x into pred_noise - # before they land in pred_v. The algebraic identity: - # v = sqrt(α_bar)*ε - sqrt(1-α_bar)*x0 - # ε = (x_noisy - sqrt(α_bar)*x0) / sqrt(1-α_bar) - # => pred_v = (sqrt(α_bar)*x_noisy - x0) / sqrt(1-α_bar).clamp(min=1e-3) sqrt_ab = self.sqrt_alphas_cumprod[t].view(-1, 1, 1) sqrt_1mab = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1).clamp(min=1e-3) pred_v = (sqrt_ab * x_noisy - x_start_pred) / sqrt_1mab @@ -601,20 +532,10 @@ def forward(self, x: torch.Tensor, static_context_vars: dict, dynamic_context_va self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * noise - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x ) - _nan_check(pred_v, "forward pred_v") - _nan_check(true_v, "forward true_v") loss_per_elem = self.recon_loss_fn(pred_v, true_v, reduction="none") - _nan_check(loss_per_elem, "forward loss_per_elem") rec_loss = ( self.loss_weight[t].view(-1, 1, 1) * loss_per_elem ).mean() - _nan_check(rec_loss, "forward rec_loss") - # When loss is NaN but input x was in reasonable range, point to numerical instability downstream - if (torch.isnan(rec_loss) | torch.isinf(rec_loss)).any(): - print( - f"[forward] rec_loss is NaN/Inf while input x had min={x.min().item():.6g}, max={x.max().item():.6g}, abs_max={x.abs().max().item():.6g}" - ) - fourier_loss = torch.tensor(0.0, device=self.device) if self.use_ff: # FFT is not generally supported in fp16 for non power-of-2 sizes on cuFFT. @@ -628,26 +549,11 @@ def forward(self, x: torch.Tensor, static_context_vars: dict, dynamic_context_va mag1 = torch.abs(fft1) mag2 = torch.abs(fft2) - _nan_check(mag1, "forward fourier mag1") - _nan_check(mag2, "forward fourier mag2") - fourier_loss = ( - self.recon_loss_fn(mag1, mag2, reduction="none") - ) - _nan_check(fourier_loss, "forward fourier_loss (per-elem)") - fourier_loss = ( - self.loss_weight[t].view(-1, 1, 1) * fourier_loss - ).mean() - _nan_check(fourier_loss, "forward fourier_loss (scalar)") + fourier_loss = self.recon_loss_fn(mag1, mag2, reduction="none") + fourier_loss = (self.loss_weight[t].view(-1, 1, 1) * fourier_loss).mean() - # fourier_loss = ( - # self.recon_loss_fn(fft1.real, fft2.real, reduction="none") - # + self.recon_loss_fn(fft1.imag, fft2.imag, reduction="none") - # ) - # fourier_loss = ( - # self.loss_weight[t].view(-1, 1, 1) * fourier_loss - # ).mean() return rec_loss, cond_classification_logits, fourier_loss.mean() @@ -663,49 +569,29 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: total_loss: Scalar training loss. """ ts_batch, static_context_batch, dynamic_context_batch = batch - # print("BEFORE PRINT I") - # print(ts_batch, static_context_batch, dynamic_context_batch) - # print("AFTER PRINT I") - _nan_check(ts_batch, "training_step ts_batch") rec_loss, cond_class_logits, fourier_loss = self(ts_batch, static_context_batch, dynamic_context_batch) - _nan_check(rec_loss, "training_step rec_loss") - _nan_check(fourier_loss, "training_step fourier_loss") cond_loss = 0.0 for var_name, outputs in cond_class_logits.items(): labels = static_context_batch[var_name] - if isinstance(outputs, torch.Tensor): - _nan_check(outputs, f"training_step cond_logits[{var_name}]") - if isinstance(labels, torch.Tensor): - _nan_check(labels, f"training_step cond_labels[{var_name}]") if var_name in self.continuous_context_vars: loss = F.mse_loss(outputs, labels.float()) elif var_name in self.categorical_context_vars: loss = self.auxiliary_loss(outputs, labels) - _nan_check(loss, f"training_step cond_loss[{var_name}]") cond_loss += loss.mean() - # Normalize by number of context variables so the weight is dataset-independent - # if len(cond_class_logits) > 0: - # cond_loss = cond_loss / len(cond_class_logits) - h, _, _ = self._get_context_embedding(static_context_batch, dynamic_context_batch) - _nan_check(h, "training_step h (for tc)") + h, _, _ = self._get_context_embedding(static_context_batch, dynamic_context_batch, batch_size=ts_batch.shape[0]) tc_term = ( self.cfg.model.tc_loss_weight * total_correlation(h) if self.cfg.model.tc_loss_weight > 0.0 else torch.tensor(0.0, device=self.device) ) - _nan_check(tc_term, "training_step tc_term") total_loss = ( rec_loss + self.context_reconstruction_loss_weight * cond_loss + tc_term + fourier_loss * self.ff_weight ) - _nan_check(total_loss, f"training_step total_loss batch_idx={batch_idx}") # Skip this batch entirely if loss is bad — avoids corrupting weights before EMA can help - if not torch.isfinite(total_loss): - print(f"[training_step] Non-finite loss ({total_loss.item()}) at batch {batch_idx}, skipping.") - return None self.log_dict( { @@ -786,39 +672,7 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: if hasattr(self, '_ema') and self._ema: self._ema.update() - # def on_load_checkpoint(self, checkpoint: dict) -> None: - # """ - # Restore EMA weights from checkpoint after loading. - # """ - # super().on_load_checkpoint(checkpoint) - - # # Check if EMA weights exist in checkpoint - # state_dict = checkpoint.get('state_dict', {}) - # ema_keys = [key for key in state_dict.keys() if key.startswith('_ema.')] - - # if ema_keys: - # if not hasattr(self, '_ema') or self._ema is None: - # self._ema = EMA( - # self.model, - # beta=self.cfg.model.ema_decay, - # update_every=self.cfg.model.ema_update_interval, - # ) - - # # Load EMA weights into the EMA helper - # ema_state_dict = {} - # for key, value in state_dict.items(): - # if key.startswith('_ema.ema_model.'): - # # Map '_ema.ema_model.*' -> 'ema_model.*' (remove the _ema prefix) - # ema_key = key.replace('_ema.ema_model.', 'ema_model.') - # ema_state_dict[ema_key] = value - - # if ema_state_dict: - # print(f"Loading {len(ema_state_dict)} EMA weights from checkpoint") - # self._ema.ema_model.load_state_dict(ema_state_dict, strict=False) - # else: - # raise ValueError("No EMA model weights found in checkpoint") - # else: - # raise ValueError("No EMA keys found in checkpoint") + def load_state_dict(self, state_dict, strict=True): # Strip legacy _ema.* keys — EMA is restored separately via on_load_checkpoint. # Old checkpoints have these because _ema was previously a registered submodule. @@ -853,7 +707,6 @@ def _predict_x0_from_xt_with_grad( """ trend, season = self.model(x_t, t, padding_masks=None, cond=embedding, dyn_ctx=dyn_ctx) x_start = self._decode_to_x0((trend + season).contiguous()) - _nan_check(x_start, "_predict_x0_from_xt_with_grad x_start") return x_start @torch.no_grad() @@ -871,8 +724,6 @@ def model_predictions( trend, season = self.model(x, t, padding_masks=None, cond=embedding, dyn_ctx=dyn_ctx) x_start = self._decode_to_x0((trend + season).contiguous()) pred_noise = self.predict_noise_from_start(x, t, x_start) - _nan_check(x_start, "model_predictions x_start") - _nan_check(pred_noise, "model_predictions pred_noise") return pred_noise, x_start @staticmethod @@ -900,7 +751,6 @@ def p_mean_variance( """ pred_noise, x_start = self.model_predictions(x, t, embedding, dyn_ctx=dyn_ctx) pm, pv, plv = self.q_posterior(x_start, x, t) - _nan_check(x_start, "p_mean_variance x_start") return pm, pv, plv, x_start @torch.no_grad() @@ -919,7 +769,6 @@ def p_sample( else 0 ) out = pm + (0.5 * plv).exp() * noise - _nan_check(out, "p_sample output") return out def _reconstruction_guided_step_alg1( @@ -941,10 +790,8 @@ def _reconstruction_guided_step_alg1( x_t = x_t.detach().requires_grad_(True) x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding, dyn_ctx=dyn_ctx) - _nan_check(x_start, "_reconstruction_guided_step_alg1 x_start") x_hat_a = x_start[:, :cond_len] L_1 = (x_a - x_hat_a).pow(2).mean() - _nan_check(L_1, "_reconstruction_guided_step_alg1 L_1") pm, pv, plv = self.q_posterior(x_start, x_t, bt) noise = ( @@ -954,15 +801,11 @@ def _reconstruction_guided_step_alg1( ) x_prev_initial = (pm + (0.5 * plv).exp() * noise).detach() L_2 = ((x_prev_initial - pm).pow(2) / pv.clamp(min=1e-8)).mean() - _nan_check(L_2, "_reconstruction_guided_step_alg1 L_2") loss = L_1 + gamma * L_2 - _nan_check(loss, "_reconstruction_guided_step_alg1 loss") loss.backward() with torch.no_grad(): x_tilde_0 = x_start.detach() + eta * x_t.grad - _nan_check(x_t.grad, "_reconstruction_guided_step_alg1 x_t.grad") - _nan_check(x_tilde_0, "_reconstruction_guided_step_alg1 x_tilde_0") pm_final, pv_final, plv_final = self.q_posterior(x_tilde_0, x_t.detach(), bt) noise_final = ( _randn_like_correlated(x_t, self.correlated_noise) @@ -971,7 +814,6 @@ def _reconstruction_guided_step_alg1( ) x_prev = pm_final + (0.5 * plv_final).exp() * noise_final x_prev = self._replace_conditional(x_a, x_prev, cond_len) - _nan_check(x_prev, "_reconstruction_guided_step_alg1 x_prev") return x_prev def _reconstruction_guided_step_alg2( @@ -997,7 +839,6 @@ def _reconstruction_guided_step_alg2( for _ in range(K): x_t = x_t.requires_grad_(True) x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach, dyn_ctx=dyn_ctx_detach) - _nan_check(x_start, "_reconstruction_guided_step_alg2 x_start (inner)") x_hat_a = x_start[:, :cond_len] L_1 = (x_a - x_hat_a).pow(2).mean() pm, pv, plv = self.q_posterior(x_start, x_t, bt) @@ -1009,16 +850,13 @@ def _reconstruction_guided_step_alg2( x_prev_initial = (pm + (0.5 * plv).exp() * noise).detach() L_2 = ((x_prev_initial - pm).pow(2) / pv.clamp(min=1e-8)).mean() loss = L_1 + gamma * L_2 - _nan_check(loss, "_reconstruction_guided_step_alg2 loss (inner)") loss.backward() with torch.no_grad(): - _nan_check(x_t.grad, "_reconstruction_guided_step_alg2 x_t.grad") x_t = x_t + eta * x_t.grad x_t = x_t.detach() with torch.no_grad(): x_start_final = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach, dyn_ctx=dyn_ctx_detach) - _nan_check(x_start_final, "_reconstruction_guided_step_alg2 x_start_final") pm_final, pv_final, plv_final = self.q_posterior(x_start_final, x_t, bt) noise_final = ( blueish_noise_like(x_t, power=self.blue_noise_power, correlated=self.correlated_noise) @@ -1027,7 +865,6 @@ def _reconstruction_guided_step_alg2( ) x_prev = pm_final + (0.5 * plv_final).exp() * noise_final x_prev = self._replace_conditional(x_a, x_prev, cond_len) - _nan_check(x_prev, "_reconstruction_guided_step_alg2 x_prev") return x_prev def sample_reconstruction_guided( @@ -1070,7 +907,7 @@ def sample_reconstruction_guided( x = _randn_shape_correlated( shape, self.device, torch.float32, self.correlated_noise ) - embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) + embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars, batch_size=shape[0]) x_a = x_a.to(self.device) for t in reversed(range(self.num_timesteps)): @@ -1105,10 +942,9 @@ def sample(self, shape: Tuple[int, int, int], static_context_vars: dict, dynamic x = _randn_shape_correlated( shape, self.device, torch.float32, self.correlated_noise ) - embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) + embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars, batch_size=shape[0]) for t in reversed(range(self.num_timesteps)): x = self.p_sample(x, t, embedding, dyn_ctx=dyn_ctx_seq) - _nan_check(x, "sample() output") return x @torch.no_grad() @@ -1127,7 +963,7 @@ def fast_sample( x = _randn_shape_correlated( shape, self.device, torch.float32, self.correlated_noise ) - embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars) + embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars, batch_size=shape[0]) use_cfg = cfg_scale > 1.0 if use_cfg: @@ -1150,7 +986,6 @@ def fast_sample( pred_noise, x_start = self.model_predictions(x, bt, embedding, dyn_ctx=dyn_ctx_seq) if time_next < 0: x = x_start - _nan_check(x, "fast_sample x (final step)") continue alpha = self.alphas_cumprod[time] alpha_next = self.alphas_cumprod[time_next] @@ -1161,8 +996,6 @@ def fast_sample( c = (1 - alpha_next - sigma**2).sqrt() noise = _randn_like_correlated(x, self.correlated_noise) x = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise - _nan_check(x, "fast_sample x (mid)") - _nan_check(x, "fast_sample x (final)") return x @contextmanager @@ -1177,26 +1010,32 @@ def ema_scope(self): else: yield - def generate(self, static_context_vars: dict, dynamic_context_vars: dict = None) -> torch.Tensor: + def generate(self, static_context_vars: dict, dynamic_context_vars: dict = None, n: int = None) -> torch.Tensor: """ Public entry to generate conditioned samples in batches. Args: - static_context_vars: dict of context tensors for each sample. + static_context_vars: dict of context tensors for each sample. Pass {} for unconditional generation. dynamic_context_vars: dict of dynamic context tensors for each sample. + n: number of samples to generate; required when static_context_vars is empty. Returns: Complete generated tensor of shape (N, seq_len, dims). """ bs = self.cfg.model.sampling_batch_size - total = len(next(iter(static_context_vars.values()))) + if static_context_vars: + total = len(next(iter(static_context_vars.values()))) + elif n is not None: + total = n + else: + raise ValueError("Pass n= when generating unconditionally (empty static_context_vars)") generated_samples = [] with self.ema_scope(): for start_idx in tqdm( range(0, total, bs), unit="seq", - desc="[CENTS] Generating samples", + desc="[CENTS] Generating samples", leave=True, ): end_idx = min(start_idx + bs, total) @@ -1206,7 +1045,7 @@ def generate(self, static_context_vars: dict, dynamic_context_vars: dict = None) } batch_dynamic_context_vars = { var_name: var_tensor[start_idx:end_idx] - for var_name, var_tensor in dynamic_context_vars.items() + for var_name, var_tensor in (dynamic_context_vars or {}).items() } current_bs = end_idx - start_idx diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index 5006bfa..be4cf73 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -18,23 +18,6 @@ from torch import nn -def _nan_check(t: Optional[torch.Tensor], name: str, extra: str = "") -> None: - """Print location and stats when tensor contains NaN or Inf (for debugging).""" - if t is None or not isinstance(t, torch.Tensor): - return - if not (torch.isnan(t).any() or torch.isinf(t).any()): - return - nan_c = torch.isnan(t).sum().item() - inf_c = torch.isinf(t).sum().item() - finite = t[~(torch.isnan(t) | torch.isinf(t))] - min_s = finite.min().item() if finite.numel() > 0 else float("nan") - max_s = finite.max().item() if finite.numel() > 0 else float("nan") - mean_s = finite.float().mean().item() if finite.numel() > 0 else float("nan") - print( - f"[NaN/Inf] Transformer {name}: shape={tuple(t.shape)}, nan_count={nan_c}, inf_count={inf_c}, " - f"finite_min={min_s:.6g}, finite_max={max_s:.6g}, finite_mean={mean_s:.6g} {extra}".strip() - ) - def linear_beta_schedule(timesteps: int) -> torch.Tensor: """ @@ -70,9 +53,6 @@ def cosine_beta_schedule(timesteps: int, s: float = 0.004) -> torch.Tensor: betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) -import math -import torch - def cosine_beta_schedule_logsnr( timesteps: int, @@ -254,21 +234,10 @@ def __init__(self, in_dim, out_dim, resid_pdrop=0.0): self.drop = nn.Dropout(p=resid_pdrop) def forward(self, x): - # x: (B, T, C) - _nan_check(x, "Conv_MLP forward x (initial)") - # # Print when values are extreme (even if not NaN) to debug downstream NaN - # print( - # f"[Conv_MLP] x (initial): shape={tuple(x.shape)}, min={x.min().item():.6g}, max={x.max().item():.6g}, " - # f"abs_max={x.abs().max().item():.6g}, has_nan={torch.isnan(x).any().item()}, has_inf={torch.isinf(x).any().item()}" - # ) x = x.transpose(1, 2).contiguous() # (B, C, T) contiguous - _nan_check(x, "Conv_MLP forward x (transposed)") x = self.conv(x) - _nan_check(x, "Conv_MLP forward x (conv)") x = self.drop(x) - _nan_check(x, "Conv_MLP forward x (drop)") out = x.transpose(1, 2).contiguous() # back to (B, T, C), contiguous - _nan_check(out, "Conv_MLP forward out (transposed)") return out @@ -717,9 +686,7 @@ def __init__( act = nn.GELU() if activate == "GELU" else GELU2() self.trend = TrendBlock(n_channel, n_channel, n_embd, n_feat, act=act) - # self.decomp = MovingBlock(n_channel) self.seasonal = FourierLayer(d_model=n_embd) - # self.seasonal = SeasonBlock(n_channel, n_channel) self.mlp = nn.Sequential( nn.Linear(n_embd, mlp_hidden_times * n_embd), @@ -921,60 +888,40 @@ def forward(self, input, t, padding_masks=None, return_res=False, cond=None, dyn cond = cond.float() if dyn_ctx is not None and dyn_ctx.is_floating_point(): dyn_ctx = dyn_ctx.float() - _nan_check(input, "forward input") t_emb = self.time_emb(t) - _nan_check(t_emb, "forward t_emb") label_emb = None if (cond is not None) and (self.cond_proj is not None): label_emb = self.cond_proj(cond) # (B, n_embd) - _nan_check(label_emb, "forward label_emb") total_cond_emb = self.cond_mix_mlp(torch.concat([t_emb, label_emb], dim=1)) else: total_cond_emb = t_emb - _nan_check(total_cond_emb, "forward total_cond_emb") # Project dynamic context sequence to n_embd for cross-attention dyn_ctx_emb = None if dyn_ctx is not None and self.dyn_ctx_proj is not None: dyn_ctx_emb = self.dyn_ctx_proj(dyn_ctx) # (B, T, n_embd) - _nan_check(dyn_ctx_emb, "forward dyn_ctx_emb") emb = self.emb(input) - _nan_check(emb, "forward emb") inp_enc = self.pos_enc(emb) - _nan_check(inp_enc, "forward inp_enc") enc_cond = self.encoder(inp_enc, total_cond_emb, padding_masks=padding_masks) - _nan_check(enc_cond, "forward enc_cond") inp_dec = self.pos_dec(emb) - _nan_check(inp_dec, "forward inp_dec") output, mean, trend, season = self.decoder( inp_dec, total_cond_emb, enc_cond, dyn_ctx=dyn_ctx_emb, padding_masks=padding_masks ) - _nan_check(output, "forward decoder output") - _nan_check(mean, "forward decoder mean") - _nan_check(trend, "forward decoder trend") - _nan_check(season, "forward decoder season") res = self.inverse(output) - _nan_check(res, "forward res (inverse output)") res_m = torch.mean(res, dim=1, keepdim=True).contiguous() - _nan_check(res_m, "forward res_m") combine_m_out = self.combine_m(mean).contiguous() - _nan_check(combine_m_out, "forward combine_m_out") combine_s_out = self.combine_s(season.transpose(1, 2)).transpose(1, 2).contiguous() - _nan_check(combine_s_out, "forward combine_s_out") season_error = (combine_s_out + res - res_m).contiguous() - _nan_check(season_error, "forward season_error") trend = (combine_m_out + res_m + trend).contiguous() - _nan_check(trend, "forward trend (final)") if return_res: out_res = res - res_m - _nan_check(out_res, "forward return res - res_m") return trend, combine_s_out, out_res return trend, season_error diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 245a569..6f3b5e1 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -12,7 +12,7 @@ from cents.datasets.utils import split_timeseries from cents.models.base import NormalizerModel -from cents.models.context import MLPContextModule, SepMLPContextModule, DynamicContextModule_CNN, DynamicContextModule_Transformer # Import to trigger registration +from cents.models.context import MLPContextModule, SepMLPContextModule, DynamicContextModule_Transformer # Import to trigger registration from cents.models.context_registry import get_context_module_cls from cents.models.stats_head_registry import register_stats_head, get_stats_head_cls from cents.models.registry import register_model @@ -56,10 +56,6 @@ def _initialize_output_layer(self, init_sigma: float = 1.0): if self.do_scale: # 1. Initialize z_min to -2.0 out_layer.bias[2 * D : 3 * D].fill_(-2.0) - - # 2. Initialize the RAW DELTA to ~4.0 - # Softplus(4.0) is approx 4.018. - # z_max = -2.0 + 4.018 = 2.018 (Perfect starting point) out_layer.bias[3 * D : 4 * D].fill_(4.0) @staticmethod @@ -155,10 +151,6 @@ def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_di # Process static context variables if self.static_cond_module is not None: - # static_vars = { - # k: v for k, v in context_vars_dict.items() - # if k not in getattr(self, '_dynamic_var_names', []) - # } if static_context_vars_dict: device = next(self.static_cond_module.parameters()).device static_context_vars_dict = { @@ -170,11 +162,6 @@ def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_di # Process dynamic context variables if self.dynamic_cond_module is not None: - # dynamic_var_names = getattr(self, '_dynamic_var_names', []) - # dynamic_vars = { - # k: v for k, v in context_vars_dict.items() - # if k in dynamic_var_names - # } if dynamic_context_vars_dict: device = next(self.dynamic_cond_module.parameters()).device dynamic_context_vars_dict = { @@ -379,31 +366,7 @@ def setup(self, stage: Optional[str] = None): # Log initial predictions if stage == "fit" or stage is None: - self._log_initial_predictions() - - def _log_initial_predictions(self): - """Log initial model predictions to diagnose initialization issues.""" - # self.eval() - # with torch.no_grad(): - # dataloader = self.train_dataloader() - # batch = next(iter(dataloader)) - # cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch - - # device = next(self.parameters()).device - # cat_vars_dict = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in cat_vars_dict.items()} - # mu_t = mu_t.to(device) - # sigma_t = sigma_t.to(device) - - # # Predict (Returns Real Unscaled values via Forward) - # pred_mu, pred_sigma, pred_z_min, pred_z_max, _ = self(cat_vars_dict) - - # print(f"\n[Initial Predictions]") - # print(f" Target mu: mean={mu_t.mean().item():.4f}, std={mu_t.std().item():.4f}") - # print(f" Predicted mu: mean={pred_mu.mean().item():.4f}, std={pred_mu.std().item():.4f}") - # print(f" Initial loss_mu: {F.mse_loss(pred_mu, mu_t).item():.6f}") - # print() - - self.train() + self.train() def _raw_mu_to_real(self, pred_mu_raw: torch.Tensor) -> torch.Tensor: """Convert network mu output to real-world mu (handles both global and direct/asinh paths).""" @@ -680,8 +643,23 @@ def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: continuous_vars = set(self.continuous_vars) categorical_ts = getattr(self.dataset, "categorical_time_series", {}) group_edges = getattr(self, "_group_bin_edges", {}) + # Pre-compute global fallback stats (used when row has no context columns) + _global_mu = np.full(self.time_series_dims, self.global_mu_mean.item(), dtype=np.float32) + _global_sigma = np.full( + self.time_series_dims, + max(float(np.exp(self.global_log_sigma_mean.item())), self.min_sigma), + dtype=np.float32, + ) + with torch.no_grad(): for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Inverse normalizing"): + # Unconditional path: no context columns present → use global stats directly + if not any(v in row for v in self.normalizer_static_vars + self.normalizer_dynamic_vars): + for d, col in enumerate(self.time_series_cols): + z = np.asarray(row[col], dtype=np.float32) + df_out.at[i, col] = (z * _global_sigma[d] + _global_mu[d]).tolist() + continue + static_context_vars_dict = {} dynamic_context_vars_dict = {} for v in self.context_vars: @@ -824,7 +802,6 @@ def _row_stats(row) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optio ) self._group_bin_edges[v] = np.asarray(edges, dtype=np.float64) df[v] = binned - print(group_vars, "group_vars") grouped = df.groupby(list(group_vars), dropna=False) for group_key, gdf in grouped: