diff --git a/.gitignore b/.gitignore index f1316a3..14c1d4e 100644 --- a/.gitignore +++ b/.gitignore @@ -106,8 +106,10 @@ ENV/ .*.swp # Repository Specific +runs/ cents/data/* cents/data/pecanstreet/* +cents/data/commercial/* cents/data/custom/ .DS_Store .ipynb_checkpoints diff --git a/cents/config/config.yaml b/cents/config/config.yaml index eb69f62..f779f20 100644 --- a/cents/config/config.yaml +++ b/cents/config/config.yaml @@ -1,31 +1,5 @@ -defaults: - - model: null - - dataset: pecanstreet - - evaluator: default - - trainer: null - - _self_ - device: auto -job_name: ${model.name}_${dataset.name}_${dataset.user_group} -run_dir: outputs/${job_name}/${now:%Y-%m-%d_%H-%M-%S} model_ckpt: null -hydra: - job_logging: - version: 1 - formatters: - simple: - format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - handlers: - console: - class: logging.StreamHandler - formatter: simple - level: INFO - root: - handlers: [console] - level: INFO - run: - dir: ${run_dir} - wandb: enabled: false diff --git a/cents/config/context/default.yaml b/cents/config/context/default.yaml new file mode 100644 index 0000000..e8fb2a7 --- /dev/null +++ b/cents/config/context/default.yaml @@ -0,0 +1,20 @@ +# Context configuration +# This file defines the context modules used across the codebase + +static_context: + type: mlp # Options: "mlp", "sep_mlp", "transformer" + # TransformerStaticContextModule hyperparameters (ignored by mlp/sep_mlp): + # n_heads: 4 + # n_layers: 2 + # dropout: 0.1 + # dim_feedforward: 256 + +# Normalizer: stats head configuration for the normalizer +normalizer: + stats_head_type: mlp # Stats head type (e.g., "mlp") + n_layers: 5 + # hidden_dim: 512 + +# Dynamic context: context module used by the normalizer for time series context variables +dynamic_context: + type: null # Context module type for dynamic context (e.g., "cnn") diff --git a/cents/config/dataset/airquality.yaml b/cents/config/dataset/airquality.yaml new file mode 100644 index 0000000..3ce51c2 --- /dev/null +++ b/cents/config/dataset/airquality.yaml @@ -0,0 +1,59 @@ +name: airquality +geography: null +normalize: True +scale: False +use_learned_normalizer: True +threshold: 8 +seq_len: 24 +shuffle: True +skip_heavy_processing: False +max_samples: null +path: "./data/airquality" +numeric_context_bins: 1 +reduce_cardinality: False +time_series_dims: 1 +normalizer_stats_mode: group +# Normalizer conditions only on these (e.g. per-station); diffusion still gets full context_vars +normalizer_group_vars: ["station"] + +# Targets (what becomes the merged "timeseries" dims) +# NOTE: use PMcoarse instead of PM10 +time_series_columns: ["PM2.5"] + +# Raw CSV columns to load +# Keep wd/WSPM because we need them to engineer wind_u/wind_v +# Keep PM10 because we need it to engineer PMcoarse +data_columns: + - "No" + - "year" + - "month" + - "day" + - "hour" + - "PM2.5" + - "PM10" + - "SO2" + - "NO2" + - "CO" + - "TEMP" + - "DEWP" + - "PRES" + - "RAIN" + - "WSPM" + - "wd" + - "station" + +context_vars: + # static categorical + year: ["categorical", 5] + month: ["categorical", 12] + weekday: ["categorical", 7] + station: ["categorical", 12] + + # dynamic time-series context + TEMP: ["time_series", null] + DEWP: ["time_series", null] + PRES: ["time_series", null] + RAIN: ["time_series", null] + wind_u: ["time_series", null] + wind_v: ["time_series", null] + wd_valid: ["time_series", null] \ No newline at end of file diff --git a/cents/config/dataset/commercial.yaml b/cents/config/dataset/commercial.yaml new file mode 100644 index 0000000..81fd583 --- /dev/null +++ b/cents/config/dataset/commercial.yaml @@ -0,0 +1,30 @@ +name: commercial +geography: null +user_group: all +normalize: True +scale: False +use_learned_normalizer: True +threshold: 8 +seq_len: 24 +time_series_dims: 1 +shuffle: True +skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP) +max_samples: null # Limit dataset size (null = use all data) +path: "./data/commercial/csv" +time_series_columns: "energy_meter" +data_columns: ["dataid","energy_meter","timestamp"] +metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sqft", "yearbuilt", "sub_primaryspaceusage"] +numeric_context_bins: 5 +reduce_cardinality: False +normalizer_stats_mode: group +normalizer_group_vars: null + +context_vars: + year: ["categorical", 2] + month: ["categorical", 12] + weekday: ["categorical", 7] + site_id: ["categorical", 19] + primaryspaceusage: ["categorical", 16] + sqft: ["categorical", null] + yearbuilt: ["categorical", null] + sub_primaryspaceusage: ["categorical", 104] \ No newline at end of file diff --git a/cents/config/dataset/default.yaml b/cents/config/dataset/default.yaml index 13c28be..14f65d8 100644 --- a/cents/config/dataset/default.yaml +++ b/cents/config/dataset/default.yaml @@ -1,13 +1,15 @@ -name: default -normalize: True -scale: True -use_learned_normalizer: True -shuffle: True -threshold: 6 -time_series_dims: 1 -time_series_columns: [] -seq_len: 8 -user_group: null +# name: default +# normalize: True +# scale: True +# use_learned_normalizer: True +# shuffle: True +# threshold: 6 +# time_series_dims: 1 +# time_series_columns: [] +# seq_len: 8 +# user_group: null -numeric_context_bins: 5 -context_vars: {} +# numeric_context_bins: 5 +# context_vars: {} # Dict mapping variable names to category counts (for categorical) or placeholders (for continuous) +# continuous_context_vars: [] # Optional: list of variable names that should be kept as continuous (not binned) +# stats_head_type: mlp diff --git a/cents/config/dataset/metraq.yaml b/cents/config/dataset/metraq.yaml new file mode 100644 index 0000000..439f762 --- /dev/null +++ b/cents/config/dataset/metraq.yaml @@ -0,0 +1,54 @@ +name: metraq +geography: null +normalize: True +scale: False +use_learned_normalizer: True +threshold: 8 +seq_len: 24 +shuffle: True +skip_heavy_processing: False +max_samples: null +path: "./data/metraq" +numeric_context_bins: 1 +reduce_cardinality: False +time_series_dims: 1 +normalizer_stats_mode: group +# Normalizer conditions only on these (e.g. per-station); diffusion still gets full context_vars +normalizer_group_vars: ["sensor_name"] +max_z_threshold: 15.0 + +# Targets (what becomes the merged "timeseries" dims) +# NOTE: use PMcoarse instead of PM10 +time_series_columns: ["PM2.5"] + +# Raw CSV columns to load +# Keep wd/WSPM because we need them to engineer wind_u/wind_v +# Keep PM10 because we need it to engineer PMcoarse +data_columns: + - "entry_date" + - "magnitude_name" + - "sensor_name" + - "value" + # - "utm_x" + # - "utm_y" + +context_vars: + # static categorical + year: ["categorical", 6] + month: ["categorical", 12] + weekday: ["categorical", 7] + sensor_name: ["categorical", 24] + # utm_x: ["continuous", null] + # utm_y: ["continuous", null] + + # dynamic time-series context + # WS and WD are decomposed into wind_u/wind_v in preprocessing to handle + # the circularity of wind direction (WD=355° ≈ WD=5°, but z-score would give opposite signs). + T: ["time_series", null] + # wind_u: ["time_series", null] + # wind_v: ["time_series", null] + # wd_valid: ["time_series", null] + # RH, AP, R dropped — per-sample correlation with PM2.5 < 0.025 across all stations + # Traffic: TI = vehicles/hour (Kriging interpolation); SP = avg speed km/h + TI: ["time_series", null] + SP: ["time_series", null] \ No newline at end of file diff --git a/cents/config/dataset/pecanstreet.yaml b/cents/config/dataset/pecanstreet.yaml index ab513cf..bc1139a 100644 --- a/cents/config/dataset/pecanstreet.yaml +++ b/cents/config/dataset/pecanstreet.yaml @@ -7,20 +7,24 @@ threshold: 8 seq_len: 96 time_series_dims: 1 shuffle: True +skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP) +max_samples: null # Limit dataset size (null = use all data) path: "./data/pecanstreet/csv" -time_series_columns: ["grid", "solar"] +time_series_columns: ["grid"] data_columns: ["dataid","local_15min","car1","grid","solar"] metadata_columns: ["dataid","building_type","solar","car1","city","state","total_square_footage","house_construction_year"] user_group: all # non_pv_users, all, pv_users numeric_context_bins: 5 +normalizer_stats_mode: group -context_vars: # for each desired context variable, add the name and number of categories - month: 12 - weekday: 7 - building_type: 3 - has_solar: 2 # note that the metadata csv file column name is 'solar', which is renamed to avoid conflicts with the 'solar' column in the data csv. - car1: 2 - city: 7 - state: 3 - total_square_footage: 5 - house_construction_year: 5 + +context_vars: + month: ["categorical", 12] + weekday: ["categorical", 7] + building_type: ["categorical", 3] + has_solar: ["categorical", 2] + car1: ["categorical", 2] + city: ["categorical", 7] + state: ["categorical", 3] + total_square_footage: ["categorical", null] + house_construction_year: ["categorical", null] \ No newline at end of file diff --git a/cents/config/dataset/walmart.yaml b/cents/config/dataset/walmart.yaml new file mode 100644 index 0000000..41c0a05 --- /dev/null +++ b/cents/config/dataset/walmart.yaml @@ -0,0 +1,37 @@ +name: walmart +geography: null +normalize: True +scale: False +use_learned_normalizer: True +threshold: 8 +seq_len: 28 +shuffle: True +skip_heavy_processing: False +max_samples: null +path: "./data/walmart" +numeric_context_bins: 1 +reduce_cardinality: False +time_series_dims: 1 +normalizer_stats_mode: group +# Normalizer conditions on category × store to capture per-group sales distributions +normalizer_group_vars: ["cat_id", "store_id"] +max_z_threshold: 15.0 + +# Target: daily unit sales +time_series_columns: ["sales"] + +context_vars: + # Static categorical — characterise the window by when it starts + year: ["categorical", 6] # 2011–2016 + month: ["categorical", 12] + # Static categorical — item / store identity + cat_id: ["categorical", 3] # FOODS, HOBBIES, HOUSEHOLD + dept_id: ["categorical", 7] # e.g. FOODS_1 … HOUSEHOLD_2 + store_id: ["categorical", 10] # CA_1 … WI_3 + state_id: ["categorical", 3] # CA, TX, WI + + # Dynamic time-series context (co-occurring with target, length = seq_len) + sell_price: ["time_series", null] # weekly price broadcast to daily, z-scored + snap: ["time_series", null] # binary SNAP eligibility for the item's state + event_binary: ["time_series", null] # 1 if a named calendar event falls on that day + weekday: ["time_series", null] # day of week encoded as 0 (Mon) – 6 (Sun), z-scored diff --git a/cents/config/evaluator/airquality.yaml b/cents/config/evaluator/airquality.yaml new file mode 100644 index 0000000..9b8017e --- /dev/null +++ b/cents/config/evaluator/airquality.yaml @@ -0,0 +1,32 @@ +model: + name: diffusion_ts +dataset: + name: airquality +eval_pv_shift: False +eval_metrics: True +eval_context_sparse: True +save_results: False +eval_disentanglement: True +eval_context_recovery: True +job_name: diffusion_ts_airquality +save_dir: outputs/diffusion_ts_airquality/eval + +# Context Faithfulness Score (CFS) and Granger Causality Preservation (GCP). +# Runs only when enabled=True AND either: +# - the generated signal has multiple dimensions (multivariate), OR +# - the dataset uses dynamic (time-series) context variables. +# +# pairs: list of {x, c} dicts specifying which time series to evaluate against each other. +# x — name of a generated output dimension (must match time_series_columns in dataset config) +# c — name of a dynamic context variable (from context_vars with type "time_series") +# OR another generated output dimension (multivariate case, GCP only) +# +# CFS is computed only when c is a dynamic context variable (shared context). +# GCP is computed for all pairs. +# +eval_context_faithfulness: + enabled: true + gcp_max_lag: 5 + pairs: + - {x: "PM2.5", c: "TEMP"} + - {x: "PM2.5", c: "DEWP"} diff --git a/cents/config/evaluator/default.yaml b/cents/config/evaluator/default.yaml index 69ca339..f39eca1 100644 --- a/cents/config/evaluator/default.yaml +++ b/cents/config/evaluator/default.yaml @@ -1,7 +1,40 @@ -model_name: ${model.name} +model: + name: diffusion_ts # Set this to your model name +dataset: + name: commercial # Set this to your dataset name (e.g., "commercial") eval_pv_shift: False eval_metrics: True +pred_score_trtr: True # If True, also trains on real data (TRTR) and reports MAE delta alongside TSTR MAE eval_context_sparse: True save_results: False eval_disentanglement: True -save_dir: ${run_dir}/eval +eval_context_recovery: True +job_name: diffusion_ts_commercial +save_dir: outputs/diffusion_ts_commercial/eval + +# Context Faithfulness Score (CFS) and Granger Causality Preservation (GCP). +# Runs only when enabled=True AND either: +# - the generated signal has multiple dimensions (multivariate), OR +# - the dataset uses dynamic (time-series) context variables. +# +# pairs: list of {x, c} dicts specifying which time series to evaluate against each other. +# x — name of a generated output dimension (must match time_series_columns in dataset config) +# c — name of a dynamic context variable (from context_vars with type "time_series") +# OR another generated output dimension (multivariate case, GCP only) +# +# CFS is computed only when c is a dynamic context variable (shared context). +# GCP is computed for all pairs. +# +# Example for airquality dataset (PM2.5 generated, TEMP/DEWP as context): +# pairs: +# - {x: "PM2.5", c: "TEMP"} +# - {x: "PM2.5", c: "DEWP"} +# +eval_context_faithfulness: + enabled: true + gcp_max_lag: 5 + pairs: + - {x: "PM2.5", c: "T"} + - {x: "PM2.5", c: "TI"} + - {x: "PM2.5", c: "SP"} + # - {x: "PM2.5", c: "TEMP"} diff --git a/cents/config/evaluator/walmart.yaml b/cents/config/evaluator/walmart.yaml new file mode 100644 index 0000000..5785d0a --- /dev/null +++ b/cents/config/evaluator/walmart.yaml @@ -0,0 +1,34 @@ +model: + name: diffusion_ts +dataset: + name: walmart +eval_pv_shift: False +eval_metrics: True +eval_context_sparse: True +save_results: False +eval_disentanglement: True +eval_context_recovery: True +job_name: diffusion_ts_walmart +save_dir: outputs/diffusion_ts_walmart/eval + +# Context Faithfulness Score (CFS) and Granger Causality Preservation (GCP). +# Runs only when enabled=True AND either: +# - the generated signal has multiple dimensions (multivariate), OR +# - the dataset uses dynamic (time-series) context variables. +# +# pairs: list of {x, c} dicts specifying which time series to evaluate against each other. +# x — name of a generated output dimension (must match time_series_columns in dataset config) +# c — name of a dynamic context variable (from context_vars with type "time_series") +# OR another generated output dimension (multivariate case, GCP only) +# +# CFS is computed only when c is a dynamic context variable (shared context). +# GCP is computed for all pairs. +# +eval_context_faithfulness: + enabled: true + gcp_max_lag: 5 + pairs: + - {x: "sales", c: "sell_price"} + - {x: "sales", c: "snap"} + - {x: "sales", c: "event_binary"} + - {x: "sales", c: "weekday"} diff --git a/cents/config/model/diffusion_ts.yaml b/cents/config/model/diffusion_ts.yaml index 396856d..911005c 100644 --- a/cents/config/model/diffusion_ts.yaml +++ b/cents/config/model/diffusion_ts.yaml @@ -2,16 +2,19 @@ _target_: generator.diffusion_ts.gaussian_diffusion.Diffusion_TS name: diffusion_ts context_reconstruction_loss_weight: 0.1 tc_loss_weight: 0 -noise_dim: 256 -cond_emb_dim: 16 +noise_dim: 128 +cond_emb_dim: 64 n_layer_enc: 4 n_layer_dec: 5 d_model: 128 n_steps: 1000 -sampling_timesteps: 1000 +sampling_timesteps: 200 sampling_batch_size: 4096 loss_type: l1 #l2 -beta_schedule: cosine #linear +training_objective: v +loss_weighting: snr +min_snr_gamma: 5.0 +beta_schedule: cosine #linear diffusion ts paper uses linear schedule n_heads: 4 mlp_hidden_times: 4 eta: 0.0 @@ -22,6 +25,19 @@ padding_size: null use_ff: True reg_weight: null gradient_accumulate_every: 2 -ema_decay: 0.99 -ema_update_interval: 10 +ema_decay: 0.999 +ema_update_interval: 1 use_ema_sampling: False +k_bins: 20 +# Reconstruction-guided sampling (Algorithms 1 & 2) +recon_guide_eta: 0.1 # gradient scale for guidance +recon_guide_gamma: 1.0 # trade-off L1 vs L2 (L1 + gamma*L2) +recon_guide_algorithm: none # none | alg1 | alg2 +recon_guide_K: 3 # inner steps per t for alg2 (int or list for K[t]) +# Optional: dual head for x̂_a / x̂_b (set to cond_len used in recon-guided sampling) +recon_cond_len: null # int or null; if set, use fc_a / fc_b for first vs rest of sequence +# Context embedding dropout (training only) for more robust recon-guided sampling +context_embed_dropout: 0 # probability of zeroing entire context embedding per sample (CFG-compatible) +blue_noise_power: 0.0 # 0.0 = white noise, 1.0 = blue noise, 2.0 = violet noise +# When true and time_series_dims > 1, noise is correlated across dimensions (same draw per timestep). +correlated_noise: False \ No newline at end of file diff --git a/cents/config/trainer/acgan.yaml b/cents/config/trainer/acgan.yaml index 3034158..4795c3e 100644 --- a/cents/config/trainer/acgan.yaml +++ b/cents/config/trainer/acgan.yaml @@ -2,7 +2,7 @@ precision: "16-mixed" accelerator: auto devices: auto strategy: ddp_find_unused_parameters_true -max_epochs: 5000 +max_epochs: 1000 batch_size: 1024 sampling_batch_size: 4096 gradient_accumulate_every: 1 diff --git a/cents/config/trainer/diffusion_ts.yaml b/cents/config/trainer/diffusion_ts.yaml index 4f3e701..1fb90f2 100644 --- a/cents/config/trainer/diffusion_ts.yaml +++ b/cents/config/trainer/diffusion_ts.yaml @@ -2,23 +2,31 @@ precision: "16-mixed" accelerator: auto devices: auto strategy: ddp_find_unused_parameters_true -gradient_accumulate_every: 2 +gradient_accumulate_every: 4 log_every_n_steps: 1 -batch_size: 1024 +batch_size: 512 max_epochs: 5000 base_lr: 1e-4 +warmup_epochs: 100 # linear warmup from 1% to 100% of base_lr over first N epochs eval_after_training: False checkpoint: - save_last: False - save_top_k: 0 + save_last: True # Save final model + save_top_k: 0 # 0 = only periodic saves; use >0 to also save top-k by metric every_n_train_steps: null - every_n_epochs: null + every_n_epochs: 250 # Save a distinct checkpoint every 250 epochs (250, 500, 750, ...) lr_scheduler_params: factor: 0.5 patience: 200 - min_lr: 1.0e-5 + min_lr: 1.0e-6 threshold: 1.0e-1 threshold_mode: rel verbose: false + +intermediate_fid: + enabled: True + every_n_epochs: 20 # check context-FID every N epochs + n_samples: 3500 # number of samples to generate (subsample of dataset) + fast_timesteps: 50 # DDIM steps to use (vs 200 default) for speed + top_k: 3 # keep top-k FID checkpoints + their neighbors diff --git a/cents/config/trainer/normalizer.yaml b/cents/config/trainer/normalizer.yaml index 666ded6..f34db7f 100644 --- a/cents/config/trainer/normalizer.yaml +++ b/cents/config/trainer/normalizer.yaml @@ -1,17 +1,28 @@ -strategy: auto -accelerator: auto -devices: 1 +strategy: ddp_find_unused_parameters_true +accelerator: gpu +devices: 1, log_every_n_steps: 1 hidden_dim: 512 embedding_dim: 256 -n_epochs: 2000 +n_epochs: 500 batch_size: 4096 lr: 3e-4 save_cycle: 5000 eval_after_training: False +loss_type: mse # Options: "mse" or "gaussian_nll" + +# If true (default), targets are scaled by global mu/sigma stats and network predicts in scaled space (then unscaled in forward). +# If false, network predicts mu in asinh space and log(sigma) directly (no global preprocessing); can be more stable for some datasets (e.g. commercial). +use_global_stats_preprocessing: true +# When use_global_stats_preprocessing=false: mu is predicted as asinh(mu); clamp to [-max_asinh_mu, max_asinh_mu]. sinh(10) ~ 11013. +max_asinh_mu: 10 + +# Floor predicted sigma and scale range so (x - mu) / sigma and (z - zmin) / rng never explode +min_sigma: 0.1 # sigma_effective = max(pred_sigma, min_sigma) +min_scale_range: 0.25 # rng_effective = max(zmax - zmin, min_scale_range) checkpoint: save_last: False save_top_k: 0 every_n_train_steps: null - every_n_epochs: null + every_n_epochs: 500 diff --git a/cents/data_generator.py b/cents/data_generator.py index a77aa18..30f1ae8 100644 --- a/cents/data_generator.py +++ b/cents/data_generator.py @@ -7,8 +7,7 @@ import pytorch_lightning as pl import torch from huggingface_hub import hf_hub_download -from hydra import compose, initialize_config_dir -from omegaconf import DictConfig, ListConfig +from omegaconf import DictConfig, ListConfig, OmegaConf import cents.models from cents.datasets.utils import convert_generated_data_to_df @@ -18,7 +17,9 @@ get_device, get_normalizer_training_config, parse_dims_from_name, + get_context_config, ) +from cents.utils.config_loader import load_yaml, apply_overrides PKG_ROOT = Path(__file__).resolve().parent CONF_DIR = PKG_ROOT / "config" @@ -46,16 +47,18 @@ class DataGenerator: def __init__( self, - model_name: str, + model_name: str = None, + model_type: str = None, device: str = None, cfg: DictConfig = None, + dataset = None, model: Optional[pl.LightningModule] = None, normalizer: Optional[Normalizer] = None, ): self.device = get_device(device) self.cfg = cfg self._ctx_buff: Dict[str, torch.Tensor] = {} - + self.dataset = dataset if model is not None: self.model = model.to(self.device).eval() self.normalizer = normalizer @@ -71,28 +74,40 @@ def __init__( self.model = None self.normalizer = None self.load_pretrained(model_name) + elif model_type is not None: + # Init without loading - user will call load_from_checkpoint separately + self.model_type = model_type + self.model = None + self.normalizer = None + if dataset is not None: + self.cfg = cfg or OmegaConf.create({}) + if not hasattr(self.cfg, 'dataset'): + self.cfg.dataset = dataset.cfg + if not hasattr(self.cfg, 'model'): + self.cfg.model = load_yaml(CONF_DIR / "model" / f"{model_type}.yaml") + self.set_dataset_spec( + self.cfg.dataset, self._read_ctx_codes(self.cfg.dataset.name) + ) else: - raise ValueError("Must provide either model_name or model instance.") + raise ValueError("Must provide either model_name, model_type, or model instance.") def _default_cfg(self) -> DictConfig: """ - Load the default Hydra config for this model_name. - - Returns: - Composed DictConfig from 'config/config.yaml'. + Build a minimal default config (model + dataset) without Hydra. """ - # Extract dimensions from model name dims = parse_dims_from_name(self.model_name) time_series_dims = int(dims) - with initialize_config_dir(str(CONF_DIR), version_base=None): - return compose( - config_name="config", - overrides=[ - f"model={self.model_type}", - f"dataset.time_series_dims={time_series_dims}", - ], - ) + model_cfg = load_yaml(CONF_DIR / "model" / f"{self.model_type}.yaml") + dataset_cfg = load_yaml(CONF_DIR / "dataset" / "default.yaml") + dataset_cfg = apply_overrides( + dataset_cfg, [f"time_series_dims={time_series_dims}"] + ) + + cfg = OmegaConf.create({}) + cfg.model = model_cfg + cfg.dataset = dataset_cfg + return cfg def set_dataset_spec( self, dataset_cfg: DictConfig, ctx_codes: Dict[str, Dict[int, str]] @@ -107,13 +122,15 @@ def set_dataset_spec( self.cfg.dataset = dataset_cfg self.ctx_code_book = ctx_codes - def set_context(self, auto_fill_missing: bool = False, **context_vars: int): + def set_context(self, auto_fill_missing: bool = False, **context_vars: Union[int, float]): """ Define a context vector for subsequent generation calls. Args: auto_fill_missing: If True, randomly sample missing context variables. **context_vars: Named codes for each context variable. + For categorical variables: integer codes (int). + For continuous variables: float values (float). Raises: RuntimeError: If dataset spec has not been set. @@ -125,33 +142,46 @@ def set_context(self, auto_fill_missing: bool = False, **context_vars: int): ) required = self.cfg.dataset.context_vars + continuous_vars = getattr(self.cfg.dataset, "continuous_context_vars", None) or [] if auto_fill_missing: for var, n in required.items(): - context_vars.setdefault(var, random.randrange(n)) - else: + if var in continuous_vars: + # For continuous variables, sample from a reasonable range + # This is a simple default - users should provide actual values + context_vars[var] = random.uniform(0.0, 1.0) + else: + context_vars.setdefault(var, random.randrange(n)) + elif context_vars: missing = set(required) - set(context_vars) if missing: raise ValueError(f"Missing context vars: {missing}") + self._ctx_buff = {} for var, code in context_vars.items(): - max_cat = self.cfg.dataset.context_vars[var] - if not (0 <= code < max_cat): - raise ValueError( - f"Context '{var}' must be in [0, {max_cat}); got {code}." - ) - - self._ctx_buff = { - var: torch.tensor(code, device=self.device) - for var, code in context_vars.items() - } + if var in continuous_vars: + # Continuous variables: use float tensor (no validation needed) + self._ctx_buff[var] = torch.tensor(code, dtype=torch.float32, device=self.device) + else: + # Categorical variables: validate and use long tensor + if var in required: + max_cat = required[var] + if not (0 <= code < max_cat): + raise ValueError( + f"Context '{var}' must be in [0, {max_cat}); got {code}." + ) + self._ctx_buff[var] = torch.tensor(code, dtype=torch.long, device=self.device) @torch.no_grad() - def generate(self, n: int = 128) -> "pd.DataFrame": + def generate(self, n: int = 128, stochastic_round: bool = False) -> "pd.DataFrame": """ Produce n synthetic samples under the previously set context. Args: n: Number of samples to generate. + stochastic_round: If True, round output to non-negative integers using + stochastic rounding (floor + bernoulli on fractional part). Applied + after inverse_transform so normalizer stats are unaffected. Useful + for count-valued datasets (e.g. Walmart unit sales). Returns: DataFrame with context columns + 'timeseries'. @@ -163,13 +193,24 @@ def generate(self, n: int = 128) -> "pd.DataFrame": raise RuntimeError( "No model loaded. Call `load_from_checkpoint(...)` first." ) - if not self._ctx_buff: - raise RuntimeError("No context set – call `set_context()` first.") ctx_batch = {k: v.repeat(n) for k, v in self._ctx_buff.items()} - ts = self.model.generate(ctx_batch) + ts = self.model.generate(ctx_batch, n=n) df = convert_generated_data_to_df(ts, self._ctx_buff, decode=False) - return self.normalizer.inverse_transform(df) if self.normalizer else df + df = self.normalizer.inverse_transform(df) if self.normalizer else df + + if stochastic_round: + import numpy as np + + def _stochastic_round(x): + x = np.clip(x, 0, None) + floor = np.floor(x).astype(int) + frac = x - floor + return floor + (np.random.random(x.shape) < frac).astype(int) + + df["timeseries"] = df["timeseries"].apply(_stochastic_round) + + return df def load_from_checkpoint( self, @@ -195,7 +236,9 @@ def load_from_checkpoint( ckpt_path, state = self._resolve_ckpt(model_ckpt) ModelCls = get_model_cls(self.model_type) + if ckpt_path.suffix == ".ckpt": + print(f"[Cents] Loading model from checkpoint: {ckpt_path}") self.model = ( ModelCls.load_from_checkpoint( cfg=self.cfg, @@ -217,7 +260,8 @@ def load_from_checkpoint( self.normalizer = Normalizer( dataset_cfg=self.cfg.dataset, normalizer_training_cfg=get_normalizer_training_config(), - dataset=None, + dataset=self.dataset, + context_cfg=get_context_config(), ) state = torch.load(normalizer_ckpt, map_location=device) sd = state.get("state_dict", state) diff --git a/cents/datasets/airquality.py b/cents/datasets/airquality.py new file mode 100644 index 0000000..0f9f9eb --- /dev/null +++ b/cents/datasets/airquality.py @@ -0,0 +1,257 @@ +from ast import Str +import os +import warnings +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides +from cents.datasets.utils import is_all_nan, is_any_nan, fill_with_row_mean + +from cents.datasets.timeseries_dataset import TimeSeriesDataset + +warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) +# These are warnings for an error that is accounted for in the code +warnings.filterwarnings("ignore", category=RuntimeWarning) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +class AirQualityDataset(TimeSeriesDataset): + def __init__(self, cfg: DictConfig = None, + overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None): + """ + Initializes the AirQuality Dataset. Available at: + https://doi.org/10.24432/C5RK5G. + Hourly Air Quality at Multiple Sites in China, in many measures. + Accompanying location and weather information. + """ + + if cfg is None: + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "airquality.yaml")) + + if overrides: + cfg = apply_overrides(cfg, overrides) + + self.cfg = cfg + self.name = cfg.name + self.normalize = cfg.normalize + if isinstance(cfg.time_series_columns, str): + cfg.time_series_columns = [cfg.time_series_columns] + self.target_time_series_columns = cfg.time_series_columns + self.geography = cfg.geography + self.time_series_dims = cfg.time_series_dims + + self.city_names = ["Aotizhongxin", "Changping", "Dingling", "Dongsi", "Guanyuan", + "Gucheng", "Huairou", "Nongzhanguan", "Shunyi", "Tiantan", "Wanliu", "Wanshouxigong"] + self.context_time_series_columns = {k:v[1] for k,v in self.cfg.context_vars.items() if v[0] == "time_series"} + self.context_series_names = list(self.context_time_series_columns.keys()) + + self.categorical_time_series = { + k: v[1] for k, v in self.cfg.context_vars.items() + if v[0] == "time_series" and v[1] is not None + } + self._load_data() + + super().__init__( + data=self.data, + time_series_column_names=self.target_time_series_columns, + context_var_column_names=list(self.cfg.context_vars.keys()), + seq_len=self.cfg.seq_len, + normalize=self.cfg.normalize, + scale=self.cfg.scale, + skip_heavy_processing=cfg.get('skip_heavy_processing', False), + size=cfg.get('max_samples', None), + categorical_time_series=self.categorical_time_series, + force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, + ) + + def _load_data(self): + """ + Loads in metadata and data for commercial energy dataset. + Categorial CVs: Year, Month, Day, Location + Context Time Series: Temperature, Pressure, Dewpoint, Precipitation, Wind + """ + + module_dir = os.path.dirname(os.path.abspath(__file__)) + path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) + + if not self.geography: self.geography = self.city_names + + self.geography = [self.geography] if isinstance(self.geography, str) else self.geography + dfs = [] + for name in self.geography: + fname = f"PRSA_Data_{name}_20130301-20170228.csv" + data_path = os.path.join(path, fname) + + if not os.path.exists(data_path): + raise FileNotFoundError(f"Data file not found at {data_path}") + + dfs.append(pd.read_csv(data_path)[self.cfg.data_columns]) + + self.data = pd.concat(dfs, axis=0) + + def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: + data = data.copy() + + # Timestamp + weekday + data["timestamp"] = pd.to_datetime(data[["year", "month", "day", "hour"]]) + data["weekday"] = data["timestamp"].dt.day_name() + + # ------------------------- + # Engineer wind_u / wind_v + # ------------------------- + wd_deg_map = { + "N": 0.0, "NNE": 22.5, "NE": 45.0, "ENE": 67.5, + "E": 90.0, "ESE": 112.5, "SE": 135.0, "SSE": 157.5, + "S": 180.0, "SSW": 202.5, "SW": 225.0, "WSW": 247.5, + "W": 270.0, "WNW": 292.5, "NW": 315.0, "NNW": 337.5, + } + + has_wd = "wd" in data.columns + has_wspm = "WSPM" in data.columns + if has_wd and has_wspm: + wd_clean = data["wd"].astype(str).str.strip().str.upper() + wd_deg = wd_clean.map(wd_deg_map) + + data["wd_valid"] = wd_deg.notna().astype(np.int8) + + theta = np.deg2rad(wd_deg.fillna(0.0).to_numpy(dtype=float)) + wspm = pd.to_numeric(data["WSPM"], errors="coerce").fillna(0.0) + + data["wind_u"] = wspm * np.cos(theta) + data["wind_v"] = wspm * np.sin(theta) + + data.drop(columns=["wd", "WSPM"], inplace=True) + else: + if "wd" in data.columns: + data.drop(columns=["wd"], inplace=True) + if "WSPM" in data.columns: + data.drop(columns=["WSPM"], inplace=True) + + # ------------------------- + # Engineer PMcoarse + # ------------------------- + if "PM10" in data.columns and "PM2.5" in data.columns: + pm10 = pd.to_numeric(data["PM10"], errors="coerce") + pm25 = pd.to_numeric(data["PM2.5"], errors="coerce") + data["PMcoarse"] = (pm10 - pm25).clip(lower=0.0) + + # ------------------------- + # Choose time-series columns + # ------------------------- + ctx_ts = list(self.context_series_names) + tgt_ts = list(self.target_time_series_columns) + + ctx_ts = [c for c in ctx_ts if c not in ("wd", "WSPM")] + for c in ("wind_u", "wind_v"): + if c in data.columns and c not in ctx_ts: + ctx_ts.append(c) + if "wd_valid" in data.columns and "wd_valid" not in ctx_ts: + ctx_ts.append("wd_valid") + + if "PMcoarse" in data.columns: + tgt_ts = ["PMcoarse" if c == "PM10" else c for c in tgt_ts] + tgt_ts = [c for c in tgt_ts if c != "PM10"] + + + + ts_cols = ctx_ts + tgt_ts + + missing = [c for c in ts_cols if c not in data.columns] + if missing: + raise ValueError(f"Missing required time-series columns after preprocessing: {missing}") + + data = data.sort_values(["station", "year", "month", "day", "hour"]) + + months = [ + "January", "February", "March", "April", "May", "June", + "July", "August", "September", "October", "November", "December", + ] + data["month"] = data["month"].map(lambda x: months[x - 1]) + + group_keys = ["station", "year", "month", "day", "weekday"] + + grouped = ( + data.groupby(group_keys, as_index=False, sort=False) + .agg({c: list for c in ts_cols}) + ) + + for c in ts_cols: + grouped[c] = grouped[c].map(np.asarray) + + + len_col = tgt_ts[0] if len(tgt_ts) > 0 else ts_cols[0] + grouped = grouped[grouped[len_col].apply(len) == self.cfg.seq_len].reset_index(drop=True) + + grouped = self._handle_missing_data(grouped) + + ctx_numeric = [c for c in ctx_ts if c not in self.categorical_time_series] + log1p_channels = {"RAIN"} # add more if needed + + clip_bound = 5.0 + eps = 1e-8 + ctx_stats = {} + for c in ctx_numeric: + # stacked shape: (N, L) + X = np.stack(grouped[c].values).astype(np.float32) + + if c in log1p_channels: + X = np.log1p(np.clip(X, a_min=0.0, a_max=None)) + + mu = float(X.mean()) + sd = float(X.std()) + if sd < 1e-6: + sd = 1.0 # avoid divide-by-zero; effectively makes it "center only" + ctx_stats[c] = (mu, sd) + + Xn = (X - mu) / (sd + eps) + Xn = np.clip(Xn, -clip_bound, clip_bound).astype(np.float32) + + grouped[c] = list(Xn) + + self.context_ts_stats_ = ctx_stats + + for c in ts_cols: + grouped[c] = grouped[c].map(tuple) + + return grouped + + + def _handle_missing_data(self, data): + numeric_series = [c for c in self.context_series_names if c not in self.categorical_time_series] + + mask = data[numeric_series].applymap(is_all_nan).any(axis=1) if numeric_series else pd.Series([False] * len(data)) + data = data[~mask] + + for col in numeric_series: + data[col] = data[col].apply(fill_with_row_mean) + + cat_cols = list(self.categorical_time_series.keys()) + if cat_cols: + mask = data[cat_cols].applymap(is_any_nan).any(axis=1) + data = data[~mask] + + for tcol in self.target_time_series_columns: + if tcol in data.columns: + data = data.loc[data[tcol].apply(lambda x: not np.isnan(np.asarray(x, dtype=float)).any())] + + def row_has_low_std(row, cols, thresh=0.01): + for c in cols: + arr = np.asarray(row[c], dtype=np.float32) + if arr.std() < thresh: + return True + return False + + mask = data.apply( + lambda row: row_has_low_std(row, self.target_time_series_columns, thresh=0.01), + axis=1 + ) + + data = data[~mask] + return data + + + diff --git a/cents/datasets/commercial.py b/cents/datasets/commercial.py new file mode 100644 index 0000000..75739b8 --- /dev/null +++ b/cents/datasets/commercial.py @@ -0,0 +1,258 @@ +import os +import warnings +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides + +from cents.datasets.timeseries_dataset import TimeSeriesDataset + +warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) +# These are warnings for an error that is accounted for in the code +warnings.filterwarnings("ignore", category=RuntimeWarning) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +class CommercialDataset(TimeSeriesDataset): + def __init__(self, cfg: DictConfig = None, + overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None): + + """ + Initializes the commercial energy dataset. + Original Dataset: https://github.com/buds-lab/building-data-genome-project-2 + Note: This uses the clean version of their data, which already has some preprocessing done. + Args: + cfg (Optional[DictConfig]): Override Hydra config; if None, + load from `config/dataset/commercial.yaml`. + overrides (Optional[List[str]]): Override Hydra config; if None, + load from `config/dataset/commercial.yaml` and apply overrides. + """ + + if cfg is None: + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "commercial.yaml")) + if overrides: + cfg = apply_overrides(cfg, overrides) + + self.cfg = cfg + self.name = cfg.name + self.normalize = cfg.normalize + self.cfg.time_series_columns = ["energy_meter"] + self.geography = cfg.geography + + self._load_data() + + self.time_series_column_names = self.cfg.time_series_columns + self.time_series_dims = len(self.time_series_column_names) + + super().__init__( + data=self.data, + time_series_column_names=self.time_series_column_names, + seq_len=self.cfg.seq_len, + context_var_column_names=list(self.cfg.context_vars.keys()), + normalize=self.cfg.normalize, + scale=self.cfg.scale, + skip_heavy_processing=cfg.get('skip_heavy_processing', False), + size=cfg.get('max_samples', None), + force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, + ) + + def _load_data(self): + """ + Loads in metadata and data for the commercial energy dataset. + """ + module_dir = os.path.dirname(os.path.abspath(__file__)) + base_path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) + + metapath = os.path.join(base_path, "metadata.csv") + if not os.path.exists(metapath): + raise FileNotFoundError(f"Metadata file not found at {metapath}") + metadata = pd.read_csv(metapath, usecols=self.cfg.metadata_columns) + + data_path = os.path.join(base_path, "electricity_cleaned.csv") + if not os.path.exists(data_path): + raise FileNotFoundError(f"Data file not found at {data_path}") + data = pd.read_csv(data_path) + + data = data.melt( + id_vars="timestamp", # keep timestamp as is + var_name="dataid", # old column names (id1, id2, etc.) become values in this column + value_name="energy_meter" # their values go here + ) + + data['site_id'] = data['dataid'].str.split('_').str[0] + + if self.geography: + if self.geography not in metadata["site_id"].unique(): + raise ValueError(f"Geography {self.geography} not found in metadata") + data = data[data["site_id"] == self.geography] + metadata = metadata[metadata["site_id"] == self.geography] + + max_samples = self.cfg.get('max_samples', None) + if max_samples is not None: + unique_ids = data['dataid'].unique() + # Each building produces ~700 daily sequences; 2x buffer accounts for filtering + est_seqs_per_id = 700 + n_ids = min(len(unique_ids), max(1, int(np.ceil(max_samples * 2.0 / est_seqs_per_id)))) + kept_ids = np.random.choice(unique_ids, size=n_ids, replace=False) + data = data[data['dataid'].isin(kept_ids)] + print(f"CommercialDataset: subsampling {n_ids} building IDs (est. ~{n_ids * est_seqs_per_id} sequences) for max_samples={max_samples}") + + self.data = data + self.metadata = metadata + + def _preprocess_data(self, data): + """ + Creates sequences of seq_len and merges metadata. Removes any sequences with missing data. + + Args: + data (pd.DataFrame): Raw DataFrame including 'energy_meter' values. + + Returns: + pd.DataFrame: Metadata columns, datetime, year, month, weekday, date_day, and array-valued 'energy_meter'. + """ + data = data.copy() + + data['datetime'] = pd.to_datetime(data['timestamp']) + # data = data.dropna(subset=['energy_meter']) # any NaN makes the day shorter -> filtered by size later + data = data.sort_values(by=["dataid", "datetime"]) + + # Extract date for grouping (groups all hourly measurements from same calendar day) + data['date'] = data['datetime'].dt.date + + grouped = data.groupby(['dataid', 'date'])['energy_meter'].apply(np.asarray).reset_index() + + # grouped = grouped[grouped["energy_meter"].apply(len) == self.cfg.seq_len].reset_index(drop=True) + grouped = grouped[grouped["energy_meter"].apply(lambda x: not np.isnan(x).any())] + grouped['date'] = pd.to_datetime(grouped['date']) + + + grouped['year'] = grouped['date'].dt.year + grouped["month"] = grouped["date"].dt.month_name() + grouped["weekday"] = grouped["date"].dt.day_name() + grouped["date_day"] = grouped["date"].dt.day + + if grouped["energy_meter"].apply(lambda x: np.isnan(x).any()).any(): + raise ValueError("NaN values remain in grouped energy_meter sequences after filtering.") + merged = pd.merge(grouped, self.metadata, how="left", left_on="dataid", right_on="building_id").drop(columns=["building_id"]) + merged.sort_values(by=["dataid", "date"], inplace=True) + + merged = self._handle_missing_data(merged) + + # Check if any NaN remains + context_cols = [col for col in self.cfg.context_vars.keys() if col in merged.columns] + if merged[context_cols].isna().sum().sum() > 0: + print(f"Warning: {merged[context_cols].isna().sum().sum()} NaN values remain after handling") + + mask = merged["energy_meter"].apply(lambda x: x.std() < 0.01) + + merged = merged[~mask] + + return merged + + def _handle_missing_data(self, merged): + """ + Fill NaNs using hierarchical and context-aware imputation. + + Strategy: + 1. Numeric: Group-based median (by site/building type), fallback to global median + 2. Categorical: Use hierarchical structure (e.g., sub_primaryspaceusage from primaryspaceusage) + 3. Last resort: Mode or 'unknown' category + + Args: + merged (pd.DataFrame): Merged sequence+metadata rows. + + Returns: + pd.DataFrame: Fully imputed DataFrame. + """ + df = merged.copy() + + # Handle numeric columns with group-based imputation + numeric_cols = self.cfg.get('numeric_cols', []) + for col in numeric_cols: + if col in df.columns and df[col].isna().any(): + # Try imputing based on similar buildings (same site_id and primaryspaceusage) + if 'site_id' in df.columns and 'primaryspaceusage' in df.columns: + for (site, usage), group in df.groupby(['site_id', 'primaryspaceusage']): + group_median = group[col].median() + if pd.notna(group_median): + mask = (df['site_id'] == site) & (df['primaryspaceusage'] == usage) & df[col].isna() + df.loc[mask, col] = group_median + + # Fallback to global median for remaining NaNs + if df[col].isna().any(): + global_median = df[col].median() + df[col] = df[col].fillna(global_median if pd.notna(global_median) else 0) + + # Handle hierarchical categorical: sub_primaryspaceusage from primaryspaceusage + if 'sub_primaryspaceusage' in df.columns and 'primaryspaceusage' in df.columns: + if df['sub_primaryspaceusage'].isna().any(): + # For each primaryspaceusage, find most common sub category + for primary in df['primaryspaceusage'].unique(): + mask = (df['primaryspaceusage'] == primary) + mode_sub = df.loc[mask, 'sub_primaryspaceusage'].mode() + if len(mode_sub) > 0: + df.loc[mask & df['sub_primaryspaceusage'].isna(), 'sub_primaryspaceusage'] = mode_sub[0] + + # Handle remaining categorical columns + categorical_cols = [col for col in self.cfg.context_vars.keys() + if col not in numeric_cols and col in df.columns] + for col in categorical_cols: + if col in df.columns and df[col].isna().any(): + mode_val = df[col].mode() + if len(mode_val) > 0: + df[col] = df[col].fillna(mode_val[0]) + else: + # Create 'unknown' category instead of dropping + df[col] = df[col].fillna('unknown') + + return df + + def _reduce_high_cardinality_features(self, df, col, min_samples=50, max_categories=30): + """ + Reduce high-cardinality categorical features by grouping rare categories. + + Strategy: Keep top N categories by frequency, group the rest as 'other_{parent}' + where parent is the primaryspaceusage. + + Args: + df (pd.DataFrame): Input DataFrame + col (str): Column name to reduce + min_samples (int): Minimum samples to keep category separate + max_categories (int): Maximum number of categories to keep + + Returns: + pd.DataFrame: DataFrame with reduced categories + """ + if col not in df.columns: + return df + + df = df.copy() + value_counts = df[col].value_counts() + + # Keep categories with enough samples + keep_categories = value_counts[value_counts >= min_samples].index.tolist() + + # If still too many, keep only top max_categories + if len(keep_categories) > max_categories: + keep_categories = value_counts.nlargest(max_categories).index.tolist() + + # For sub_primaryspaceusage, group by parent category + if col == 'sub_primaryspaceusage' and 'primaryspaceusage' in df.columns: + # Map rare subcategories to 'other_{primary}' + def map_category(row): + if row[col] in keep_categories: + return row[col] + else: + return f"other_{row['primaryspaceusage']}" + df[col] = df.apply(map_category, axis=1) + else: + # Generic grouping + df.loc[~df[col].isin(keep_categories), col] = 'other' + + return df + + \ No newline at end of file diff --git a/cents/datasets/metraq.py b/cents/datasets/metraq.py new file mode 100644 index 0000000..a9bc2f6 --- /dev/null +++ b/cents/datasets/metraq.py @@ -0,0 +1,250 @@ +import os +import warnings +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides + +from cents.datasets.timeseries_dataset import TimeSeriesDataset +from cents.datasets.utils import is_all_nan, is_any_nan, fill_with_row_mean + +warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class MetraqDataset(TimeSeriesDataset): + """ + Dataset class for Metraq time series data. + + Handles loading and preprocessing including normalization and context variables. + Data: https://huggingface.co/datasets/dmariaa70/METRAQ-Air-Quality + + Attributes: + cfg (DictConfig): Hydra config for the dataset. + name (str): Dataset name. + geography (str): Geographic region selector. + normalize (bool): Whether to apply normalization. + threshold (Tuple[int, int]): Range filter for grid values. + include_generation (bool): If True, include solar series. + """ + + def __init__( + self, + cfg: Optional[DictConfig] = None, + overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None, + ): + """ + Initialize and preprocess the PecanStreet dataset. + + Loads metadata and timeseries CSVs, then applies filtering, + grouping, user-subsetting, and calls the base class for + further preprocessing (normalization, merging, rarity flags). + + Args: + cfg (Optional[DictConfig]): Override Hydra config; if None, + load from `config/dataset/pecanstreet.yaml`. + overrides (Optional[List[str]]): Override Hydra config; if None, + load from `config/dataset/pecanstreet.yaml` and apply overrides. + + Raises: + FileNotFoundError: If required CSV files are missing. + """ + if cfg is None: + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "metraq.yaml")) + if overrides: + cfg = apply_overrides(cfg, overrides) + + self.cfg = cfg + self.name = cfg.name + self.geography = cfg.geography + self.normalize = cfg.normalize + self.target_time_series_columns = cfg.time_series_columns + + self.threshold = (-1 * int(cfg.threshold), int(cfg.threshold)) + self.time_series_dims = cfg.time_series_dims + + self._load_data() + + ts_cols: List[str] = self.cfg.time_series_columns[: self.time_series_dims] + + self.context_time_series_columns = {k:v[1] for k,v in self.cfg.context_vars.items() if v[0] == "time_series"} + self.context_series_names = list(self.context_time_series_columns.keys()) + + super().__init__( + data=self.data, + time_series_column_names=ts_cols, + context_var_column_names=list(self.cfg.context_vars.keys()), + seq_len=self.cfg.seq_len, + normalize=self.cfg.normalize, + scale=self.cfg.scale, + skip_heavy_processing=cfg.get('skip_heavy_processing', False), + size=cfg.get('max_samples', None), + force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, + ) + + def _load_data(self) -> None: + """ + Populates self.data DataFrames. + + Raises: + FileNotFoundError: If any required CSV file is missing. + """ + module_dir = os.path.dirname(os.path.abspath(__file__)) + path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) + "/metraq_aq_processed.csv" + + data = pd.read_csv(path) + + if self.geography: + data = data.loc[data.sensor_name.isin(self.geography)].copy() + + self.data = data + + def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: + """ + Convert timestamps, assemble sequences of length seq_len, and merge metadata. + + Args: + data (pd.DataFrame): Raw concatenated grid (and solar) rows. + + Returns: + pd.DataFrame: One row per sequence, with array-valued 'grid' and + optional 'solar' columns plus context and metadata fields. + """ + data = data.copy() + data["timestamp"] = pd.to_datetime(data["entry_date"], utc=True) + data["year"] = data["timestamp"].dt.year + data["month"] = data["timestamp"].dt.month_name() + data["weekday"] = data["timestamp"].dt.day_name() + data["day"] = data["timestamp"].dt.day + + ctx_ts = list(self.context_series_names) + tgt_ts = list(self.target_time_series_columns) + + if "PM10" in data.columns and "PM2.5" in data.columns: + pm10 = pd.to_numeric(data["PM10"], errors="coerce") + pm25 = pd.to_numeric(data["PM2.5"], errors="coerce") + data["PMcoarse"] = (pm10 - pm25).clip(lower=0.0) + + if "PMcoarse" in data.columns: + tgt_ts = ["PMcoarse" if c == "PM10" else c for c in tgt_ts] + + if "WD" in data.columns and "WS" in data.columns: + wd_deg = pd.to_numeric(data["WD"], errors="coerce") + ws = pd.to_numeric(data["WS"], errors="coerce").clip(lower=0.0) + data["wd_valid"] = wd_deg.notna().astype(np.int8) + wd_deg = wd_deg.fillna(0.0) + ws = ws.fillna(0.0) + wd_rad = np.deg2rad(wd_deg) + data["wind_u"] = ws * np.sin(wd_rad) + data["wind_v"] = ws * np.cos(wd_rad) + ctx_ts = [ + "wind_u" if c == "WS" else "wind_v" if c == "WD" else c + for c in ctx_ts + ] + + ts_cols = ctx_ts + tgt_ts + + missing = [c for c in ts_cols if c not in data.columns] + if missing: + raise ValueError(f"Missing required time-series columns after preprocessing: {missing}") + + data = data.sort_values(["sensor_name", "timestamp"]) + + group_keys = ["sensor_name", "year", "month", "day", "weekday"] + static_continuous_cols = [ + k for k, v in self.cfg.context_vars.items() + if v[0] == "continuous" and k in data.columns + ] + agg_dict = {c: list for c in ts_cols} + agg_dict.update({c: "first" for c in static_continuous_cols}) + + grouped = ( + data.groupby(group_keys, as_index=False, sort=False) + .agg(agg_dict) + ) + + + for c in ts_cols: + grouped[c] = grouped[c].map(np.asarray) + + len_col = tgt_ts[0] if len(tgt_ts) > 0 else ts_cols[0] + grouped = grouped[grouped[len_col].apply(len) == self.cfg.seq_len].reset_index(drop=True) + + grouped = self._handle_missing_data(grouped) + + ctx_numeric = [c for c in ctx_ts if c not in self.categorical_time_series] + log1p_channels = {"TI"} + binary_channels = {"wd_valid"} # already in [0, 1] — skip z-scoring + clip_bound = 5.0 + eps = 1e-8 + + ctx_stats = {} # {channel: {sensor_name: (mu, sd)}} + for c in ctx_numeric: + if c in binary_channels: + grouped[c] = list(np.stack(grouped[c].values).astype(np.float32)) + continue + + ctx_stats[c] = {} + col_arrays = grouped[c].map(np.asarray) + + if c in log1p_channels: + col_arrays = col_arrays.map(lambda x: np.log1p(np.clip(x, 0.0, None))) + + normalized = col_arrays.copy() + for stn, idx in grouped.groupby("sensor_name").groups.items(): + X = np.stack(col_arrays.loc[idx].values).astype(np.float32) + mu = float(X.mean()) + sd = float(X.std()) + if sd < 1e-6: + sd = 1.0 + ctx_stats[c][stn] = (mu, sd) + Xn = np.clip((X - mu) / (sd + eps), -clip_bound, clip_bound) + for arr_i, row_i in enumerate(idx): + normalized.loc[row_i] = Xn[arr_i] + + grouped[c] = list(normalized) + + self.context_ts_stats_ = ctx_stats + + for c in ts_cols: + grouped[c] = grouped[c].map(tuple) + + return grouped + + def _handle_missing_data(self, data): + numeric_series = [c for c in self.context_series_names if c not in self.categorical_time_series] + + mask = data[numeric_series].applymap(is_all_nan).any(axis=1) if numeric_series else pd.Series([False] * len(data)) + data = data[~mask] + + for col in numeric_series: + data[col] = data[col].apply(fill_with_row_mean) + + cat_cols = list(self.categorical_time_series.keys()) + if cat_cols: + mask = data[cat_cols].applymap(is_any_nan).any(axis=1) + data = data[~mask] + + for tcol in self.target_time_series_columns: + if tcol in data.columns: + data = data.loc[data[tcol].apply(lambda x: not np.isnan(np.asarray(x, dtype=float)).any())] + + def row_has_low_std(row, cols, thresh=0.01): + for c in cols: + arr = np.asarray(row[c], dtype=np.float32) + if arr.std() < thresh: + return True + return False + + mask = data.apply( + lambda row: row_has_low_std(row, self.target_time_series_columns, thresh=0.01), + axis=1 + ) + + data = data[~mask] + return data \ No newline at end of file diff --git a/cents/datasets/pecanstreet.py b/cents/datasets/pecanstreet.py index 866921d..c3e0022 100644 --- a/cents/datasets/pecanstreet.py +++ b/cents/datasets/pecanstreet.py @@ -4,8 +4,8 @@ import numpy as np import pandas as pd -from hydra import compose, initialize_config_dir from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides from cents.datasets.timeseries_dataset import TimeSeriesDataset @@ -33,6 +33,8 @@ def __init__( self, cfg: Optional[DictConfig] = None, overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None, ): """ Initialize and preprocess the PecanStreet dataset. @@ -51,10 +53,9 @@ def __init__( FileNotFoundError: If required CSV files are missing. """ if cfg is None: - with initialize_config_dir( - config_dir=os.path.join(ROOT_DIR, "config/dataset"), version_base=None - ): - cfg = compose(config_name="pecanstreet", overrides=overrides) + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "pecanstreet.yaml")) + if overrides: + cfg = apply_overrides(cfg, overrides) self.cfg = cfg self.name = cfg.name @@ -63,8 +64,6 @@ def __init__( self.threshold = (-1 * int(cfg.threshold), int(cfg.threshold)) self.time_series_dims = cfg.time_series_dims - self.cfg.time_series_columns = ["grid", "solar"] - self.include_generation = self.time_series_dims > 1 if self.time_series_dims > 1 and self.cfg.user_group in {"non_pv_users", "all"}: @@ -75,7 +74,7 @@ def __init__( self._load_data() self._set_user_flags() - + ts_cols: List[str] = self.cfg.time_series_columns[: self.time_series_dims] super().__init__( @@ -85,6 +84,10 @@ def __init__( seq_len=self.cfg.seq_len, normalize=self.cfg.normalize, scale=self.cfg.scale, + skip_heavy_processing=cfg.get('skip_heavy_processing', False), + size=cfg.get('max_samples', None), + force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, ) def _load_data(self) -> None: diff --git a/cents/datasets/timeseries_dataset.py b/cents/datasets/timeseries_dataset.py index 4f70776..9e6940c 100644 --- a/cents/datasets/timeseries_dataset.py +++ b/cents/datasets/timeseries_dataset.py @@ -8,14 +8,16 @@ import pandas as pd import pytorch_lightning as pl import torch -from hydra import compose, initialize_config_dir from pytorch_lightning.callbacks import ModelCheckpoint from sklearn.cluster import KMeans from torch.utils.data import DataLoader, Dataset +from omegaconf import ListConfig, OmegaConf +import pickle from cents.datasets.utils import encode_context_variables from cents.models.normalizer import Normalizer -from cents.utils.utils import _ckpt_name, get_normalizer_training_config +from cents.utils.config_loader import load_yaml, apply_overrides +from cents.utils.utils import _ckpt_name, get_normalizer_training_config, get_context_config ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -50,62 +52,116 @@ def __init__( normalize: bool = True, scale: bool = True, overrides: Dict[str, Any] = {}, + skip_heavy_processing: bool = False, + size: int = None, + categorical_time_series: Dict[str, int] = None, + force_retrain_normalizer: bool = False, + run_dir: Any = None, ): # Initialize basic attributes - self.time_series_column_names = ( - time_series_column_names - if isinstance(time_series_column_names, list) - else [time_series_column_names] - ) - self.time_series_dims = len(self.time_series_column_names) + # Handle OmegaConf ListConfig objects + if not isinstance(time_series_column_names, list): + time_series_column_names = list(time_series_column_names) + if isinstance(context_var_column_names, ListConfig): + context_var_column_names = list(context_var_column_names) + + self.time_series_column_names = time_series_column_names + self.time_series_dims = self.cfg.time_series_dims self.context_vars = context_var_column_names or [] self.seq_len = seq_len # Load dataset-level config if not already set if not hasattr(self, "cfg"): - with initialize_config_dir( - config_dir=os.path.join(ROOT_DIR, "config", "dataset"), - version_base=None, - ): - overrides = [ - f"seq_len={seq_len}", - f"time_series_dims={len(self.time_series_column_names)}", - ] - cfg = compose(config_name="default", overrides=overrides) - cfg.time_series_columns = self.time_series_column_names - self.numeric_context_bins = cfg.numeric_context_bins - context_vars = self._get_context_var_dict(data) - cfg.context_vars = context_vars - self.cfg = cfg - + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "default.yaml")) + # dynamic overrides for required fields + dyn = [ + f"seq_len={seq_len}", + f"time_series_dims={len(self.time_series_column_names)}", + ] + cfg = apply_overrides(cfg, dyn) + cfg.time_series_columns = self.time_series_column_names + self.numeric_context_bins = cfg.numeric_context_bins + self.cfg = cfg + + self.context_var_dict = self.cfg.context_vars + self.numeric_cols = [k for k, v in self.cfg.context_vars.items() if v[0] == "categorical" and v[1] is None] self.numeric_context_bins = self.cfg.numeric_context_bins - if not hasattr(self, "threshold"): - self.threshold = (-self.cfg.threshold, self.cfg.threshold) + + for k in self.numeric_cols: + self.context_var_dict[k] = ["categorical", self.numeric_context_bins] + if not hasattr(self, "name"): self.name = "custom" + # Split context vars into static and dynamic once (no future re-splits) + self.continuous_vars = [k for k, v in self.cfg.context_vars.items() if v[0] == "continuous"] + categorical_vars = [k for k, v in self.cfg.context_vars.items() if v[0] == "categorical"] + self.dynamic_context_vars = [k for k, v in self.cfg.context_vars.items() if v[0] == "time_series"] + self.static_context_vars = categorical_vars + self.continuous_vars + self.normalize = normalize self.scale = scale + self.force_retrain_normalizer = force_retrain_normalizer + self.run_dir = Path(run_dir) if run_dir is not None else None - if self.scale: - assert self.normalize, "Normalization must be enabled if scaling is enabled" + # Store categorical time series info + self.categorical_time_series = categorical_time_series or {} # Preprocess and optionally encode context self.data = self._preprocess_data(data) + + if self.continuous_vars: + self._normalize_continuous_vars() + + if size is not None: + self.data = self.data.sample(size) + print(f"Sampled {size} rows from dataset") if self.context_vars: self.data, self.context_var_codes = self._encode_context_vars(self.data) self._save_context_var_codes() + + self.context_cfg = get_context_config() + self.dynamic_module_type = self.context_cfg.dynamic_context.type + self.static_module_type = getattr(self.context_cfg.normalizer, "context_type", "mlp") + self.stats_head_type = self.context_cfg.normalizer.stats_head_type + + is_ddp_subprocess = self._is_ddp_subprocess() if self.normalize: self._init_normalizer() - self.data = self._normalizer.transform(self.data) - + cache_path = self._get_normalization_cache_path() + if cache_path.exists(): + print(f"[{'DDP Subprocess' if is_ddp_subprocess else 'Main Process'}] Loading pre-normalized data from cache", cache_path) + with open(cache_path, 'rb') as f: + self.data = pickle.load(f) + else: + # Normalize (only main process or if cache doesn't exist) + if not is_ddp_subprocess: + print("[Main Process] Normalizing data...") + self.data = self._normalizer.transform(self.data) + # Save to cache for subprocesses (only main process) + if not is_ddp_subprocess: + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, 'wb') as f: + pickle.dump(self.data, f) + print(f"[Main Process] Cached normalized data for subprocesses") self.data = self.merge_timeseries_columns(self.data) self.data = self.data.reset_index() - self.data = self.get_frequency_based_rarity() - self.data = self.get_clustering_based_rarity() - self.data = self.get_combined_rarity() + # Check if we should skip heavy processing for DDP + if is_ddp_subprocess and skip_heavy_processing: + print("skipped rarity computation for DDP compatibility") + cache_path = self._get_rarity_cache_path() + if self._load_rarity_cache(cache_path): + self._rarity_computed = True + else: + print("Computing rarity features...") + self.data = self.get_frequency_based_rarity() + self.data = self.get_clustering_based_rarity() + self.data = self.get_combined_rarity() + rarity_cache_path = self._get_rarity_cache_path() + self._save_rarity_cache(rarity_cache_path) + self._rarity_computed = True @abstractmethod def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: """ @@ -137,20 +193,52 @@ def __getitem__(self, idx: int): idx (int): Sample index. Returns: - Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: - timeseries: Tensor of shape (seq_len, dims). - - context_vars: Dict of context variable tensors. + - static_context_vars: Dict of static context tensors (categorical long, continuous float). + - dynamic_context_vars: Dict of dynamic (time_series) context tensors. """ sample = self.data.iloc[idx] timeseries = torch.tensor(sample["timeseries"], dtype=torch.float32) - context_vars_dict = { - var: torch.tensor(sample[var], dtype=torch.long) - for var in self.context_vars - } - return timeseries, context_vars_dict + + static_context_vars_dict = {} + for var in self.static_context_vars: + if var in self.continuous_vars: + val = sample[var] + if isinstance(val, (float, int)) and (not isinstance(val, bool) and (np.isnan(val) or np.isinf(val))): + raise ValueError( + f"NaN/Inf detected in continuous variable '{var}' in dataset at index {idx}. " + f"Value: {val}. This should not happen if normalization was done correctly." + ) + static_context_vars_dict[var] = torch.tensor(val, dtype=torch.float32) + else: + static_context_vars_dict[var] = torch.tensor(sample[var], dtype=torch.long) + + dynamic_context_vars_dict = {} + for var in self.dynamic_context_vars: + arr = np.asarray(sample[var]) + if var in self.categorical_time_series: + dynamic_context_vars_dict[var] = torch.from_numpy(arr).long() + else: + dynamic_context_vars_dict[var] = torch.from_numpy(arr).float() + + return timeseries, static_context_vars_dict, dynamic_context_vars_dict + + def __getstate__(self): + """ + Make the dataset picklable for DataLoader worker processes by dropping + non-picklable attributes (e.g., attached Lightning normalizer module). + + Note: Training workers only need access to the preprocessed `self.data`. + The normalizer is not used during batching, so it is safe to omit here. + """ + state = self.__dict__.copy() + if state.get("_normalizer", None) is not None: + state["_normalizer"] = None + return state def get_train_dataloader( - self, batch_size: int, shuffle: bool = True, num_workers: int = 4 + self, batch_size: int, shuffle: bool = True, num_workers: int = 8, persistent_workers: bool = True ) -> DataLoader: """ Create a PyTorch DataLoader for training. @@ -159,12 +247,15 @@ def get_train_dataloader( batch_size (int): Batch size. shuffle (bool): Whether to shuffle the data. num_workers (int): Number of worker processes. + persistent_workers (bool): Whether to keep workers alive between epochs. Returns: DataLoader: Configured data loader. """ + + self._normalize_continuous_vars() return DataLoader( - self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers + self, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, persistent_workers=persistent_workers ) def split_timeseries(self, df: pd.DataFrame) -> pd.DataFrame: @@ -220,6 +311,12 @@ def merge_timeseries_columns(self, df: pd.DataFrame) -> pd.DataFrame: raise ValueError("Incorrect array shape.") else: raise ValueError("Array must have 2 dims.") + + # For categorical time series, ensure they remain as integers + for col in self.time_series_column_names: + if col in self.categorical_time_series: + df[col] = df[col].apply(lambda x: x.astype(np.int32)) + df["timeseries"] = df.apply( lambda r: np.hstack([r[c] for c in self.time_series_column_names]), axis=1 ) @@ -244,11 +341,24 @@ def inverse_transform( df = self._normalizer.inverse_transform(df) return self.merge_timeseries_columns(df) if merged else df + def apply_pretrained_normalizer(self) -> None: + """ + Transform self.data with the attached pretrained normalizer (e.g. after + dataset._normalizer = data_generator.normalizer). Use when normalize=False + was set to avoid training but you still want real data in normalized space. + """ + if getattr(self, "_normalizer", None) is None: + return + df_split = self.split_timeseries(self.data.copy()) + self.data = self.merge_timeseries_columns(self._normalizer.transform(df_split)) + self.data = self.data.reset_index(drop=True) + def _encode_context_vars( self, data: pd.DataFrame ) -> Tuple[pd.DataFrame, Dict[str, Any]]: """ Encode and bin numeric or categorical context variables. + Continuous variables are kept as-is. Args: data (pd.DataFrame): Input DataFrame. @@ -256,11 +366,55 @@ def _encode_context_vars( Returns: Tuple of encoded DataFrame and mapping codes. """ - return encode_context_variables( + encoded_data, mapping = encode_context_variables( data=data, columns_to_encode=self.context_vars, bins=self.numeric_context_bins, + numeric_cols=self.numeric_cols, + continuous_vars=self.continuous_vars, + time_series_cols=self.dynamic_context_vars, + categorical_time_series=self.categorical_time_series, ) + + return encoded_data, mapping + + def _normalize_continuous_vars(self): + """ + Normalize continuous context variables in the dataset using z-score normalization. + This is done once during dataset initialization, so models receive pre-normalized values. + """ + if not self.continuous_vars: + return + + # Store stats for potential inverse transform if needed + self.continuous_var_stats = {} + + for var_name in self.continuous_vars: + if var_name in self.data.columns: + values = self.data[var_name] + # Compute mean and std + mean_val = float(values.mean()) + std_val = float(values.std()) + + # Avoid zero std + if std_val < 1e-8: + std_val = 1.0 + print(f"[Dataset] Warning: {var_name} has zero std, using std=1.0") + + # Store stats for reference and for sampling in original range + self.continuous_var_stats[var_name] = { + 'mean': mean_val, + 'std': std_val, + 'min': float(values.min()), + 'max': float(values.max()), + } + + # Normalize the values in-place: (x - mean) / std + self.data[var_name] = (values - mean_val) / std_val + + print(f"[Dataset] Normalized {var_name}: mean={mean_val:.4f}, std={std_val:.4f}") + else: + print(f"[Dataset] Warning: Continuous variable {var_name} not found in data columns") def _get_context_var_dict(self, data: pd.DataFrame) -> Dict[str, int]: """ @@ -273,8 +427,9 @@ def _get_context_var_dict(self, data: pd.DataFrame) -> Dict[str, int]: dict: {var_name: num_categories} """ context_dict = {} + numeric_cols = getattr(self.cfg, 'numeric_cols', []) for var in self.context_vars: - if pd.api.types.is_numeric_dtype(data[var]): + if var in numeric_cols: binned = pd.cut( data[var], bins=self.numeric_context_bins, include_lowest=True ) @@ -309,8 +464,21 @@ def sample_random_context_vars(self) -> Dict[str, torch.Tensor]: dict: Random context index tensors. """ ctx = {} - for var, n in self._get_context_var_dict(self.data).items(): - ctx[var] = torch.randint(0, n, (), dtype=torch.long) + + for var, info in self.context_var_dict.items(): + if info[0] == "categorical": + ctx[var] = torch.randint(0, info[1], (), dtype=torch.long) + elif info[0] == "continuous": + # Sample from original [min, max] then z-score to match training data + stats = getattr(self, "continuous_var_stats", {}).get(var) + if stats is not None and "min" in stats and "max" in stats: + x_raw = np.random.uniform(stats["min"], stats["max"]) + x_norm = (x_raw - stats["mean"]) / stats["std"] + ctx[var] = torch.tensor(x_norm, dtype=torch.float32) + else: + raise ValueError(f"Continuous variable {var} has no stats") + else: + raise ValueError(f"Invalid context variable type: {info[0]}") return ctx def get_context_var_combination_rarities( @@ -338,12 +506,19 @@ def get_frequency_based_rarity(self) -> pd.DataFrame: Returns: pd.DataFrame: DataFrame with 'is_frequency_rare' column. """ - freq = self.data.groupby(self.context_vars).size().reset_index(name="count") + # Continuous vars are floats (e.g. UTM coordinates) — grouping by them would create + # one group per unique value, making rarity meaningless. Use only discrete vars. + continuous = set(getattr(self, "continuous_vars", [])) + groupby_vars = [v for v in self.context_vars if v not in continuous] + if not groupby_vars: + self.data["is_frequency_rare"] = False + return self.data + freq = self.data.groupby(groupby_vars).size().reset_index(name="count") threshold = freq["count"].quantile(0.1) freq["is_frequency_rare"] = freq["count"] < threshold self.data = self.data.merge( - freq[self.context_vars + ["is_frequency_rare"]], - on=self.context_vars, + freq[groupby_vars + ["is_frequency_rare"]], + on=groupby_vars, how="left", ) return self.data @@ -406,6 +581,94 @@ def get_combined_rarity(self) -> pd.DataFrame: ) return self.data + def _ensure_rarity_computed(self): + """ + Compute rarity features lazily to avoid DDP issues. + """ + if not self._rarity_computed: + cache_path = self._get_rarity_cache_path() + if self._load_rarity_cache(cache_path): + self._rarity_computed = True + else: + self.data = self.get_frequency_based_rarity() + self.data = self.get_clustering_based_rarity() + self.data = self.get_combined_rarity() + self._save_rarity_cache(cache_path) + self._rarity_computed = True + + def _get_rarity_cache_path(self) -> str: + """Get cache file path for rarity features.""" + import hashlib + context_cfg = get_context_config() + context_module_type = context_cfg.static_context.type + cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{str(sorted(self.context_vars))}_{context_module_type or ''}" + cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8] + if self.run_dir is not None: + cache_dir = self.run_dir / "cache" / "rarity" + cache_dir.mkdir(parents=True, exist_ok=True) + return str(cache_dir / f"rarity_{cache_hash}.pkl") + cache_dir = os.path.join(ROOT_DIR, "cache", "rarity") + os.makedirs(cache_dir, exist_ok=True) + return os.path.join(cache_dir, f"rarity_{cache_hash}.pkl") + + def _get_normalization_cache_path(self): + """Get cache file path for normalized data.""" + import hashlib + from pathlib import Path + context_cfg = get_context_config() + context_module_type = context_cfg.dynamic_context.type + stats_head_type = context_cfg.normalizer.stats_head_type + cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{self.normalize}_{self.scale}_{context_module_type or ''}_{stats_head_type or ''}" + cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8] + if self.run_dir is not None: + cache_dir = self.run_dir / "cache" / "normalized_data" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / f"normalized_{cache_hash}.pkl" + cache_dir = Path(ROOT_DIR) / "cache" / "normalized_data" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / f"normalized_{cache_hash}.pkl" + + def _save_rarity_cache(self, cache_path: str) -> None: + """Save rarity features to cache.""" + import pickle + rarity_data = { + 'is_frequency_rare': self.data.get('is_frequency_rare'), + 'is_pattern_rare': self.data.get('is_pattern_rare'), + 'is_rare': self.data.get('is_rare'), + 'cluster': self.data.get('cluster') + } + with open(cache_path, 'wb') as f: + pickle.dump(rarity_data, f) + print(f"Saved rarity cache to {cache_path}") + + def _load_rarity_cache(self, cache_path: str) -> bool: + """Load rarity features from cache.""" + import pickle + if not os.path.exists(cache_path): + return False + try: + with open(cache_path, 'rb') as f: + rarity_data = pickle.load(f) + for col, data in rarity_data.items(): + if data is not None: + self.data[col] = data + return True + except Exception as e: + print(f"Failed to load rarity cache: {e}") + return False + + def _is_ddp_subprocess(self) -> bool: + """Detect if we're running in a DDP subprocess.""" + import os + # Check for DDP-related environment variables + ddp_indicators = [ + 'LOCAL_RANK' in os.environ, + 'WORLD_SIZE' in os.environ, + 'RANK' in os.environ, + os.environ.get('CUDA_VISIBLE_DEVICES', '').count(',') > 0, # Multiple GPUs + ] + return any(ddp_indicators) + def _init_normalizer(self) -> None: """ Initialize or load a cached Normalizer for this dataset. @@ -413,23 +676,43 @@ def _init_normalizer(self) -> None: On first run, trains a new Normalizer and writes a single state dict to cache. On subsequent runs, loads that file. If loading fails, deletes the corrupted cache and retrains. """ - normalizer_dir = ( - Path.home() / ".cache" / "cents" / "checkpoints" / self.name / "normalizer" - ) + if self.run_dir is not None: + normalizer_dir = self.run_dir / "normalizer" + else: + normalizer_dir = ( + Path.home() / ".cache" / "cents" / "checkpoints" / self.name / "normalizer" + ) normalizer_dir.mkdir(parents=True, exist_ok=True) + + ncfg = get_normalizer_training_config() + if hasattr(self.cfg, "normalizer_use_global_stats_preprocessing"): + ncfg = OmegaConf.merge( + ncfg, + OmegaConf.create({"use_global_stats_preprocessing": self.cfg.normalizer_use_global_stats_preprocessing}), + ) + use_global = ncfg.get("use_global_stats_preprocessing", True) + cache_path = normalizer_dir / _ckpt_name( - self.name, "normalizer", self.time_series_dims + self.name, + "normalizer", + self.time_series_dims, + static_module_type=self.static_module_type, + stats_head_type=self.stats_head_type, + dynamic_module_type=self.dynamic_module_type, + use_global_stats_preprocessing=use_global, ) - ncfg = get_normalizer_training_config() + print(f"[Cents] cache_path: {cache_path}") + self._normalizer = Normalizer( dataset_cfg=self.cfg, normalizer_training_cfg=ncfg, dataset=self, + context_cfg=self.context_cfg, ) - # attempt to load existing state dict - if cache_path.exists(): + # attempt to load existing state dict (unless force_retrain_normalizer is True) + if cache_path.exists() and not self.force_retrain_normalizer: try: state = torch.load(cache_path, map_location="cpu") sd = state.get("state_dict", state) @@ -442,9 +725,12 @@ def _init_normalizer(self) -> None: cache_path.unlink() except OSError: pass + elif self.force_retrain_normalizer and cache_path.exists(): + print(f"[Cents] Force retrain enabled, ignoring cached normalizer at {cache_path}") # train and cache a single state dict print("[Cents] Training normalizer…") + print(f"[Cents] devices: {ncfg.devices}") trainer = pl.Trainer( max_epochs=ncfg.n_epochs, accelerator=ncfg.accelerator, diff --git a/cents/datasets/utils.py b/cents/datasets/utils.py index 96a9183..bf615cb 100644 --- a/cents/datasets/utils.py +++ b/cents/datasets/utils.py @@ -18,7 +18,7 @@ def split_timeseries( if n_dim != len(time_series_columns): raise ValueError("shape mismatch") for i, col in enumerate(time_series_columns): - df[col] = df["timeseries"].apply(lambda x: x[:, i]) + df[col] = df["timeseries"].apply(lambda x: x[:, i].tolist()) return df.drop(columns=["timeseries"]) @@ -108,17 +108,23 @@ def split_dataset(dataset: Dataset, val_split: float = 0.1) -> Tuple[Dataset, Da def encode_context_variables( - data: pd.DataFrame, columns_to_encode: List[str], bins: int + data: pd.DataFrame, columns_to_encode: List[str], bins: int, + numeric_cols: List[str] = None, continuous_vars: List[str] = None, + time_series_cols: List[str] = None, + categorical_time_series: Dict[str, int] = None, ) -> Tuple[pd.DataFrame, Dict[str, Dict[int, Any]]]: """ Encodes specified columns in the DataFrame either by binning numeric columns or by converting categorical/string columns to integer codes. For 'weekday' and 'month' columns, encoding follows chronological order. + Continuous variables are skipped and kept as-is. Args: data (pd.DataFrame): The input DataFrame containing the data. columns_to_encode (List[str]): List of column names to encode. bins (int): Number of bins for numeric columns. + numeric_cols (List[str], optional): Columns to treat as numeric for binning. + continuous_vars (List[str], optional): Columns to keep as continuous (skip encoding). Returns: Tuple[pd.DataFrame, Dict[str, Dict[int, Any]]]: @@ -127,6 +133,7 @@ def encode_context_variables( """ encoded_data = data.copy() mapping: Dict[str, Dict[int, Any]] = {} + continuous_vars = continuous_vars or [] # Define the chronological order for weekdays and months weekdays_order = [ @@ -154,9 +161,25 @@ def encode_context_variables( ] for col in columns_to_encode: - if pd.api.types.is_numeric_dtype(encoded_data[col]): + # Skip continuous variables - they should remain as float values + if col in categorical_time_series: + encoded_data[col], mapping[col] = encode_list_column(encoded_data[col]) + elif col in time_series_cols or col in continuous_vars: + continue + elif numeric_cols and col in numeric_cols: + print("ENCODING NUMERIC COL", col) # Numeric column: Perform binning - binned = pd.cut(encoded_data[col], bins=bins, include_lowest=True) + # Handle NaN values by filling with median before binning + if encoded_data[col].isna().all(): + # Entire column is NaN - fill with 0 and create single bin + print(f" Warning: {col} is entirely NaN, filling with 0") + encoded_data[col] = 0 + elif encoded_data[col].isna().any(): + print(f" Warning: {col} has {encoded_data[col].isna().sum()} NaN values, filling with median") + median_val = encoded_data[col].median() + encoded_data[col] = encoded_data[col].fillna(median_val if pd.notna(median_val) else 0) + + binned = pd.cut(encoded_data[col], bins=bins, include_lowest=True, duplicates='drop') encoded_data[col] = binned.cat.codes # Assign integer codes starting from 0 bin_intervals = binned.cat.categories # Create the mapping from integer code to bin interval @@ -217,19 +240,53 @@ def convert_generated_data_to_df( n_samples = data_np.shape[0] - if decode: - if mapping is None: - raise ValueError("Mapping must be provided when decode=True.") - cond_vars = { - var: mapping[var][code.item()] for var, code in context_vars.items() - } - else: - cond_vars = {var: int(code.item()) for var, code in context_vars.items()} + def _get_code_at(code: Any, i: int) -> Any: + if isinstance(code, torch.Tensor): + if code.dim() == 2 and code.shape[0] == n_samples: + # Dynamic (time-series) context: return as numpy array + return code[i].cpu().numpy() + if code.dim() == 1 and code.shape[0] == n_samples: + return code[i].item() + return code.item() if isinstance(code, torch.Tensor) else code records = [] for i in range(n_samples): - record = cond_vars.copy() + record = {} + for var, code in context_vars.items(): + v = _get_code_at(code, i) + if decode: + if mapping is None: + raise ValueError("Mapping must be provided when decode=True.") + record[var] = mapping[var][v] + elif isinstance(v, np.ndarray): + record[var] = v.tolist() + else: + record[var] = v if isinstance(v, float) else int(v) record["timeseries"] = data_np[i] records.append(record) return pd.DataFrame.from_records(records) + +def encode_list_column(series: pd.Series): + count = 0 + for x in series: + for t in x: + if pd.isna(t): + count += 1 + vocab = sorted({t for x in series for t in x}) + tok2id = {t: i for i, t in enumerate(vocab)} + encoded = series.apply(lambda x: [tok2id[t] for t in x]) + encoded = encoded.map(tuple) + mapping = dict(enumerate(vocab)) # id -> token + return encoded, mapping + +def is_all_nan(arr): + return pd.isna(arr).all() + +def is_any_nan(arr): + return pd.isna(arr).any() + +def fill_with_row_mean(lst): + s = pd.Series(lst, dtype=float) + m = s.mean(skipna=True) + return s.fillna(m).tolist() diff --git a/cents/datasets/walmart.py b/cents/datasets/walmart.py new file mode 100644 index 0000000..99fda3e --- /dev/null +++ b/cents/datasets/walmart.py @@ -0,0 +1,283 @@ +import logging +import os +import warnings +from typing import List, Optional + +import numpy as np +import pandas as pd +from omegaconf import DictConfig +from cents.utils.config_loader import load_yaml, apply_overrides + +from cents.datasets.timeseries_dataset import TimeSeriesDataset +from cents.datasets.utils import is_all_nan, is_any_nan, fill_with_row_mean + +warnings.filterwarnings("ignore", category=pd.errors.SettingWithCopyWarning) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +_MONTHS = [ + "January", "February", "March", "April", "May", "June", + "July", "August", "September", "October", "November", "December", +] + + +class WalmartDataset(TimeSeriesDataset): + """ + Dataset class for Walmart M5 daily unit sales time series. + + Data: https://www.kaggle.com/competitions/m5-forecasting-accuracy + + Each sample is the first 28 days of a calendar month for a single + item-store pair (days 29–31 are discarded), filtered to high-velocity + items (top 20% by mean daily sales). The 28-day window aligns exactly + with four weeks, eliminating month-length drift and making month a clean + static context variable. Dynamic context includes sell price, SNAP + eligibility, and a binary calendar-event indicator. + """ + + def __init__( + self, + cfg: Optional[DictConfig] = None, + overrides: Optional[List[str]] = None, + force_retrain_normalizer: bool = False, + run_dir: Optional[str] = None, + ): + if cfg is None: + cfg = load_yaml(os.path.join(ROOT_DIR, "config", "dataset", "walmart.yaml")) + if overrides: + cfg = apply_overrides(cfg, overrides) + + self.cfg = cfg + self.name = cfg.name + self.normalize = cfg.normalize + self.target_time_series_columns = list(cfg.time_series_columns) + self.time_series_dims = cfg.time_series_dims + self.geography = cfg.get("geography", None) + + self.context_time_series_columns = { + k: v[1] for k, v in cfg.context_vars.items() if v[0] == "time_series" + } + self.context_series_names = list(self.context_time_series_columns.keys()) + # No categorical time series for this dataset (all dynamic vars are continuous/binary) + self.categorical_time_series = { + k: v[1] for k, v in cfg.context_vars.items() + if v[0] == "time_series" and v[1] is not None + } + + self._load_data() + + ts_cols: List[str] = self.target_time_series_columns[: self.time_series_dims] + + super().__init__( + data=self.data, + time_series_column_names=ts_cols, + context_var_column_names=list(cfg.context_vars.keys()), + seq_len=cfg.seq_len, + normalize=cfg.normalize, + scale=cfg.scale, + skip_heavy_processing=cfg.get("skip_heavy_processing", False), + size=cfg.get("max_samples", None), + categorical_time_series=self.categorical_time_series, + force_retrain_normalizer=force_retrain_normalizer, + run_dir=run_dir, + ) + + def _load_data(self) -> None: + """ + Loads and joins sales, calendar, and price CSVs; filters to + high-velocity item-store pairs; constructs per-state SNAP flags + and binary event indicators. + """ + module_dir = os.path.dirname(os.path.abspath(__file__)) + path = os.path.normpath(os.path.join(module_dir, "..", self.cfg.path)) + + # --- Sales: wide → long --- + sales_path = os.path.join(path, "sales_train_evaluation.csv") + if not os.path.exists(sales_path): + raise FileNotFoundError(f"Sales file not found: {sales_path}") + + sales = pd.read_csv(sales_path) + meta_cols = ["id", "item_id", "dept_id", "cat_id", "store_id", "state_id"] + day_cols = [c for c in sales.columns if c.startswith("d_")] + + # --- High-velocity filter computed from wide format (before the expensive melt) --- + day_means = sales[day_cols].mean(axis=1) + vel_threshold = day_means.quantile(0.80) + keep_ids = set(sales.loc[day_means >= vel_threshold, "id"].values) + + max_samples = self.cfg.get("max_samples", None) + if max_samples is not None and len(keep_ids) > 0: + est_windows_per_id = 60 + n_ids = min(len(keep_ids), int(np.ceil(max_samples * 2.0 / est_windows_per_id))) + keep_ids = set(np.random.choice(sorted(keep_ids), size=n_ids, replace=False)) + logging.info( + "WalmartDataset: subsampling %d IDs (est. ~%d windows) for max_samples=%d", + n_ids, n_ids * est_windows_per_id, max_samples, + ) + + sales = sales[sales["id"].isin(keep_ids)] + + sales_long = sales[meta_cols + day_cols].melt( + id_vars=meta_cols, var_name="d", value_name="sales" + ) + del sales # free memory before merges + + # --- Calendar --- + calendar = pd.read_csv( + os.path.join(path, "calendar.csv"), + usecols=["d", "date", "wm_yr_wk", "weekday", "month", "year", + "event_name_1", "snap_CA", "snap_TX", "snap_WI"], + ) + calendar["date"] = pd.to_datetime(calendar["date"]) + sales_long = sales_long.merge(calendar, on="d", how="left") + del calendar + + # --- Sell prices (weekly → broadcast to daily via merge, then ffill) --- + prices = pd.read_csv(os.path.join(path, "sell_prices.csv")) + sales_long = sales_long.merge(prices, on=["store_id", "item_id", "wm_yr_wk"], how="left") + del prices + + sales_long = sales_long.sort_values(["id", "date"]) + sales_long["sell_price"] = ( + sales_long.groupby("id")["sell_price"] + .transform(lambda x: x.ffill().bfill()) + ) + + sales_long["snap"] = np.where( + sales_long["state_id"] == "CA", sales_long["snap_CA"], + np.where( + sales_long["state_id"] == "TX", sales_long["snap_TX"], + sales_long["snap_WI"], + ), + ).astype(np.int8) + + sales_long["event_binary"] = sales_long["event_name_1"].notna().astype(np.int8) + + sales_long["month"] = sales_long["month"].map(lambda x: _MONTHS[int(x) - 1]) + + _WEEKDAY_MAP = { + "Monday": 0, "Tuesday": 1, "Wednesday": 2, "Thursday": 3, + "Friday": 4, "Saturday": 5, "Sunday": 6, + } + sales_long["weekday"] = sales_long["weekday"].map(_WEEKDAY_MAP).astype(np.int8) + + sales_long.drop( + columns=["snap_CA", "snap_TX", "snap_WI", "event_name_1", + "wm_yr_wk", "d", "item_id"], + errors="ignore", + inplace=True, + ) + + self.data = sales_long.reset_index(drop=True) + + def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame: + """ + Groups data into calendar-month windows using only the first 28 days + of each month (days 29–31 are discarded). Each (id, year, month) + group therefore has exactly 28 days, giving fixed-length sequences + with no end-of-month drift and month as a clean context variable. + Continuous dynamic context channels are z-score normalised. + """ + data = data.copy() + + ts_ctx = list(self.context_series_names) # ["sell_price", "snap", "event_binary"] + tgt_ts = list(self.target_time_series_columns) # ["sales"] + all_ts = ts_ctx + tgt_ts + + # Keep only the first 28 days of each calendar month + data = data.sort_values(["id", "date"]).reset_index(drop=True) + data = data[data["date"].dt.day <= 28].reset_index(drop=True) + + # Group by calendar month; year and month are already columns + group_keys = ["id", "year", "month"] + static_cols = ["cat_id", "dept_id", "store_id", "state_id", "year", "month"] + + agg_dict = {c: list for c in all_ts} + agg_dict.update({c: "first" for c in static_cols if c not in group_keys}) + + grouped = ( + data.groupby(group_keys, as_index=False, sort=False) + .agg(agg_dict) + ) + + for c in all_ts: + grouped[c] = grouped[c].map(np.asarray) + + # Keep only complete 28-day windows + len_col = tgt_ts[0] + grouped = grouped[grouped[len_col].apply(len) == self.cfg.seq_len].reset_index(drop=True) + + grouped = self._handle_missing_data(grouped) + + # Z-score normalise continuous dynamic context; pass binaries through + binary_channels = {"snap", "event_binary"} + clip_bound = 5.0 + eps = 1e-8 + ctx_stats = {} + + for c in ts_ctx: + X = np.stack(grouped[c].values).astype(np.float32) + + if c in binary_channels: + grouped[c] = list(X) + continue + + mu = float(X.mean()) + sd = float(X.std()) + if sd < 1e-6: + sd = 1.0 + ctx_stats[c] = (mu, sd) + + Xn = np.clip((X - mu) / (sd + eps), -clip_bound, clip_bound).astype(np.float32) + grouped[c] = list(Xn) + + self.context_ts_stats_ = ctx_stats + + # Convert arrays → tuples (hashable, required by base class) + for c in all_ts: + grouped[c] = grouped[c].map(tuple) + + return grouped + + def _handle_missing_data(self, data: pd.DataFrame) -> pd.DataFrame: + numeric_series = [ + c for c in self.context_series_names if c not in self.categorical_time_series + ] + + # Drop windows where any numeric context series is entirely NaN + if numeric_series: + mask = data[numeric_series].applymap(is_all_nan).any(axis=1) + data = data[~mask] + + # Fill isolated NaNs in numeric context series with within-window mean + for col in numeric_series: + data[col] = data[col].apply(fill_with_row_mean) + + # Drop windows with NaN in any categorical time series (none expected here) + cat_cols = list(self.categorical_time_series.keys()) + if cat_cols: + mask = data[cat_cols].applymap(is_any_nan).any(axis=1) + data = data[~mask] + + # Drop windows with any NaN in target + for tcol in self.target_time_series_columns: + if tcol in data.columns: + data = data.loc[ + data[tcol].apply( + lambda x: not np.isnan(np.asarray(x, dtype=float)).any() + ) + ] + + # Drop windows where any target series sums to zero (all-zero sequences) + for tcol in self.target_time_series_columns: + if tcol in data.columns: + data = data.loc[ + data[tcol].apply(lambda x: np.asarray(x, dtype=float).sum() > 0) + ] + + # Drop near-constant windows (std < 0.01 → degenerate for diffusion) + def _low_std(row, cols, thresh=0.01): + return any(np.asarray(row[c], dtype=np.float32).std() < thresh for c in cols) + + mask = data.apply(lambda row: _low_std(row, self.target_time_series_columns), axis=1) + data = data[~mask] + return data diff --git a/cents/eval/eval.py b/cents/eval/eval.py index d3c1a72..5dab1e9 100644 --- a/cents/eval/eval.py +++ b/cents/eval/eval.py @@ -17,7 +17,11 @@ from cents.eval.discriminative_score import discriminative_score_metrics from cents.eval.eval_metrics import ( Context_FID, + calculate_banded_mse, calculate_mmd, + compute_cfs, + compute_context_recovery_score, + compute_gcp, compute_mig, compute_sap, dynamic_time_warping_dist, @@ -106,11 +110,16 @@ def evaluate_model( if data_generator is not None: if data_generator.model is not None: model = data_generator.model + print(f"[Cents] Using pre-trained model from DataGenerator") if data_generator.normalizer is not None: dataset._normalizer = data_generator.normalizer - print("[CENTS] Using pre-trained normalizer from DataGenerator") - elif not model: - model = self.get_trained_model(dataset) + print("[Cents] Using pre-trained normalizer from DataGenerator") + if not getattr(dataset.cfg, "normalize", True): + dataset.apply_pretrained_normalizer() + print("[Cents] Normalized dataset with pretrained normalizer") + else: + if not model: + model = self.get_trained_model(dataset) model.to(self.device) model.eval() @@ -182,19 +191,29 @@ def compute_quality_metrics( syn_data: np.ndarray, real_data_frame: pd.DataFrame, mask: Optional[np.ndarray] = None, + target: Optional[Dict] = None, + log_prefix: str = "", + context_arrays: Optional[Dict[str, np.ndarray]] = None, + ts_column_names: Optional[list] = None, + static_context_arrays: Optional[Dict[str, np.ndarray]] = None, + continuous_vars: Optional[list] = None, ) -> Dict: """ - Compute evaluation metrics and store them in current_results. + Compute evaluation metrics and store them in current_results (or in target if provided). Args: real_data (np.ndarray): Real data array (shape: [N, seq_len, dims]) syn_data (np.ndarray): Synthetic data array (shape: [N, seq_len, dims]) real_data_frame (pd.DataFrame): Real data subset (inverse-transformed) mask (Optional[np.ndarray]): Boolean array indicating which rows are "rare" + target (Optional[Dict]): If set, write metrics into this dict instead of current_results (for normalized_domain). + log_prefix (str): Prefix for log messages (e.g. "[normalized]"). + context_arrays (Optional[Dict[str, np.ndarray]]): Named dynamic context arrays (N, T) each. + ts_column_names (Optional[list]): Names of the time-series output dimensions, aligned with real_data axis-2. """ - logger.info(f"[Cents] --- Starting Full-Subset Metrics ---") + logger.info(f"[Cents] --- {log_prefix}Full-Subset Metrics ---") - metrics = {} + metrics = target if target is not None else {} # Compute and store metrics dtw_mean, dtw_std = dynamic_time_warping_dist(real_data, syn_data) @@ -205,6 +224,10 @@ def compute_quality_metrics( metrics["MMD"] = {"mean": mmd_mean, "std": mmd_std} logger.info(f"[Cents] MMD completed") + banded_mse = calculate_banded_mse(real_data, syn_data) + metrics["Banded_MSE"] = banded_mse + logger.info(f"[Cents] Banded MSE completed") + fid_score = Context_FID(real_data, syn_data) metrics["Context_FID"] = fid_score logger.info(f"[Cents] Context-FID completed") @@ -213,10 +236,27 @@ def compute_quality_metrics( metrics["Disc_Score"] = discr_score logger.info(f"[Cents] Discr Score completed") - pred_score = predictive_score_metrics(real_data, syn_data) + trtr = self.cfg.evaluator.get("pred_score_trtr", False) + pred_score = predictive_score_metrics(real_data, syn_data, trtr=trtr) metrics["Pred_Score"] = pred_score logger.info(f"[Cents] Pred Score completed") + # CFS / GCP — only when dynamic context or multivariate signal is present + cf_cfg = self.cfg.evaluator.get("eval_context_faithfulness", None) + if cf_cfg and cf_cfg.get("enabled", False) and context_arrays is not None: + cf_metrics = self._compute_context_faithfulness_metrics( + real_data, syn_data, context_arrays, ts_column_names, cf_cfg + ) + metrics["context_faithfulness"] = cf_metrics + + # Context Recovery Score — tests whether static context is encoded in generated outputs + if self.cfg.evaluator.get("eval_context_recovery", False) and static_context_arrays: + crs_score, crs_per_var = compute_context_recovery_score( + real_data, syn_data, static_context_arrays, continuous_vars=continuous_vars + ) + metrics["context_recovery"] = {"overall": crs_score, "per_var": crs_per_var} + logger.info(f"[Cents] Context Recovery Score completed: {crs_score:.4f}") + if mask is not None: logger.info("[Cents] Starting Rare-Subset Metrics") rare_metrics = {} @@ -224,34 +264,151 @@ def compute_quality_metrics( rare_syn_data = syn_data[mask] rare_real_df = real_data_frame[mask].reset_index(drop=True) - dtw_mean_r, dtw_std_r = dynamic_time_warping_dist( - rare_real_data, rare_syn_data - ) - rare_metrics["DTW"] = {"mean": dtw_mean_r, "std": dtw_std_r} - logger.info(f"[Cents] DTW completed") - - mmd_mean_r, mmd_std_r = calculate_mmd(rare_real_data, rare_syn_data) - rare_metrics["MMD"] = {"mean": mmd_mean_r, "std": mmd_std_r} - logger.info(f"[Cents] MMD completed") + # Check if rare subset has any valid samples + if len(rare_real_data) == 0: + logger.warning("[Cents] Rare subset is empty - skipping rare subset metrics") + rare_metrics = { + "DTW": {"mean": float('nan'), "std": float('nan')}, + "MMD": {"mean": float('nan'), "std": float('nan')}, + "Context_FID": float('nan'), + "Disc_Score": float('nan'), + "Pred_Score": float('nan') + } + else: + logger.info(f"[Cents] Rare subset contains {len(rare_real_data)} samples") + + dtw_mean_r, dtw_std_r = dynamic_time_warping_dist( + rare_real_data, rare_syn_data + ) + rare_metrics["DTW"] = {"mean": dtw_mean_r, "std": dtw_std_r} + logger.info(f"[Cents] DTW completed") + + mmd_mean_r, mmd_std_r = calculate_mmd(rare_real_data, rare_syn_data) + rare_metrics["MMD"] = {"mean": mmd_mean_r, "std": mmd_std_r} + logger.info(f"[Cents] MMD completed") + + fid_score_r = Context_FID(rare_real_data, rare_syn_data) + rare_metrics["Context_FID"] = fid_score_r + logger.info(f"[Cents] Context-FID completed") + + discr_score_r, _, _ = discriminative_score_metrics( + rare_real_data, rare_syn_data + ) + rare_metrics["Disc_Score"] = discr_score_r + logger.info(f"[Cents] Discr Score completed") + + pred_score_r = predictive_score_metrics(rare_real_data, rare_syn_data) + rare_metrics["Pred_Score"] = pred_score_r + logger.info(f"[Cents] Pred Score completed") - fid_score_r = Context_FID(rare_real_data, rare_syn_data) - rare_metrics["Context_FID"] = fid_score_r - logger.info(f"[Cents] Context-FID completed") + logger.info("[Cents] Done computing Rare-Subset Metrics.") + metrics["rare_subset"] = rare_metrics - discr_score_r, _, _ = discriminative_score_metrics( - rare_real_data, rare_syn_data - ) - rare_metrics["Disc_Score"] = discr_score_r - logger.info(f"[Cents] Discr Score completed") + if target is None: + self.current_results["metrics"] = metrics + return metrics - pred_score_r = predictive_score_metrics(rare_real_data, rare_syn_data) - rare_metrics["Pred_Score"] = pred_score_r - logger.info(f"[Cents] Pred Score completed") + def _compute_context_faithfulness_metrics( + self, + real_data: np.ndarray, + syn_data: np.ndarray, + context_arrays: Dict[str, np.ndarray], + ts_column_names: Optional[list], + cf_cfg, + ) -> Dict: + """ + Compute CFS and GCP for configured (x_dim, c_dim) pairs. - logger.info("[Cents] Done computing Rare-Subset Metrics.") - metrics["rare_subset"] = rare_metrics + Each pair specifies one x dimension (by name from ts_column_names, or by index) + and one c dimension (by name from context_arrays, or by name from ts_column_names + for within-signal multivariate pairs). - self.current_results["metrics"] = metrics + CFS is only computed when c comes from context_arrays (shared context). + GCP is computed for all pairs. + """ + results: Dict = {} + max_lag: int = int(cf_cfg.get("gcp_max_lag", 5)) + pairs_cfg = cf_cfg.get("pairs", []) + + if not pairs_cfg: + logger.warning("[Cents] eval_context_faithfulness.pairs is empty — skipping CFS/GCP.") + return results + + ts_names = list(ts_column_names) if ts_column_names else [] + + def _resolve_x(name) -> Optional[Tuple[np.ndarray, np.ndarray]]: + """Return (x_real_slice, x_synth_slice) of shape (N, T, 1).""" + if isinstance(name, int): + idx = name + elif name in ts_names: + idx = ts_names.index(name) + else: + logger.warning(f"[Cents] CFS/GCP: x dim '{name}' not found in ts_column_names {ts_names}; skipping.") + return None + return real_data[:, :, idx : idx + 1], syn_data[:, :, idx : idx + 1] + + def _resolve_c(name) -> Optional[Tuple[np.ndarray, np.ndarray, bool]]: + """Return (c_real, c_synth, is_shared) of shape (N, T, 1). is_shared=True if dynamic context.""" + if name in context_arrays: + arr = context_arrays[name] # (N, T) + if arr.ndim == 1: + arr = arr[:, np.newaxis] # edge-case: (N,) → (N, 1) + if arr.ndim == 2: + arr = arr[:, :, np.newaxis] # (N, T, 1) + return arr, arr, True # same array for real and synth + elif name in ts_names: + idx = ts_names.index(name) + c_r = real_data[:, :, idx : idx + 1] + c_s = syn_data[:, :, idx : idx + 1] + return c_r, c_s, False + else: + logger.warning(f"[Cents] CFS/GCP: c dim '{name}' not found in context_arrays or ts_column_names; skipping.") + return None + + for pair in pairs_cfg: + x_name = pair.get("x") + c_name = pair.get("c") + if x_name is None or c_name is None: + logger.warning(f"[Cents] CFS/GCP: pair {pair} missing 'x' or 'c' key; skipping.") + continue + + x_resolved = _resolve_x(x_name) + c_resolved = _resolve_c(c_name) + if x_resolved is None or c_resolved is None: + continue + + x_r, x_s = x_resolved + c_r, c_s, is_shared = c_resolved + pair_key = f"{x_name}_vs_{c_name}" + + pair_result: Dict = {} + + # CFS — only meaningful when context is shared (dynamic context) + if is_shared: + try: + cfs = compute_cfs(x_r, x_s, c_r) + pair_result["CFS"] = cfs + logger.info(f"[Cents] CFS completed for {pair_key}: {cfs:.4f}") + except Exception as e: + logger.warning(f"[Cents] CFS failed for {pair_key}: {e}") + pair_result["CFS"] = float("nan") + else: + logger.info(f"[Cents] CFS skipped for {pair_key} (c is a generated dim, not shared context).") + + # GCP — works for both shared and generated-dim context + try: + gcp, diag = compute_gcp(x_r, x_s, c_r, c_s, max_lag=max_lag) + pair_result["GCP"] = gcp + pair_result["GCP_diagnostics"] = diag + logger.info(f"[Cents] GCP completed for {pair_key}: {gcp:.4f}") + except Exception as e: + logger.warning(f"[Cents] GCP failed for {pair_key}: {e}") + pair_result["GCP"] = float("nan") + pair_result["GCP_diagnostics"] = {} + + results[pair_key] = pair_result + + return results def compute_disentanglement_metrics( self, @@ -333,14 +490,25 @@ def evaluate_subset( """ dataset.data = dataset.get_combined_rarity() real_data_subset = dataset.data.iloc[indices].reset_index(drop=True) - context_vars = { - name: torch.tensor( - real_data_subset[name].values, dtype=torch.long, device=self.device - ) - for name in dataset.context_vars - } - - generated_ts = model.generate(context_vars).cpu().numpy() + continuous_vars = getattr(dataset, "continuous_vars", []) + static_context_vars = {} + for name in dataset.static_context_vars: + vals = real_data_subset[name].values + dtype = torch.float32 if name in continuous_vars else torch.long + static_context_vars[name] = torch.tensor(vals, dtype=dtype, device=self.device) + dynamic_context_vars = {} + categorical_ts = getattr(dataset, "categorical_time_series", {}) + for name in dataset.dynamic_context_vars: + vals = real_data_subset[name].values + # Dynamic module expects tensors (training path uses torch.from_numpy in dataset __getitem__) + if len(vals) and hasattr(vals[0], "__len__") and not isinstance(vals[0], (str, bytes)): + arr = np.stack([np.asarray(v, dtype=np.float32 if name not in categorical_ts else np.int64) for v in vals]) + else: + arr = np.asarray(vals, dtype=np.float32 if name not in categorical_ts else np.int64) + dtype = torch.long if name in categorical_ts else torch.float32 + dynamic_context_vars[name] = torch.tensor(arr, dtype=dtype, device=self.device) + + generated_ts = model.generate(static_context_vars, dynamic_context_vars).cpu().numpy() if generated_ts.ndim == 2: generated_ts = generated_ts.reshape( generated_ts.shape[0], -1, generated_ts.shape[1] @@ -349,12 +517,48 @@ def evaluate_subset( syn_data_subset = real_data_subset.copy() syn_data_subset["timeseries"] = list(generated_ts) - real_data_inv = dataset.inverse_transform(real_data_subset) - syn_data_inv = dataset.inverse_transform(syn_data_subset) + # When normalize=False but a pretrained normalizer was applied via apply_pretrained_normalizer, + # dataset.inverse_transform() is a no-op (it checks self.normalize). Do the inverse manually. + normalizer = getattr(dataset, "_normalizer", None) + if not getattr(dataset, "normalize", True) and normalizer is not None: + def _inv(df): + split = dataset.split_timeseries(df.copy()) + split = normalizer.inverse_transform(split) + return dataset.merge_timeseries_columns(split) + real_data_inv = _inv(real_data_subset) + syn_data_inv = _inv(syn_data_subset) + else: + real_data_inv = dataset.inverse_transform(real_data_subset) + syn_data_inv = dataset.inverse_transform(syn_data_subset) real_data_array = np.stack(real_data_inv["timeseries"]) syn_data_array = np.stack(syn_data_inv["timeseries"]) + # Extract dynamic context as (N, T) numpy arrays for CFS/GCP + context_np: Dict[str, np.ndarray] = {} + for name, tensor in dynamic_context_vars.items(): + arr = tensor.cpu().numpy() # (N,) static scalar or (N, T) time series + if arr.ndim == 2: # (N, T) — keep time-series context only + context_np[name] = arr + + ts_col_names = list(getattr(dataset, "time_series_column_names", [])) + + # Decide whether CFS/GCP should run: multivariate signal OR dynamic context present + has_dynamic_context = len(context_np) > 0 + has_multivariate_signal = real_data_array.shape[-1] > 1 + cf_cfg = self.cfg.evaluator.get("eval_context_faithfulness", None) + run_cf = ( + cf_cfg is not None + and cf_cfg.get("enabled", False) + and (has_dynamic_context or has_multivariate_signal) + ) + + # Build static context numpy arrays for Context Recovery Score + static_context_np: Dict[str, np.ndarray] = { + name: tensor.cpu().numpy() for name, tensor in static_context_vars.items() + } + dataset_continuous_vars: list = getattr(dataset, "continuous_vars", []) + if self.cfg.evaluator.eval_metrics: rare_mask = None @@ -364,9 +568,37 @@ def evaluate_subset( ): rare_mask = real_data_subset["is_rare"].values + # Metrics in raw (un-normalized) domain self.compute_quality_metrics( - real_data_array, syn_data_array, real_data_inv, rare_mask + real_data_array, syn_data_array, real_data_inv, rare_mask, + log_prefix="[raw] ", + context_arrays=context_np if run_cf else None, + ts_column_names=ts_col_names if run_cf else None, + static_context_arrays=static_context_np if static_context_np else None, + continuous_vars=dataset_continuous_vars, ) + # Metrics in normalized (z) domain for cross-domain comparability. + # Fires whenever a normalizer is available — whether data was pre-normalized by dataset + # init (normalize=True) or normalized in-place via apply_pretrained_normalizer (normalize=False). + if ( + getattr(dataset, "_normalizer", None) is not None + and "timeseries" in real_data_subset.columns + ): + real_data_norm = np.stack(real_data_subset["timeseries"].values) + syn_data_norm = generated_ts + logger.info("[Cents] Computing metrics in normalized domain (z-space) for cross-domain comparison.") + normalized_metrics = {} + self.compute_quality_metrics( + real_data_norm, syn_data_norm, real_data_inv, rare_mask, + target=normalized_metrics, + log_prefix="[normalized] ", + context_arrays=context_np if run_cf else None, + ts_column_names=ts_col_names if run_cf else None, + static_context_arrays=static_context_np if static_context_np else None, + continuous_vars=dataset_continuous_vars, + ) + self.current_results["metrics"]["normalized_domain"] = normalized_metrics + if self.cfg.evaluator.eval_disentanglement: - self.compute_disentanglement_metrics(context_vars, model) + self.compute_disentanglement_metrics(static_context_vars, model) diff --git a/cents/eval/eval_metrics.py b/cents/eval/eval_metrics.py index c3d18a5..792d4c0 100644 --- a/cents/eval/eval_metrics.py +++ b/cents/eval/eval_metrics.py @@ -1,13 +1,36 @@ +import warnings from functools import partial -from typing import Dict, Tuple +from itertools import product +from typing import Any, Dict, List, Optional, Tuple import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy +import torch +import torch.nn as nn +import torch.optim as optim from dtaidistance import dtw -from sklearn.linear_model import Ridge -from sklearn.metrics import mutual_info_score, r2_score +from scipy.stats import f as f_dist +from scipy.stats import wasserstein_distance +from sklearn.linear_model import LogisticRegression, Ridge +from sklearn.metrics import mutual_info_score, r2_score, roc_auc_score +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from statsmodels.tsa.tsatools import lagmat + +_cfs_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class _CFSDiscriminator(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int): + super().__init__() + self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True) + self.fc = nn.Linear(hidden_dim, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, h_n = self.gru(x) + return self.fc(h_n.squeeze(0)) from cents.eval.eval_utils import ( gaussian_kernel_matrix, @@ -46,6 +69,56 @@ def dynamic_time_warping_dist(X: np.ndarray, Y: np.ndarray) -> Tuple[float, floa return np.mean(dtw_distances), np.std(dtw_distances) +def calculate_banded_mse( + real_data: np.ndarray, + syn_data: np.ndarray, + n_bands: int = 5, +) -> Dict[str, Any]: + """ + Compute MSE between real and synthetic time series broken down by amplitude band. + + Band boundaries are quantile-derived from the real data, so each band contains + an equal number of real observations. Per-band MSE is averaged over samples + that have at least one timestep in the band. + + Args: + real_data: Real time series (N, T, D). + syn_data: Synthetic time series (N, T, D). + n_bands: Number of quantile bands (default 5). + + Returns: + Dict with keys "band_1" … "band_N", each containing + {"mean": float, "std": float, "range": [lo, hi]}. + """ + assert real_data.shape == syn_data.shape, "real_data and syn_data must have the same shape" + N, T, D = real_data.shape + + edges = np.percentile(real_data.reshape(-1), np.linspace(0, 100, n_bands + 1)) + edges[-1] += 1e-8 # make upper bound inclusive for max value + + sq_err = (real_data - syn_data) ** 2 # (N, T, D) + results: Dict[str, Any] = {} + + for i in range(n_bands): + lo, hi = edges[i], edges[i + 1] + in_band = (real_data >= lo) & (real_data < hi) # (N, T, D) + + band_mses = [] + for n in range(N): + mask = in_band[n] # (T, D) + if mask.sum() == 0: + continue + band_mses.append(float(sq_err[n][mask].mean())) + + results[f"band_{i + 1}"] = { + "mean": float(np.mean(band_mses)) if band_mses else float("nan"), + "std": float(np.std(band_mses)) if band_mses else float("nan"), + "range": [round(float(lo), 4), round(float(hi), 4)], + } + + return results + + def calculate_period_bound_mse( real_dataframe: pd.DataFrame, synthetic_timeseries: np.ndarray ) -> Tuple[float, float]: @@ -147,12 +220,37 @@ def Context_FID(ori_data: np.ndarray, generated_data: np.ndarray) -> float: Calculate the FID score between original and generated data representations using TS2Vec embeddings. Args: - ori_data: Original time series data. - generated_data: Generated time series data. + ori_data: Original time series data (N, seq_len) or (N, seq_len, dims). + generated_data: Generated time series data, same shape convention. Returns: float: FID score between the original and generated data representations. """ + ori_data = np.asarray(ori_data, dtype=np.float32) + generated_data = np.asarray(generated_data, dtype=np.float32) + # TS2Vec expects (n_instance, n_timestamps, n_features); ensure 3D + if ori_data.ndim == 2: + ori_data = ori_data[:, :, np.newaxis] + if generated_data.ndim == 2: + generated_data = generated_data[:, :, np.newaxis] + if ori_data.ndim != 3 or generated_data.ndim != 3: + warnings.warn( + f"Context_FID: expected 2D or 3D arrays, got ori_data.ndim={ori_data.ndim}, generated_data.ndim={generated_data.ndim}; returning nan." + ) + return float("nan") + # Require at least one non–all-NaN row so TS2Vec.fit() does not infinite-loop + ori_valid = ~np.isnan(ori_data).all(axis=2).all(axis=1) + n_valid = int(ori_valid.sum()) + if n_valid == 0: + warnings.warn( + "Context_FID: ori_data has no valid (non–all-NaN) rows; returning nan." + ) + return float("nan") + # Allow single-sample (TS2Vec will get 1 batch); only reject when 0 valid + if np.isnan(ori_data).any() or np.isnan(generated_data).any(): + warnings.warn( + "Context_FID: ori_data or generated_data contain NaN; FID may be unreliable." + ) model = TS2Vec( input_dims=ori_data.shape[-1], device=0, @@ -161,14 +259,307 @@ def Context_FID(ori_data: np.ndarray, generated_data: np.ndarray) -> float: output_dims=320, max_train_length=50000, ) + + fit_log = model.fit(ori_data, verbose=False) model.fit(ori_data, verbose=False) - ori_represenation = model.encode(ori_data, encoding_window="full_series") - gen_represenation = model.encode(generated_data, encoding_window="full_series") + + ori_rep = model.encode(ori_data, encoding_window="full_series") + gen_rep = model.encode(generated_data, encoding_window="full_series") + idx = np.random.permutation(ori_data.shape[0]) - ori_represenation = ori_represenation[idx] - gen_represenation = gen_represenation[idx] - results = calculate_fid(ori_represenation, gen_represenation) - return results + ori_rep = ori_rep[idx] + gen_rep = gen_rep[idx] + + if not np.isfinite(ori_rep).all() or not np.isfinite(gen_rep).all(): + return float("nan") + + return calculate_fid(ori_rep, gen_rep) + + +def compute_cfs( + x_real: np.ndarray, + x_synth: np.ndarray, + c: np.ndarray, + iterations: int = 2000, + batch_size: int = 128, + test_ratio: float = 0.2, +) -> float: + """ + Compute Context Faithfulness Score (CFS). + + Trains a GRU to distinguish real (x, c) pairs from synthetic (x, c) pairs, + then returns 2 * |AUROC - 0.5|. + + Args: + x_real: (N, T, D_x) real time series + x_synth: (N, T, D_x) synthetic time series + c: (N, T, D_c) shared context (same for real and synthetic) + iterations: number of training steps + batch_size: training batch size + test_ratio: fraction of data held out for AUROC evaluation + + Returns: + float: CFS in [0, 1]. 0 = indistinguishable (perfect), 1 = fully separable (failed). + """ + N, T, _ = x_real.shape + + # Concatenate signal and context along feature dim: (N, T, D_x+D_c) + real_seq = np.concatenate([x_real, c], axis=-1).astype(np.float32) + synth_seq = np.concatenate([x_synth, c], axis=-1).astype(np.float32) + + # Drop samples with NaN + valid = ~(np.isnan(real_seq).any(axis=(1, 2)) | np.isnan(synth_seq).any(axis=(1, 2))) + real_seq, synth_seq = real_seq[valid], synth_seq[valid] + N = len(real_seq) + + if N < 4: + warnings.warn("compute_cfs: insufficient valid samples; returning nan.") + return float("nan") + + # Train/test split + idx = np.random.RandomState(42).permutation(N) + test_n = max(1, int(N * test_ratio)) + train_idx, test_idx = idx[test_n:], idx[:test_n] + + train_real, test_real = real_seq[train_idx], real_seq[test_idx] + train_synth, test_synth = synth_seq[train_idx], synth_seq[test_idx] + + input_dim = real_seq.shape[-1] + hidden_dim = int(input_dim * 2) + model = _CFSDiscriminator(input_dim, hidden_dim).to(_cfs_device) + optimizer = optim.Adam(model.parameters()) + criterion = nn.BCEWithLogitsLoss() + + n_train = len(train_real) + for _ in range(iterations): + idx_r = np.random.randint(0, n_train, batch_size) + idx_s = np.random.randint(0, n_train, batch_size) + + xr = torch.tensor(train_real[idx_r]).to(_cfs_device) + xs = torch.tensor(train_synth[idx_s]).to(_cfs_device) + + optimizer.zero_grad() + logits_r = model(xr) + logits_s = model(xs) + loss = criterion(logits_r, torch.ones_like(logits_r)) + \ + criterion(logits_s, torch.zeros_like(logits_s)) + loss.backward() + optimizer.step() + + with torch.no_grad(): + xr_test = torch.tensor(test_real).to(_cfs_device) + xs_test = torch.tensor(test_synth).to(_cfs_device) + prob_r = torch.sigmoid(model(xr_test)).cpu().numpy().squeeze() + prob_s = torch.sigmoid(model(xs_test)).cpu().numpy().squeeze() + + probs = np.concatenate([prob_r, prob_s]) + labels = np.concatenate([np.ones(len(prob_r)), np.zeros(len(prob_s))]) + + if len(np.unique(labels)) < 2: + warnings.warn("compute_cfs: test set has only one class; returning nan.") + return float("nan") + + auroc = float(roc_auc_score(labels, probs)) + return float(2.0 * abs(auroc - 0.5)) + + +def _build_lag_matrix(x: np.ndarray, max_lag: int) -> np.ndarray: + """Return (T-max_lag, max_lag) lag matrix with rows [x_{t-1}, ..., x_{t-L}].""" + return lagmat(x, maxlag=max_lag, trim="forward", original="ex")[max_lag:] + + +def _compute_f_stats_batch( + x_arr: np.ndarray, + c_arr: np.ndarray, + max_lag: int, + pairs: List[Tuple[int, int]], +) -> np.ndarray: + """Compute mean F-statistic per sample for the given (dx, dc) pairs.""" + N = x_arr.shape[0] + f_per_sample = [] + for i in range(N): + f_vals = [] + for dx, dc in pairs: + xi = x_arr[i, :, dx] + ci = c_arr[i, :, dc] + + if np.isnan(xi).any() or np.isnan(ci).any(): + continue + + X_x = _build_lag_matrix(xi, max_lag) # (T-max_lag, max_lag) + X_c = _build_lag_matrix(ci, max_lag) # (T-max_lag, max_lag) + y = xi[max_lag:] + + ones = np.ones((len(y), 1)) + X_r = np.hstack([ones, X_x]) # restricted: own lags only + X_u = np.hstack([ones, X_x, X_c]) # unrestricted: own + c lags + + beta_r, _, _, _ = np.linalg.lstsq(X_r, y, rcond=None) + beta_u, _, _, _ = np.linalg.lstsq(X_u, y, rcond=None) + + rss_r = float(np.sum((y - X_r @ beta_r) ** 2)) + rss_u = float(np.sum((y - X_u @ beta_u) ** 2)) + + df1 = max_lag + df2 = len(y) - 2 * max_lag - 1 + + if df2 <= 0 or rss_u < 1e-12: + continue + + F = ((rss_r - rss_u) / df1) / (rss_u / df2) + f_vals.append(F) + + if f_vals: + f_per_sample.append(float(np.mean(f_vals))) + + return np.array(f_per_sample) + + +def compute_gcp( + x_real: np.ndarray, + x_synth: np.ndarray, + c_real: np.ndarray, + c_synth: np.ndarray, + max_lag: int = 5, + alpha: float = 0.05, +) -> Tuple[float, Dict]: + """ + Compute Granger Causality Preservation (GCP). + + Measures how well synthetic data preserves the Granger-causal structure from + context c → signal x, via Wasserstein distance between F-statistic distributions. + + Args: + x_real: (N, T, D_x) real signal + x_synth: (N, T, D_x) synthetic signal + c_real: (N, T, D_c) context for real data + c_synth: (N, T, D_c) context for synthetic data + max_lag: maximum lag order (capped at T // 10 for short series) + alpha: significance threshold for diagnostic sig-rate computation + + Returns: + gcp: float >= 0. 0 = perfect preservation. Higher = more divergence. + diagnostics: dict with sig_rate_real, sig_rate_synth, sig_rate_delta + """ + N, T, D_x = x_real.shape + D_c = c_real.shape[-1] + + # Cap lag for short series + max_lag = min(max_lag, max(1, T // 10)) + + pairs = list(product(range(D_x), range(D_c))) + + F_real = _compute_f_stats_batch(x_real, c_real, max_lag, pairs) + F_synth = _compute_f_stats_batch(x_synth, c_synth, max_lag, pairs) + + if len(F_real) == 0 or len(F_synth) == 0: + warnings.warn("compute_gcp: no valid F-statistics computed; returning nan.") + return float("nan"), {} + + gcp = float(wasserstein_distance(F_real, F_synth)) + + # Diagnostics: significance rates + df2 = T - 2 * max_lag - 1 + if df2 > 0: + pvals_real = f_dist.sf(F_real, max_lag, df2) + pvals_synth = f_dist.sf(F_synth, max_lag, df2) + sig_real = float(np.mean(pvals_real < alpha)) + sig_synth = float(np.mean(pvals_synth < alpha)) + else: + sig_real = sig_synth = float("nan") + + diagnostics = { + "sig_rate_real": sig_real, + "sig_rate_synth": sig_synth, + "sig_rate_delta": float(abs(sig_real - sig_synth)) if not np.isnan(sig_real) else float("nan"), + } + + return gcp, diagnostics + + +def compute_context_recovery_score( + real_data: np.ndarray, + syn_data: np.ndarray, + context_labels: Dict[str, np.ndarray], + continuous_vars: Optional[List[str]] = None, + test_ratio: float = 0.2, +) -> Tuple[float, Dict[str, Dict]]: + """ + Context Recovery Score: measures whether static context is reflected in generated outputs. + + Trains a predictor f: timeseries -> context_label on real data, then evaluates + accuracy on synthetic data conditioned on the same labels. High score means + the model correctly encodes the conditioning variable in the output. + + For categorical variables: classification accuracy (chance = 1/n_classes). + For continuous variables: R² (0 = no recovery, 1 = perfect). + + Args: + real_data: (N, T, D) real time series + syn_data: (N, T, D) synthetic time series (same conditioning as real) + context_labels: {name: (N,) array of integer labels or float values} + continuous_vars: names of continuous context variables; all others treated as categorical + test_ratio: fraction of real data held out for real_baseline evaluation + + Returns: + overall_score: mean synth_score across all context variables + per_var: {name: {"synth_score": float, "real_baseline": float, "type": str}} + """ + if continuous_vars is None: + continuous_vars = [] + + N, T, D = real_data.shape + + def _features(data: np.ndarray) -> np.ndarray: + """Mean + std pool over time → (N, 2*D) fixed-size representation.""" + return np.concatenate([data.mean(axis=1), data.std(axis=1)], axis=-1) + + X_real = _features(real_data) + X_synth = _features(syn_data) + + rng = np.random.RandomState(42) + idx = rng.permutation(N) + test_size = max(1, int(N * test_ratio)) + train_idx = idx[test_size:] + test_idx = idx[:test_size] + + per_var: Dict[str, Dict] = {} + scores: List[float] = [] + + for name, labels in context_labels.items(): + labels_f = labels.astype(float) + if np.isnan(labels_f).any(): + per_var[name] = {"synth_score": float("nan"), "real_baseline": float("nan"), "type": "unknown"} + continue + + if name in continuous_vars: + y = labels_f + clf = Ridge(alpha=1e-3, fit_intercept=True) + clf.fit(X_real[train_idx], y[train_idx]) + real_baseline = float(r2_score(y[test_idx], clf.predict(X_real[test_idx]))) + synth_score = float(r2_score(y, clf.predict(X_synth))) + score_type = "r2" + else: + y = labels.astype(int) + n_classes = len(np.unique(y)) + if n_classes < 2: + per_var[name] = {"synth_score": float("nan"), "real_baseline": float("nan"), "type": "accuracy"} + continue + clf = make_pipeline(StandardScaler(), LogisticRegression(max_iter=2000, random_state=42)) + clf.fit(X_real[train_idx], y[train_idx]) + real_baseline = float(np.mean(clf.predict(X_real[test_idx]) == y[test_idx])) + synth_score = float(np.mean(clf.predict(X_synth) == y)) + score_type = "accuracy" + + per_var[name] = { + "synth_score": synth_score, + "real_baseline": real_baseline, + "type": score_type, + } + scores.append(synth_score) + + overall = float(np.mean(scores)) if scores else float("nan") + return overall, per_var def compute_mig( diff --git a/cents/eval/eval_utils.py b/cents/eval/eval_utils.py index e082804..5afd759 100644 --- a/cents/eval/eval_utils.py +++ b/cents/eval/eval_utils.py @@ -884,123 +884,6 @@ def create_visualizations( visualizations[f"TSNE_Dim_{i}"] = plot -# def evaluate_pv_shift(self, dataset: Any, model: Any): -# avg_shift = dataset.compute_average_pv_shift() -# if avg_shift is None or np.allclose(avg_shift, 0.0): -# return -# test_contexts = dataset.sample_shift_test_contexts() -# n_sampled = len(test_contexts) -# n_pv1_missing = sum(1 for c in test_contexts if c["missing_pv"] == 1) -# n_pv0_missing = sum(1 for c in test_contexts if c["missing_pv"] == 0) - -# print(f"[Shift Contexts] Sampled: {n_sampled}.") -# print(f"[Shift Contexts] PV=1 is missing in {n_pv1_missing} of these contexts.") -# print(f"[Shift Contexts] PV=0 is missing in {n_pv0_missing} of these contexts.") -# if len(test_contexts) == 0: -# return -# present_ctx_list = [] -# missing_ctx_list = [] -# present_pv_values = [] -# for cinfo in test_contexts: -# base_ctx = cinfo["base_context"] -# present_pv = cinfo["present_pv"] -# missing_pv = cinfo["missing_pv"] -# ctx_p = dict(base_ctx) -# ctx_m = dict(base_ctx) -# ctx_p["has_solar"] = present_pv -# ctx_m["has_solar"] = missing_pv -# present_ctx_list.append(ctx_p) -# missing_ctx_list.append(ctx_m) -# present_pv_values.append(present_pv) -# present_ctx_tensors = {} -# missing_ctx_tensors = {} -# all_keys = present_ctx_list[0].keys() -# for k in all_keys: -# present_ctx_tensors[k] = torch.tensor( -# [pc[k] for pc in present_ctx_list], dtype=torch.long, device=self.device -# ) -# missing_ctx_tensors[k] = torch.tensor( -# [mc[k] for mc in missing_ctx_list], dtype=torch.long, device=self.device -# ) -# with torch.no_grad(): -# syn_ts_present = model.generate(present_ctx_tensors) -# syn_ts_missing = model.generate(missing_ctx_tensors) -# syn_ts_present = syn_ts_present.cpu().numpy() -# syn_ts_missing = syn_ts_missing.cpu().numpy() -# if syn_ts_present.ndim == 3 and syn_ts_present.shape[-1] == 1: -# syn_ts_present = syn_ts_present[:, :, 0] -# syn_ts_missing = syn_ts_missing[:, :, 0] -# shifts = [] -# for i, pv_val in enumerate(present_pv_values): -# shift_i = syn_ts_missing[i] - syn_ts_present[i] -# if pv_val == 1: -# shift_i = -shift_i -# shifts.append(shift_i) -# shifts = np.array(shifts) -# avg_shift = np.asarray(avg_shift).reshape(-1) -# l2_values = [] -# for i in range(shifts.shape[0]): -# diff = shifts[i] - avg_shift -# l2 = np.sqrt((diff**2).sum()) -# l2_values.append(l2) -# mean_l2 = np.mean(l2_values) -# wandb.log({"Shift_L2": mean_l2}) - -# def find_context_matched_shift(dataset, cinfo): -# base_ctx = cinfo["base_context"] -# city_val = base_ctx.get("city", None) -# btype_val = base_ctx.get("building_type", None) -# df = dataset.data.copy() -# mask = pd.Series([True] * len(df)) -# if city_val is not None and "city" in df.columns: -# mask = mask & (df["city"] == city_val) -# if btype_val is not None and "building_type" in df.columns: -# mask = mask & (df["building_type"] == btype_val) -# df_matched = df[mask] -# if df_matched.empty: -# return None -# df_pv0 = df_matched[df_matched["has_solar"] == 0] -# df_pv1 = df_matched[df_matched["has_solar"] == 1] -# if df_pv0.empty or df_pv1.empty: -# return None -# ts_pv0 = np.stack(df_pv0["timeseries"].values, axis=0) -# ts_pv1 = np.stack(df_pv1["timeseries"].values, axis=0) -# mean_pv0 = ts_pv0.mean(axis=0) -# mean_pv1 = ts_pv1.mean(axis=0) -# mean_pv0_dim0 = mean_pv0[:, 0] -# mean_pv1_dim0 = mean_pv1[:, 0] -# real_shift = mean_pv1_dim0 - mean_pv0_dim0 -# return real_shift - -# matched_shifts = [] -# for cinfo in test_contexts: -# matched = find_context_matched_shift(dataset, cinfo) -# matched_shifts.append(matched) -# n_plots = min(6, shifts.shape[0]) -# for j, idx in enumerate( -# np.random.choice(shifts.shape[0], size=n_plots, replace=False) -# ): -# fig, ax = plt.subplots(figsize=(8, 4)) -# ax.plot(avg_shift, label="Real shift", color="red") -# ax.plot(shifts[idx], label="Synthetic shift", color="blue", linestyle="--") -# matched_s = matched_shifts[idx] -# if matched_s is not None: -# ax.plot( -# matched_s, -# label="Context-matched shift", -# color="green", -# linestyle=":", -# ) -# font_size = 12 -# ax.tick_params(axis="both", which="major", labelsize=font_size) -# ax.set_xlabel("Timestep", fontsize=font_size) -# ax.set_ylabel("kWh", fontsize=font_size) -# leg = ax.legend() -# leg.prop.set_size(font_size) -# fig.tight_layout() -# wandb.log({f"ShiftPlot_{j}": wandb.Image(fig)}) -# plt.close(fig) - def flatten_log_dict(d: Dict[str, Any], prefix: str = "") -> Dict[str, float]: """ diff --git a/cents/eval/predictive_score.py b/cents/eval/predictive_score.py index e4371c9..657cb5c 100644 --- a/cents/eval/predictive_score.py +++ b/cents/eval/predictive_score.py @@ -42,65 +42,86 @@ def forward(self, x, t): return y_hat -def predictive_score_metrics(ori_data, generated_data): - no, seq_len, dim = ori_data.shape - - ori_time, ori_max_seq_len = extract_time(ori_data) - generated_time, generated_max_seq_len = extract_time(generated_data) - max([ori_max_seq_len, generated_max_seq_len]) - +def _train_predictor(data, time, dim, iterations, batch_size): + """Train a GRU predictor on data and return the model.""" hidden_dim = max(int(dim / 2), 1) - iterations = 5000 - batch_size = 128 - model = Predictor(input_dim=dim, hidden_dim=hidden_dim).to(device) criterion = nn.L1Loss() optimizer = optim.Adam(model.parameters()) - for itt in tqdm( - range(iterations), - desc="[Cents] Training Predictive Score Model", - total=iterations, - ): - idx = np.random.permutation(len(generated_data)) + for _ in tqdm(range(iterations), desc="[Cents] Training Predictive Score Model", total=iterations): + idx = np.random.permutation(len(data)) train_idx = idx[:batch_size] - X_mb = [ - generated_data[i][:-1, :] for i in train_idx - ] # Use all dimensions for input - T_mb = [generated_time[i] - 1 for i in train_idx] - Y_mb = [ - generated_data[i][1:, :].reshape(-1, dim) for i in train_idx - ] # Predict all dimensions - - X_mb = torch.tensor(np.array(X_mb), dtype=torch.float32).to(device) - T_mb = torch.tensor(np.array(T_mb), dtype=torch.int64).to(device) - Y_mb = torch.tensor(np.array(Y_mb), dtype=torch.float32).to(device) + X_mb = torch.tensor(np.array([data[i][:-1, :] for i in train_idx]), dtype=torch.float32).to(device) + T_mb = torch.tensor(np.array([max(time[i] - 1, 1) for i in train_idx]), dtype=torch.int64).to(device) + Y_mb = torch.tensor(np.array([data[i][1:, :].reshape(-1, dim) for i in train_idx]), dtype=torch.float32).to(device) optimizer.zero_grad() - y_pred = model(X_mb, T_mb) - loss = criterion(y_pred, Y_mb) + loss = criterion(model(X_mb, T_mb), Y_mb) loss.backward() optimizer.step() - X_mb = [ori_data[i][:-1, :] for i in range(no)] - T_mb = [ori_time[i] - 1 for i in range(no)] - Y_mb = [ori_data[i][1:, :].reshape(-1, dim) for i in range(no)] + return model - X_mb = torch.tensor(np.array(X_mb), dtype=torch.float32).to(device) - T_mb = torch.tensor(np.array(T_mb), dtype=torch.int64).to(device) - Y_mb = torch.tensor(np.array(Y_mb), dtype=torch.float32).to(device) + +def _eval_mae(model, data, time, dim): + """Evaluate a predictor's MAE on data.""" + no = len(data) + X_mb = torch.tensor(np.array([data[i][:-1, :] for i in range(no)]), dtype=torch.float32).to(device) + T_mb = torch.tensor(np.array([max(time[i] - 1, 1) for i in range(no)]), dtype=torch.int64).to(device) + Y_mb = torch.tensor(np.array([data[i][1:, :].reshape(-1, dim) for i in range(no)]), dtype=torch.float32).to(device) with torch.no_grad(): y_pred = model(X_mb, T_mb) - MAE_temp = 0 - for i in range(no): - MAE_temp += mean_absolute_error(Y_mb[i].cpu().numpy(), y_pred[i].cpu().numpy()) + mae = sum( + mean_absolute_error(Y_mb[i].cpu().numpy(), y_pred[i].cpu().numpy()) + for i in range(no) + ) / no + return mae + + +def predictive_score_metrics(ori_data, generated_data, trtr: bool = False): + """ + Compute predictive score. + + Trains a GRU predictor on synthetic data (TSTR) and evaluates one-step-ahead + MAE on real data. When trtr=True, also trains on real data and returns the MAE + delta (TSTR_MAE - TRTR_MAE) alongside both individual MAEs. + + Args: + ori_data: Real time series (N, T, D). + generated_data: Synthetic time series (N, T, D). + trtr: If True, additionally train on real data for TRTR comparison. + + Returns: + float if trtr=False: TSTR MAE. + dict if trtr=True: {"tstr_mae": float, "trtr_mae": float, "delta": float}. + """ + no, seq_len, dim = ori_data.shape + ori_time, _ = extract_time(ori_data) + generated_time, _ = extract_time(generated_data) + + iterations = 5000 + batch_size = 128 + + # Train on synthetic, test on real (TSTR) + synth_model = _train_predictor(generated_data, generated_time, dim, iterations, batch_size) + tstr_mae = _eval_mae(synth_model, ori_data, ori_time, dim) + + if not trtr: + return tstr_mae - predictive_score = MAE_temp / no + # Train on real, test on real (TRTR) + real_model = _train_predictor(ori_data, ori_time, dim, iterations, batch_size) + trtr_mae = _eval_mae(real_model, ori_data, ori_time, dim) - return predictive_score + return { + "tstr_mae": tstr_mae, + "trtr_mae": trtr_mae, + "delta": tstr_mae - trtr_mae, + } def extract_time(data): diff --git a/cents/eval/t2vec/t2vec.py b/cents/eval/t2vec/t2vec.py index 4c1aafa..b3f4b2e 100644 --- a/cents/eval/t2vec/t2vec.py +++ b/cents/eval/t2vec/t2vec.py @@ -11,6 +11,8 @@ Note: Please ensure compliance with the repository's license and credit the original authors when using or distributing this code. """ +import warnings + import numpy as np import torch import torch.nn.functional as F @@ -113,13 +115,25 @@ def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False): train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)] + if len(train_data) == 0: + warnings.warn( + "TS2Vec.fit: no valid samples after dropping all-NaN rows; returning empty loss log." + ) + return [] + train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float)) + batch_size = min(self.batch_size, len(train_dataset)) train_loader = DataLoader( train_dataset, - batch_size=min(self.batch_size, len(train_dataset)), + batch_size=batch_size, shuffle=True, drop_last=True, ) + if len(train_loader) == 0: + warnings.warn( + "TS2Vec.fit: DataLoader has 0 batches (e.g. drop_last=True with too few samples); returning empty loss log." + ) + return [] optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.lr) diff --git a/cents/models/acgan.py b/cents/models/acgan.py index 2f7f84d..ebc94c7 100644 --- a/cents/models/acgan.py +++ b/cents/models/acgan.py @@ -14,11 +14,13 @@ import pytorch_lightning as pl import torch import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim from omegaconf import DictConfig from cents.models.base import GenerativeModel -from cents.models.context import ContextModule +from cents.models.context import MLPContextModule, SepMLPContextModule # Import to trigger registration +from cents.models.context_registry import get_context_module_cls from cents.models.model_utils import total_correlation from cents.models.registry import register_model @@ -43,9 +45,10 @@ def __init__( embedding_dim: int, final_window_length: int, time_series_dims: int, - context_module: ContextModule, + context_module_type: str, context_vars: Optional[dict] = None, base_channels: int = 256, + continuous_vars: Optional[list] = None, ): super().__init__() self.noise_dim = noise_dim @@ -56,7 +59,7 @@ def __init__( self.base_channels = base_channels self.context_vars = context_vars - self.context_module = context_module + self.context_module = get_context_module_cls(context_module_type)(context_vars, embedding_dim, continuous_vars=continuous_vars) in_dim = noise_dim + (embedding_dim if context_vars else 0) self.fc = nn.Linear(in_dim, self.final_window_length * base_channels) @@ -194,21 +197,37 @@ def __init__(self, cfg: DictConfig): # self.context_module = ContextModule( # cfg.dataset.context_vars, cfg.model.cond_emb_dim # ) + continuous_vars = getattr(cfg.dataset, "continuous_context_vars", None) or [] + # Get context module type from context config + from cents.utils.utils import get_context_config + context_cfg = get_context_config() + context_module_type = context_cfg.static_context.type + self.generator = Generator( noise_dim=cfg.model.noise_dim, embedding_dim=cfg.model.cond_emb_dim, final_window_length=cfg.dataset.seq_len, time_series_dims=cfg.dataset.time_series_dims, - context_module=self.context_module, + context_module_type=context_module_type, context_vars=cfg.dataset.context_vars, + continuous_vars=continuous_vars, ) + # Filter out continuous variables from context_vars for discriminator (only categorical needed) + categorical_context_vars = {k: v for k, v in cfg.dataset.context_vars.items() if k not in continuous_vars} self.discriminator = Discriminator( window_length=cfg.dataset.seq_len, time_series_dims=cfg.dataset.time_series_dims, - context_var_n_categories=cfg.dataset.context_vars, + context_var_n_categories=categorical_context_vars, ) self.adv_loss = nn.BCEWithLogitsLoss() self.aux_loss = nn.CrossEntropyLoss() + + # Get continuous variables from config to distinguish them in loss computation + self.continuous_context_vars = getattr(cfg.dataset, "continuous_context_vars", None) or [] + if isinstance(self.continuous_context_vars, (list, tuple)): + self.continuous_context_vars = set(self.continuous_context_vars) + else: + self.continuous_context_vars = set([self.continuous_context_vars]) if self.continuous_context_vars else set() def forward(self, noise: torch.Tensor, context_vars: dict): """ @@ -260,7 +279,15 @@ def training_step(self, batch: Any, batch_idx: int) -> None: if self.cfg.model.include_auxiliary_losses else 0.0 ) - g_ctx = sum(self.aux_loss(logits_ctx[v], ctx[v]) for v in logits_ctx) + # Context reconstruction loss: MSE for continuous, CSE for categorical + g_ctx = 0.0 + for v in logits_ctx: + if v in self.continuous_context_vars: + # Continuous variable: use MSE loss + g_ctx += F.mse_loss(logits_ctx[v], ctx[v].float()) + else: + # Categorical variable: use Cross-Entropy loss + g_ctx += self.aux_loss(logits_ctx[v], ctx[v].long()) h, _ = self.context_module(ctx) tc_term = ( diff --git a/cents/models/base.py b/cents/models/base.py index 29a6682..0221f73 100644 --- a/cents/models/base.py +++ b/cents/models/base.py @@ -3,9 +3,12 @@ import pandas as pd import pytorch_lightning as pl import torch +import torch.nn as nn from omegaconf import DictConfig -from cents.models.context import ContextModule +from cents.models.context import MLPContextModule, SepMLPContextModule, TransformerStaticContextModule # Import to trigger registration +from cents.models.context_registry import get_context_module_cls +from cents.utils.utils import get_context_config class BaseModel(pl.LightningModule, ABC): @@ -38,9 +41,94 @@ def __init__(self, cfg: DictConfig = None): if hasattr(cfg.dataset, "context_vars") and cfg.dataset.context_vars: emb_dim = getattr(cfg.model, "cond_emb_dim", 256) - self.context_module = ContextModule(cfg.dataset.context_vars, emb_dim) + # Get context module type from context config + context_cfg = get_context_config() + static_module_type = context_cfg.static_context.type + dynamic_module_type = getattr(context_cfg.dynamic_context, "type", None) + + # Separate static and dynamic context variables + continuous_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "continuous"] + categorical_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "categorical"] + dynamic_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "time_series"] + + static_context_vars = categorical_vars + continuous_vars + self.dynamic_context_vars = dynamic_vars + + # Create static context module (for categorical + continuous) + self.static_context_module = None + if static_context_vars: + StaticContextModuleCls = get_context_module_cls(static_module_type) + static_context_vars_dict = { + k: v for k, v in cfg.dataset.context_vars.items() + if k in static_context_vars + } + print(static_context_vars_dict) + static_ctx_kwargs = { + k: getattr(context_cfg.static_context, k) + for k in ("n_heads", "n_layers", "dropout", "dim_feedforward") + if hasattr(context_cfg.static_context, k) + } + self.static_context_module = StaticContextModuleCls( + static_context_vars_dict, + emb_dim, + **static_ctx_kwargs, + ) + + # Create dynamic context module (for time_series) + self.dynamic_context_module = None + if self.dynamic_context_vars and dynamic_module_type is not None: + num_ts_steps = getattr(cfg.dataset, "num_ts_steps", None) + DynamicContextModuleCls = get_context_module_cls("dynamic", dynamic_module_type) + dynamic_context_vars_dict = { + k: v for k, v in cfg.dataset.context_vars.items() + if k in self.dynamic_context_vars + } + seq_len = getattr(cfg.dataset, "seq_len", None) + if seq_len is None: + raise ValueError("seq_len must be specified in cfg.dataset for dynamic context modules") + self.dynamic_context_module = DynamicContextModuleCls( + dynamic_context_vars_dict, + emb_dim, + seq_len=seq_len if num_ts_steps is None else num_ts_steps, + ) + + # Determine embedding dimension + if self.static_context_module is not None: + self.embedding_dim = self.static_context_module.embedding_dim + elif self.dynamic_context_module is not None: + self.embedding_dim = self.dynamic_context_module.embedding_dim + else: + raise ValueError("At least one of static_context_module or dynamic_context_module must be provided") + + # combine_mlp is only needed when the dynamic module returns a pooled + # vector (returns_sequence=False) that must be fused with the static + # embedding before AdaLN conditioning. When returns_sequence=True the + # dynamic context is routed to cross-attention inside the backbone and + # must NOT be merged with the static embedding here. + dyn_returns_seq = getattr(self.dynamic_context_module, "returns_sequence", False) \ + if self.dynamic_context_module is not None else False + + if (self.static_context_module is not None + and self.dynamic_context_module is not None + and not dyn_returns_seq): + combined_dim = self.static_context_module.embedding_dim + self.dynamic_context_module.embedding_dim + self.combine_mlp = nn.Sequential( + nn.Linear(combined_dim, self.embedding_dim), + nn.ReLU(), + ) + else: + self.combine_mlp = None + + # For backward compatibility, expose static_context_module as context_module + # (but subclasses should use static_context_module and dynamic_context_module directly) + self.context_module = self.static_context_module else: + self.static_context_module = None + self.dynamic_context_module = None self.context_module = None + self.combine_mlp = None + self.dynamic_context_vars = [] + self.embedding_dim = getattr(cfg.model, "cond_emb_dim", 256) if cfg is not None else 256 @abstractmethod def forward(self, *args, **kwargs): diff --git a/cents/models/context.py b/cents/models/context.py index 5fc908a..e843024 100644 --- a/cents/models/context.py +++ b/cents/models/context.py @@ -1,8 +1,19 @@ import torch import torch.nn as nn +from abc import abstractmethod +from typing import Optional +from .context_registry import register_context_module +class BaseContextModule(nn.Module): + """ + Base class for context modules. Subclasses must implement the forward method. + """ + @abstractmethod + def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + pass -class ContextModule(nn.Module): +@register_context_module("default", "mlp") +class MLPContextModule(BaseContextModule): """ Integrates multiple context variables into a single embedding and provides auxiliary classification logits for each variable. @@ -28,10 +39,9 @@ def __init__(self, context_vars: dict[str, int], embedding_dim: int): """ super().__init__() self.embedding_dim = embedding_dim - self.context_embeddings = nn.ModuleDict( { - name: nn.Embedding(num_categories, embedding_dim) + name: nn.Embedding(num_categories[1], embedding_dim) for name, num_categories in context_vars.items() } ) @@ -45,7 +55,7 @@ def __init__(self, context_vars: dict[str, int], embedding_dim: int): self.classification_heads = nn.ModuleDict( { - var_name: nn.Linear(embedding_dim, num_categories) + var_name: nn.Linear(embedding_dim, num_categories[1]) for var_name, num_categories in context_vars.items() } ) @@ -64,9 +74,12 @@ def forward( classification_logits (Dict[str, Tensor]): Logits per variable, each of shape (batch_size, num_categories). """ - embeddings = [ - layer(context_vars[name]) for name, layer in self.context_embeddings.items() - ] + embeddings = [] + for name, layer in self.context_embeddings.items(): + idx = context_vars[name] + if idx.dtype in (torch.long, torch.int, torch.int32, torch.int64): + idx = idx.clamp(0, layer.num_embeddings - 1) + embeddings.append(layer(idx)) context_matrix = torch.cat(embeddings, dim=1) embedding = self.mlp(context_matrix) @@ -77,3 +90,573 @@ def forward( } return embedding, classification_logits + +@register_context_module("default", "sep_mlp") +class SepMLPContextModule(BaseContextModule): + def __init__( + self, + context_vars: dict[str, int], + embedding_dim: int, + init_depth: int = 1, + mixing_depth: int = 1, + ) -> None: + """ + Initialize SepMLPContextModule. + + Args: + context_vars: Mapping of variable names to category counts. + embedding_dim: Size of embedding vectors. + init_depth: Depth of initial MLPs. + mixing_depth: Depth of mixing MLP. + continuous_vars: List of continuous variable names. + """ + super().__init__() + + self.embedding_dim = embedding_dim + self.continuous_vars = [k for k, v in context_vars.items() if v[0] == "continuous"] + self.categorical_vars = {k: v[1] for k, v in context_vars.items() if v[0] == "categorical"} + self.context_embeddings = nn.ModuleDict( + { + name: nn.Embedding(num_categories, embedding_dim) + for name, num_categories in self.categorical_vars.items() + } + ) + + # For continuous variables, use a simple linear projection + self.continuous_projections = nn.ModuleDict( + { + name: nn.Linear(1, embedding_dim) + for name in self.continuous_vars + } + ) + + self.init_mlps = nn.ModuleDict({ + name: nn.Sequential(*[ + layer + for _ in range(init_depth) + for layer in (nn.Linear(embedding_dim, embedding_dim), nn.ReLU(), nn.Linear(embedding_dim, embedding_dim)) + ]) + for name in self.categorical_vars.keys() + }) + + # Also create init MLPs for continuous variables + self.continuous_init_mlps = nn.ModuleDict({ + name: nn.Sequential(*[ + layer + for _ in range(init_depth) + for layer in (nn.Linear(embedding_dim, embedding_dim), nn.ReLU(), nn.Linear(embedding_dim, embedding_dim)) + ]) + for name in self.continuous_vars + }) + + total_dim = embedding_dim * (len(self.categorical_vars) + len(self.continuous_vars)) + + self.mixing_mlp = nn.Sequential( + nn.Linear(total_dim, 128), + nn.ReLU(), + nn.Linear(128, embedding_dim)) + + self.classification_heads = nn.ModuleDict( + { + var_name: nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.ReLU(), + nn.Linear(embedding_dim, num_categories) + ) + for var_name, num_categories in self.categorical_vars.items() + } + ) + + # Regression heads for continuous variables (output single value for MSE loss) + self.regression_heads = nn.ModuleDict( + { + var_name: nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.ReLU(), + nn.Linear(embedding_dim, 1) + ) + for var_name in self.continuous_vars + } + ) + + def forward(self, context_vars): + encodings = {} + + # Process categorical variables (only those present in context_vars) + for name, layer in self.context_embeddings.items(): + if name in context_vars: + idx = context_vars[name] + if idx.dtype in (torch.long, torch.int, torch.int32, torch.int64): + idx = idx.clamp(0, layer.num_embeddings - 1) + encodings[name] = layer(idx) + + # Process continuous variables (only those present in context_vars) + for name, layer in self.continuous_projections.items(): + if name in context_vars: + continuous_val = context_vars[name].float() + # DataLoader stacks 0-dim scalars into (batch,); layer expects (batch, 1) + if continuous_val.dim() == 1: + continuous_val = continuous_val.unsqueeze(-1) + encodings[name] = layer(continuous_val) + + embeddings = [] + # Apply init MLPs to categorical variables + for name, layer in self.init_mlps.items(): + if name in encodings: + embeddings.append(layer(encodings[name])) + + # Apply init MLPs to continuous variables + for name, layer in self.continuous_init_mlps.items(): + if name in encodings: + embedding_output = layer(encodings[name]) + # Check for NaN in embedding output + if torch.isnan(embedding_output).any(): + raise ValueError( + f"NaN detected in embedding output for continuous variable '{name}' " + f"after init MLP. This may indicate numerical instability in the MLP layers." + ) + embeddings.append(embedding_output) + + context_matrix = torch.cat(embeddings, dim=1) + embedding = self.mixing_mlp(context_matrix) + + classification_logits = { + var_name: head(embedding) + for var_name, head in self.classification_heads.items() + if var_name in context_vars + } + + # Regression outputs for continuous variables + regression_outputs = { + var_name: head(embedding).squeeze(-1) # Remove last dim to get (batch_size,) + for var_name, head in self.regression_heads.items() + if var_name in context_vars + } + + all_outputs = {**classification_logits, **regression_outputs} + + return embedding, all_outputs + + +@register_context_module("transformer") +class TransformerStaticContextModule(BaseContextModule): + """ + Transformer-based static context embedder. + + Each context variable is projected to a token of size embedding_dim, + augmented with a per-variable type embedding, then normalised before + being fed into a shared Transformer encoder. Mean-pooling across the + variable tokens produces the final (B, embedding_dim) conditioning vector. + + Compared to MLPContextModule: + - Attention captures interactions between context variables + - pre-LN (norm_first=True) and GELU throughout → more stable gradients + - No hardcoded bottleneck; width is controlled by dim_feedforward + """ + + def __init__( + self, + context_vars: dict[str, list], + embedding_dim: int, + n_heads: int = 4, + n_layers: int = 2, + dropout: float = 0.1, + dim_feedforward: int = 256, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.continuous_vars = [k for k, v in context_vars.items() if v[0] == "continuous"] + self.categorical_vars = {k: v[1] for k, v in context_vars.items() if v[0] == "categorical"} + self.var_names = list(self.categorical_vars.keys()) + self.continuous_vars + + self.cat_embeddings = nn.ModuleDict({ + name: nn.Embedding(n_cats, embedding_dim) + for name, n_cats in self.categorical_vars.items() + }) + self.cont_projections = nn.ModuleDict({ + name: nn.Linear(1, embedding_dim) + for name in self.continuous_vars + }) + # Per-variable learnable offset so attention can distinguish variable identity + self.type_embeddings = nn.Embedding(len(self.var_names), embedding_dim) + self.register_buffer("_var_indices", torch.arange(len(self.var_names))) + + # Normalise tokens before encoder to equalize scales across embed/projection types + self.token_norm = nn.LayerNorm(embedding_dim) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=n_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + self.output_norm = nn.LayerNorm(embedding_dim) + + self.classification_heads = nn.ModuleDict({ + name: nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.GELU(), + nn.Linear(embedding_dim, n_cats), + ) + for name, n_cats in self.categorical_vars.items() + }) + self.regression_heads = nn.ModuleDict({ + name: nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.GELU(), + nn.Linear(embedding_dim, 1), + ) + for name in self.continuous_vars + }) + + self._init_weights() + + def _init_weights(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + nn.init.normal_(self.type_embeddings.weight, std=0.02) + + def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + all_type_embs = self.type_embeddings(self._var_indices) # (N_vars, D) + tokens = [] + for i, name in enumerate(self.var_names): + if name in self.categorical_vars: + idx = context_vars[name] + if idx.dtype in (torch.long, torch.int, torch.int32, torch.int64): + idx = idx.clamp(0, self.cat_embeddings[name].num_embeddings - 1) + tok = self.cat_embeddings[name](idx) + all_type_embs[i] + else: + val = context_vars[name].float() + if val.dim() == 1: + val = val.unsqueeze(-1) + tok = self.cont_projections[name](val) + all_type_embs[i] + tokens.append(self.token_norm(tok)) + + x = torch.stack(tokens, dim=1) # (B, N_vars, D) + x = self.encoder(x) # (B, N_vars, D) + embedding = self.output_norm(x.mean(dim=1)) # (B, D) + + classification_logits = { + name: head(embedding) + for name, head in self.classification_heads.items() + if name in context_vars + } + regression_outputs = { + name: head(embedding).squeeze(-1) + for name, head in self.regression_heads.items() + if name in context_vars + } + return embedding, {**classification_logits, **regression_outputs} + + +@register_context_module("dynamic_transformer") +class DynamicContextModule_Transformer(BaseContextModule): + """ + Context module for processing dynamic (time series) context variables. + Uses Transformer encoder to encode time series sequences into embeddings. + + Returns the full encoded sequence (B, T, embedding_dim) rather than a + pooled vector, so temporal structure is preserved for cross-attention + conditioning in the diffusion backbone. + """ + + # Signals to BaseModel that this module returns (B, T, emb_dim) not (B, emb_dim) + returns_sequence = True + + def __init__( + self, + context_vars: dict[str, int], + embedding_dim: int, + seq_len: int = None, + n_layers: int = 2, + n_heads: int = 4, + dropout: float = 0.1, + dim_feedforward: int = 256, + ): + """ + Initialize DynamicContextModule_Transformer. + + Args: + context_vars: Mapping of variable names to category counts (for categorical time series) + or None (for numeric time series). Format: {name: [type, num_categories]} + embedding_dim: Size of embedding vectors. + seq_len: Sequence length of time series context variables. + n_layers: Number of transformer encoder layers. + n_heads: Number of attention heads. + dropout: Dropout probability. + dim_feedforward: Dimension of feedforward network in transformer. + """ + super().__init__() + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + # Separate categorical and numeric time series + self.categorical_ts_vars = { + k: v[1] for k, v in context_vars.items() + if v[0] == "time_series" and v[1] is not None + } + self.numeric_ts_vars = [ + k for k, v in context_vars.items() + if v[0] == "time_series" and v[1] is None + ] + + # For categorical time series, use embedding + self.ts_embeddings = nn.ModuleDict({ + name: nn.Embedding(num_categories, embedding_dim) + for name, num_categories in self.categorical_ts_vars.items() + }) + + # For numeric time series, use linear projection to embedding_dim + self.ts_projections = nn.ModuleDict({ + name: nn.Linear(1, embedding_dim) + for name in self.numeric_ts_vars + }) + + # Positional encoding for transformer + if seq_len is not None: + self.pos_encodings = nn.ParameterDict({ + name: nn.Parameter(torch.zeros(1, seq_len, embedding_dim)) + for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars + }) + else: + self.pos_encodings = None + + # Transformer encoder for each time series variable + encoder_layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=n_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation='gelu', + batch_first=True, + ) + + self.ts_encoders = nn.ModuleDict({ + name: nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + for name in list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars + }) + + + n_vars = len(self.categorical_ts_vars) + len(self.numeric_ts_vars) + self.var_mix = nn.Linear(n_vars * embedding_dim, embedding_dim) if n_vars > 1 else None + + all_var_names = list(self.categorical_ts_vars.keys()) + self.numeric_ts_vars + self.post_encoder_norms = nn.ModuleDict({ + name: nn.LayerNorm(embedding_dim) + for name in all_var_names + }) + self.post_mix_norm = nn.LayerNorm(embedding_dim) + + # Initialize weights + self._initialize_weights() + + def _initialize_weights(self): + """ + Initialize weights using appropriate initialization strategies. + """ + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.Parameter): + if module.dim() == 3: # (1, seq_len, embedding_dim) + nn.init.normal_(module, std=0.02) + + def forward(self, context_vars: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """ + Process dynamic (time series) context variables using Transformer. + + Args: + context_vars: Dict mapping variable names to tensors. + For categorical TS: (batch, seq_len) integer values. + For numeric TS: (batch, seq_len) float values. + + Returns: + sequence: Combined sequence of shape (batch, seq_len, embedding_dim). + Temporal structure is preserved for downstream cross-attention. + outputs: Empty dict for interface compatibility. + """ + sequences = [] + + # Process categorical time series + for name in self.categorical_ts_vars.keys(): + if name in context_vars: + ts_data = context_vars[name] # (B, T) + if torch.isnan(ts_data).any() or torch.isinf(ts_data).any(): + raise ValueError(f"NaN/Inf in categorical TS input '{name}'") + embedded = self.ts_embeddings[name](ts_data) # (B, T, emb_dim) + if self.pos_encodings is not None and name in self.pos_encodings: + embedded = embedded + self.pos_encodings[name][:, :embedded.size(1)] + encoded = self.ts_encoders[name](embedded) # (B, T, emb_dim) + encoded = self.post_encoder_norms[name](encoded) + if torch.isnan(encoded).any() or torch.isinf(encoded).any(): + raise ValueError(f"NaN/Inf after transformer encoding '{name}'") + sequences.append(encoded) + + # Process numeric time series + for name in self.numeric_ts_vars: + if name in context_vars: + ts_data = context_vars[name] # (B, T) + if not ts_data.is_floating_point(): + ts_data = ts_data.float() + ts_data = torch.where(torch.isfinite(ts_data), ts_data, torch.zeros_like(ts_data)) + # Per-sample z-score normalisation over the time axis + ts_mean = ts_data.mean(dim=1, keepdim=True) + ts_std = ts_data.std(dim=1, keepdim=True) + 1e-8 + ts_data = (ts_data - ts_mean) / ts_std + embedded = self.ts_projections[name](ts_data.unsqueeze(-1)) # (B, T, emb_dim) + if self.pos_encodings is not None and name in self.pos_encodings: + embedded = embedded + self.pos_encodings[name][:, :embedded.size(1)] + if torch.isnan(embedded).any() or torch.isinf(embedded).any(): + raise ValueError(f"NaN/Inf after projection for '{name}'") + encoded = self.ts_encoders[name](embedded) # (B, T, emb_dim) + encoded = self.post_encoder_norms[name](encoded) + if torch.isnan(encoded).any() or torch.isinf(encoded).any(): + raise ValueError(f"NaN/Inf after transformer encoding numeric TS '{name}'") + sequences.append(encoded) + + if not sequences: + device = next(iter(context_vars.values())).device if context_vars else None + batch_size = next(iter(context_vars.values())).size(0) if context_vars else 1 + T = self.seq_len if self.seq_len is not None else 1 + return torch.zeros(batch_size, T, self.embedding_dim, device=device), {} + + if len(sequences) == 1: + out = sequences[0] + elif self.var_mix is not None: + # Learned combination across variables: concat along feature axis, project back + out = self.var_mix(torch.cat(sequences, dim=-1)) # (B, T, emb_dim) + else: + # Single-variable fallback (var_mix is None only when n_vars == 1) + out = sequences[0] + + out = self.post_mix_norm(out) + + if torch.isnan(out).any() or torch.isinf(out).any(): + raise ValueError("NaN/Inf in dynamic context sequence output") + + return out, {} + + +@register_context_module("dynamic_joint_transformer") +class DynamicContextModule_JointTransformer(BaseContextModule): + """ + Joint multi-channel encoder for dynamic context. + + All numeric time-series variables are stacked as channels (B, T, n_vars), + projected jointly to (B, T, embedding_dim), then encoded by a single shared + Transformer. Self-attention operates across time steps while seeing all + variables simultaneously, allowing it to learn non-linear variable + interactions (e.g. high TI + low wind → elevated PM2.5). + + Compare to DynamicContextModule_Transformer which runs one independent + transformer per variable and combines them with a linear mix — that + architecture can only learn additive contributions. + """ + + returns_sequence = True + + def __init__( + self, + context_vars: dict, + embedding_dim: int, + seq_len: int = None, + n_layers: int = 2, + n_heads: int = 4, + dropout: float = 0.1, + dim_feedforward: int = 256, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.seq_len = seq_len + + self.numeric_ts_vars = [ + k for k, v in context_vars.items() + if v[0] == "time_series" and v[1] is None + ] + self.categorical_ts_vars = { + k: v[1] for k, v in context_vars.items() + if v[0] == "time_series" and v[1] is not None + } + + n_numeric = len(self.numeric_ts_vars) + + if n_numeric > 0: + self.numeric_input_proj = nn.Linear(n_numeric, embedding_dim) + else: + self.numeric_input_proj = None + + # Categorical time-series: embed each to emb_dim and add into the joint repr + self.ts_cat_embeddings = nn.ModuleDict({ + name: nn.Embedding(num_categories, embedding_dim) + for name, num_categories in self.categorical_ts_vars.items() + }) + + if seq_len is not None: + self.pos_encoding = nn.Parameter(torch.zeros(1, seq_len, embedding_dim)) + else: + self.pos_encoding = None + + encoder_layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=n_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation='gelu', + batch_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) + self.post_norm = nn.LayerNorm(embedding_dim) + + self._init_weights() + + def _init_weights(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + if self.pos_encoding is not None: + nn.init.normal_(self.pos_encoding, std=0.02) + + def forward(self, context_vars: dict) -> tuple: + numeric_tensors = [] + for name in self.numeric_ts_vars: + if name in context_vars: + ts = context_vars[name] + if not ts.is_floating_point(): + ts = ts.float() + ts = torch.where(torch.isfinite(ts), ts, torch.zeros_like(ts)) + ts_mean = ts.mean(dim=1, keepdim=True) + ts_std = ts.std(dim=1, keepdim=True) + 1e-8 + ts = (ts - ts_mean) / ts_std + numeric_tensors.append(ts) + + if numeric_tensors and self.numeric_input_proj is not None: + # (B, T, n_vars) → (B, T, emb_dim) + x = self.numeric_input_proj(torch.stack(numeric_tensors, dim=-1)) + else: + device = next(iter(context_vars.values())).device + B = next(iter(context_vars.values())).size(0) + T = self.seq_len or 1 + x = torch.zeros(B, T, self.embedding_dim, device=device) + + for name, emb_layer in self.ts_cat_embeddings.items(): + if name in context_vars: + x = x + emb_layer(context_vars[name]) # (B, T, emb_dim) + + if self.pos_encoding is not None: + x = x + self.pos_encoding[:, :x.size(1)] + + x = self.encoder(x) + x = self.post_norm(x) + + return x, {} \ No newline at end of file diff --git a/cents/models/context_registry.py b/cents/models/context_registry.py new file mode 100644 index 0000000..3743768 --- /dev/null +++ b/cents/models/context_registry.py @@ -0,0 +1,65 @@ +from typing import Dict + +_CONTEXT_MODULE_REGISTRY = {} + + +def register_context_module(*names): + """ + Decorator: registers a context module class under one or more names. + + Args: + *names: One or more names to register the class under. + + Example: + @register_context_module("default", "mlp") + class MLPContextModule(BaseContextModule): + pass + """ + + def decorator(cls): + for name in names: + _CONTEXT_MODULE_REGISTRY[name] = cls + return cls + + return decorator + + +def get_context_module_cls(key: str, subkey: str = None) -> type: + """ + Fetch the context module class for `key` (and optionally `subkey`). Raises if not found. + + Args: + key: The name of the context module to retrieve (e.g., "default", "dynamic"c). + subkey: Optional subkey for two-part registration (e.g., "mlp", "cnn"). + + Returns: + The context module class. + + Raises: + ValueError: If the key is not found in the registry. + """ + # Try two-part key first if subkey is provided + if subkey is not None: + two_part_key = f"{key}_{subkey}" + if two_part_key in _CONTEXT_MODULE_REGISTRY: + return _CONTEXT_MODULE_REGISTRY[two_part_key] + + # Try single key + try: + return _CONTEXT_MODULE_REGISTRY[key] + except KeyError: + available = list(_CONTEXT_MODULE_REGISTRY.keys()) + raise ValueError( + f"Unknown context module '{key}'" + (f" with subkey '{subkey}'" if subkey else "") + + f". Available: {available}" + ) + + +def get_available_context_modules() -> list[str]: + """ + Get a list of all available context module names. + + Returns: + List of available context module names. + """ + return list(_CONTEXT_MODULE_REGISTRY.keys()) diff --git a/cents/models/diffusion_ts.py b/cents/models/diffusion_ts.py index fadf2e8..0717e55 100644 --- a/cents/models/diffusion_ts.py +++ b/cents/models/diffusion_ts.py @@ -2,6 +2,8 @@ import math from typing import Any, Optional, Tuple +from pyparsing import alphas + import pytorch_lightning as pl import torch import torch.nn as nn @@ -10,6 +12,9 @@ from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau from tqdm.auto import tqdm +from contextlib import contextmanager + + from cents.models.base import GenerativeModel from cents.models.model_utils import ( @@ -18,9 +23,71 @@ default, linear_beta_schedule, total_correlation, + cosine_beta_schedule_logsnr, ) from cents.models.registry import register_model +def _randn_like_correlated( + x: torch.Tensor, correlated: bool +) -> torch.Tensor: + """White noise with same shape as x. If correlated and C>1, same noise broadcast across last dim.""" + if not correlated or x.dim() < 3 or x.shape[-1] == 1: + return torch.randn_like(x) + return torch.randn(*x.shape[:-1], 1, device=x.device, dtype=x.dtype).expand_as(x).clone() + + +def _randn_shape_correlated( + shape: tuple, device: torch.device, dtype: torch.dtype, correlated: bool +) -> torch.Tensor: + """Randn with given shape. If correlated and shape[-1]>1, same noise broadcast across last dim.""" + if not correlated or len(shape) < 3 or shape[-1] == 1: + return torch.randn(shape, device=device, dtype=dtype) + B, L, C = shape[0], shape[1], shape[2] + return torch.randn(B, L, 1, device=device, dtype=dtype).expand(shape).clone() + + +def blueish_noise_like( + x: torch.Tensor, power: float = 1.0, eps: float = 1e-6, correlated: bool = False +) -> torch.Tensor: + """ + Generate 'blue-ish' noise: more energy at higher frequencies. + - power = 0.0 -> white noise + - power > 0.0 -> increasingly high-frequency-heavy (blue/violet-ish) + Returns noise with ~unit std per sample/channel so diffusion scaling stays consistent. + - correlated: if True and x has multiple channels (C>1), same noise is used for all channels. + + x: (B, L, C) where L is time dimension. + """ + B, L, C = x.shape + + if power == 0.0: + return _randn_like_correlated(x, correlated) + + # When correlated and C>1, generate (B, L, 1) then expand after FFT shaping + if correlated and C > 1: + n = torch.randn(B, L, 1, device=x.device, dtype=torch.float32) + else: + n = torch.randn(B, L, C, device=x.device, dtype=torch.float32) + + # real FFT over time + N = torch.fft.rfft(n, dim=1) # (B, F, C) or (B, F, 1) + freqs = torch.fft.rfftfreq(L, d=1.0).to(x.device) # (F,) + + # Amplitude shaping: + amp = (freqs.clamp_min(eps) ** (power / 2.0)).view(1, -1, 1) + N = N * amp + + n_blue = torch.fft.irfft(N, n=L, dim=1) # (B, L, C) or (B, L, 1) + + # Re-normalize per (B,C) to unit std across time + n_blue = n_blue / n_blue.std(dim=1, keepdim=True).clamp_min(1e-6) + + if correlated and C > 1: + n_blue = n_blue.expand(B, L, C).clone() + + out = n_blue.to(dtype=x.dtype) + return out + @register_model("diffusion_ts", "Watts_2_1D", "Watts_2_2D") class Diffusion_TS(GenerativeModel): @@ -29,7 +96,21 @@ class Diffusion_TS(GenerativeModel): Uses a Transformer backbone to predict and denoise time series over discrete diffusion timesteps. Supports EMA smoothing and configurable - beta schedules. + beta schedules. Optional reconstruction-guided sampling (Algorithms 1 & 2) + via sample_reconstruction_guided(shape, static_context_vars, dynamic_context_vars, algorithm="alg1"|"alg2") + when conditional observed data x_a is provided. + + Training objective (config model.training_objective): x0, epsilon, or v. + - x0: predict clean sample; loss = L1/L2(model_out, x_clean). + - epsilon: predict noise; loss = L1/L2(pred_epsilon, noise); pred_epsilon derived from model x0. + - v: v-parameterization; loss = L1/L2(pred_v, true_v); pred_v derived from model x0. + The network always outputs x0 (fc layer); sampling uses x0 and q_posterior unchanged. + Variance-debugging reference: + (a) Sampling: model predicts x0; reverse step uses q_posterior(x0,x_t,t); + noise term = sqrt(posterior_variance). Same beta schedule in train and sample. + (b) Normalization: per-group (context) z-score + optional [zmin,zmax] scale; + denorm must use the same normalizer.inverse_transform. Run + scripts/check_diffusion_consistency.py to verify norm/denorm identity. """ def __init__(self, cfg: DictConfig): @@ -53,15 +134,39 @@ def __init__(self, cfg: DictConfig): self.context_reconstruction_loss_weight = ( cfg.model.context_reconstruction_loss_weight ) - _ = self.context_module + # Reconstruction-guided sampling (Algorithms 1 & 2) + self.recon_guide_eta = getattr(cfg.model, "recon_guide_eta", 0.1) + self.recon_guide_gamma = getattr(cfg.model, "recon_guide_gamma", 1.0) + self.recon_guide_algorithm = getattr(cfg.model, "recon_guide_algorithm", "none") + self.recon_guide_K = getattr(cfg.model, "recon_guide_K", 3) + # Verify context modules are initialized (static, dynamic, or both) + if not hasattr(self, 'static_context_module') and not hasattr(self, 'dynamic_context_module'): + raise ValueError("At least one context module (static or dynamic) must be initialized") - # linear layer for denoised output - self.fc = nn.Linear( - self.time_series_dims + self.embedding_dim, self.time_series_dims - ) - # Transformer backbone + # Context embedding dropout (training only): zeros the *entire* embedding for a random + # subset of samples (Bernoulli with prob p). This is CFG-compatible — the model learns + # to denoise both with and without context, enabling guidance-scale inference later. + self.context_embed_dropout_p = getattr(cfg.model, "context_embed_dropout", 0.0) + # Keep the nn.Dropout attribute so old checkpoints that saved it don't break on load. + self.context_embed_dropout = nn.Dropout(p=0.0) + + # Optional dual head for x̂_a / x̂_b: separate output heads for conditional vs rest of sequence + self.recon_cond_len = getattr(cfg.model, "recon_cond_len", None) + if self.recon_cond_len is not None: + cond_len = int(self.recon_cond_len) + assert 0 < cond_len < self.seq_len, "recon_cond_len must be in (0, seq_len)" + self.fc_a = nn.Linear(self.time_series_dims, self.time_series_dims) + self.fc_b = nn.Linear(self.time_series_dims, self.time_series_dims) + self.fc = None + else: + self.fc_a = None + self.fc_b = None + self.fc = nn.Linear( + self.time_series_dims, self.time_series_dims + ) + # Transformer backbone (now uses AdaLN conditioning instead of input concatenation) self.model = Transformer( - n_feat=self.time_series_dims + self.embedding_dim, + n_feat=self.time_series_dims, n_channel=self.seq_len, n_layer_enc=cfg.model.n_layer_enc, n_layer_dec=cfg.model.n_layer_dec, @@ -72,28 +177,44 @@ def __init__(self, cfg: DictConfig): max_len=self.seq_len, n_embd=cfg.model.d_model, conv_params=[cfg.model.kernel_size, cfg.model.padding_size], + cond_dim=self.embedding_dim, + has_dynamic_ctx=self.dynamic_context_module is not None, ) + self.blue_noise_power = cfg.model.blue_noise_power + self.correlated_noise = bool(getattr(cfg.model, "correlated_noise", False)) + # EMA helper will be initialized on train start - self._ema_helper: Optional[EMA] = None + self._ema: Optional[EMA] = None # set up beta schedule if cfg.model.beta_schedule == "linear": betas = linear_beta_schedule(cfg.model.n_steps) elif cfg.model.beta_schedule == "cosine": betas = cosine_beta_schedule(cfg.model.n_steps) + elif cfg.model.beta_schedule == "cosine_logsnr": + betas = cosine_beta_schedule_logsnr(cfg.model.n_steps) else: raise ValueError("Unknown beta schedule") + eps = 1e-5 alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + alphas_cumprod = torch.cumprod(alphas, dim=0).clamp(min=eps, max=1.0 - eps) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0 - eps) + self.num_timesteps = betas.shape[0] self.sampling_timesteps = default( cfg.model.sampling_timesteps, self.num_timesteps ) self.fast_sampling = self.sampling_timesteps < self.num_timesteps self.loss_type = cfg.model.loss_type + self.training_objective = getattr( + cfg.model, "training_objective", "x0" + ).lower() + if self.training_objective not in ("x0", "eps", "v"): + raise ValueError( + f"training_objective must be one of x0, eps, v; got {self.training_objective}" + ) # register buffers for diffusion coefficients self.register_buffer("betas", betas) @@ -125,18 +246,112 @@ def __init__(self, cfg: DictConfig): self.register_buffer("posterior_mean_coef1", pmc1) self.register_buffer("posterior_mean_coef2", pmc2) - lw = torch.sqrt(alphas) * torch.sqrt(1.0 - alphas_cumprod) / betas / 100 + # Loss weighting: default (legacy), uniform, snr, or min_snr (via compute_snr_weights) + loss_weighting = getattr(cfg.model, "loss_weighting", "default") + min_snr_gamma = getattr(cfg.model, "min_snr_gamma", 5.0) + if loss_weighting == "default": + lw = torch.sqrt(alphas) * torch.sqrt(1.0 - alphas_cumprod) / betas.clamp(min=1e-8) / 100 + else: + lw = self.compute_snr_weights( + alphas_cumprod, + loss_weighting=loss_weighting, + objective=self.training_objective, + gamma=min_snr_gamma, + ) self.register_buffer("loss_weight", lw) # choose reconstruction loss if self.loss_type == "l1": self.recon_loss_fn = F.l1_loss elif self.loss_type == "l2": - self.recon_loss_fn = F.mse_loss + self.recon_loss_fn = F.mse_loss # MSE for continuous RVs else: raise ValueError("Invalid loss type") self.auxiliary_loss = nn.CrossEntropyLoss() + + # Get continuous variables from config to distinguish them in loss computation + self.continuous_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "continuous"] + self.categorical_context_vars = [k for k, v in cfg.dataset.context_vars.items() if v[0] == "categorical"] + + def _get_context_embedding( + self, static_context_vars: dict, dynamic_context_vars: dict = None, + batch_size: int = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], dict]: + """ + Get context embeddings from static and/or dynamic context modules. + + Returns: + static_emb: (B, embedding_dim) — fed into AdaLN via the Transformer's cond path. + dyn_ctx_seq: (B, T, embedding_dim) or None — fed into cross-attention in each + DecoderBlock when the dynamic module returns_sequence=True. + all_logits: Dict of auxiliary classification/regression logits (static only). + """ + all_logits = {} + static_emb = None + dyn_ctx_seq = None + + # --- Static context (categorical + continuous) → (B, emb_dim) for AdaLN --- + if self.static_context_module is not None and static_context_vars: + device = next(self.static_context_module.parameters()).device + static_vars = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in static_context_vars.items() + } + static_emb, static_logits = self.static_context_module(static_vars) + all_logits.update(static_logits) + + # --- Dynamic context (time series) → (B, T, emb_dim) for cross-attention --- + if self.dynamic_context_module is not None and dynamic_context_vars: + device = next(self.dynamic_context_module.parameters()).device + dyn_vars = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in dynamic_context_vars.items() + } + dyn_out, dyn_logits = self.dynamic_context_module(dyn_vars) + all_logits.update(dyn_logits) + + if getattr(self.dynamic_context_module, "returns_sequence", False): + # (B, T, emb_dim) — routed to cross-attention, not AdaLN + dyn_ctx_seq = dyn_out.float() if dyn_out.is_floating_point() else dyn_out + else: + # Legacy pooled vector: fuse with static embedding via combine_mlp + if static_emb is not None and self.combine_mlp is not None: + combined = torch.cat([static_emb, dyn_out], dim=1) + static_emb = self.combine_mlp(combined) + elif static_emb is None: + static_emb = dyn_out + + if static_emb is None: + if batch_size is None: + raise ValueError("No context provided and batch_size not given — cannot create unconditional embedding") + device = next(self.parameters()).device + static_emb = torch.zeros(batch_size, self.embedding_dim, device=device) + + if static_emb.is_floating_point(): + static_emb = static_emb.float() + if self.training and self.context_embed_dropout_p > 0: + mask = torch.bernoulli( + torch.full((static_emb.shape[0], 1), 1.0 - self.context_embed_dropout_p, + device=static_emb.device, dtype=static_emb.dtype) + ) + static_emb = static_emb * mask + return static_emb, dyn_ctx_seq, all_logits + + def _decode_to_x0(self, backbone: torch.Tensor) -> torch.Tensor: + """ + Map backbone output (trend+season) to x0 prediction. Uses single fc or dual fc_a/fc_b when recon_cond_len is set. + backbone: (B, L, time_series_dims). + """ + if self.fc is not None: + out = self.fc(backbone) + else: + cond_len = self.recon_cond_len + out = torch.cat([ + self.fc_a(backbone[:, :cond_len]), + self.fc_b(backbone[:, cond_len:]), + ], dim=1) + return out def predict_noise_from_start( self, x_t: torch.Tensor, t: torch.Tensor, x0: torch.Tensor @@ -152,9 +367,10 @@ def predict_noise_from_start( Returns: Noise prediction tensor same shape as x_t. """ - return ( + out = ( self.sqrt_recip_alphas_cumprod[t].view(-1, 1, 1) * x_t - x0 ) / self.sqrt_recipm1_alphas_cumprod[t].view(-1, 1, 1) + return out def predict_start_from_noise( self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor @@ -170,10 +386,86 @@ def predict_start_from_noise( Returns: Reconstructed x0 tensor same shape as x_t. """ - return ( + out = ( self.sqrt_recip_alphas_cumprod[t].view(-1, 1, 1) * x_t - self.sqrt_recipm1_alphas_cumprod[t].view(-1, 1, 1) * noise ) + return out + + def predict_start_from_v( + self, x_t: torch.Tensor, t: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + """ + Reconstruct x0 from x_t and v-parameterization. + v = sqrt(alpha_bar_t) * epsilon - sqrt(1 - alpha_bar_t) * x0 => x0 = sqrt(alpha_bar_t) * x_t - sqrt(1 - alpha_bar_t) * v + """ + out = ( + self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x_t + - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * v + ) + return out + + def predict_noise_from_v( + self, x_t: torch.Tensor, t: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + """ + Reconstruct epsilon from x_t and v-parameterization. + v = sqrt(alpha_bar_t) * epsilon - sqrt(1 - alpha_bar_t) * x0 => epsilon = sqrt(1 - alpha_bar_t) * x_t + sqrt(alpha_bar_t) * v + """ + out = ( + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x_t + + self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * v + ) + return out + + + def compute_snr_weights( + self, + alphas_cumprod: torch.Tensor, + *, + loss_weighting: str, + objective: str, + gamma: float = 5.0, + ) -> torch.Tensor: + """ + SNR-based loss weighting per timestep. + + Args: + alphas_cumprod: Cumulative product of alphas, shape (n_steps,). + loss_weighting: "uniform" | "snr" | "min_snr". + objective: "eps" | "x0" | "v" — must match training objective. + gamma: Cap for SNR when loss_weighting == "min_snr". + + Returns: + Weight tensor same shape as alphas_cumprod. + """ + snr = alphas_cumprod / (1.0 - alphas_cumprod).clamp(min=1e-8) + + if loss_weighting == "uniform": + return torch.ones_like(snr) + + if loss_weighting == "snr": + if objective == "eps": + return 1.0 / (snr + 1.0) + elif objective == "x0": + return snr / (snr + 1.0) + elif objective == "v": + return 1.0 / torch.sqrt(snr + 1.0) + else: + raise ValueError(objective) + + if loss_weighting == "min_snr": + snr_c = torch.minimum(snr, torch.full_like(snr, gamma)) + if objective == "eps": + return snr_c / snr.clamp(min=1e-8) + elif objective == "x0": + return snr_c + elif objective == "v": + return snr_c / (snr + 1.0) + else: + raise ValueError(objective) + + raise ValueError(loss_weighting) def q_posterior( self, @@ -200,7 +492,7 @@ def q_posterior( plv = self.posterior_log_variance_clipped[t].view(-1, 1, 1) return pm, pv, plv - def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, dict]: + def forward(self, x: torch.Tensor, static_context_vars: dict, dynamic_context_vars: dict = None) -> Tuple[torch.Tensor, dict]: """ Single forward pass: add noise, predict denoised output and compute reconstruction loss. @@ -212,22 +504,58 @@ def forward(self, x: torch.Tensor, context_vars: dict) -> Tuple[torch.Tensor, di rec_loss: Reconstruction loss tensor. cond_logits: Classification logits dict from context module. """ + b = x.shape[0] t = torch.randint(0, self.num_timesteps, (b,), device=self.device) - embedding, cond_classification_logits = self.context_module(context_vars) - noise = torch.randn_like(x) + embedding, dyn_ctx_seq, cond_classification_logits = self._get_context_embedding(static_context_vars, dynamic_context_vars, batch_size=b) + + noise = blueish_noise_like( + x, power=self.blue_noise_power, correlated=self.correlated_noise + ) x_noisy = ( self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * x + self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * noise ) - c = torch.cat( - [x_noisy, embedding.unsqueeze(1).repeat(1, self.seq_len, 1)], - dim=-1, - ) - trend, season = self.model(c, t, padding_masks=None) - x_recon = self.fc(trend + season) - rec_loss = self.recon_loss_fn(x_recon, x) - return rec_loss, cond_classification_logits + trend, season = self.model(x_noisy, t, padding_masks=None, cond=embedding, dyn_ctx=dyn_ctx_seq) + x_start_pred = self._decode_to_x0((trend + season).contiguous()) + # Compute loss based on training objective (network always predicts x0; we derive epsilon/v as needed) + if self.training_objective == "x0": + loss_per_elem = self.recon_loss_fn(x_start_pred, x, reduction="none") + elif self.training_objective == "eps": + pred_noise = self.predict_noise_from_start(x_noisy, t, x_start_pred) + loss_per_elem = self.recon_loss_fn(pred_noise, noise, reduction="none") + else: # v + sqrt_ab = self.sqrt_alphas_cumprod[t].view(-1, 1, 1) + sqrt_1mab = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1).clamp(min=1e-3) + pred_v = (sqrt_ab * x_noisy - x_start_pred) / sqrt_1mab + true_v = ( + self.sqrt_alphas_cumprod[t].view(-1, 1, 1) * noise + - self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1) * x + ) + loss_per_elem = self.recon_loss_fn(pred_v, true_v, reduction="none") + rec_loss = ( + self.loss_weight[t].view(-1, 1, 1) * loss_per_elem + ).mean() + fourier_loss = torch.tensor(0.0, device=self.device) + if self.use_ff: + # FFT is not generally supported in fp16 for non power-of-2 sizes on cuFFT. + # Run FFT in fp32 outside autocast. + with torch.autocast(device_type=x.device.type, enabled=False): + x1 = x_start_pred.transpose(1, 2).float() + x2 = x.transpose(1, 2).float() + + fft1 = torch.fft.fft(x1, norm="forward") + fft2 = torch.fft.fft(x2, norm="forward") + + mag1 = torch.abs(fft1) + mag2 = torch.abs(fft2) + + fourier_loss = self.recon_loss_fn(mag1, mag2, reduction="none") + + fourier_loss = (self.loss_weight[t].view(-1, 1, 1) * fourier_loss).mean() + + + return rec_loss, cond_classification_logits, fourier_loss.mean() def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: """ @@ -240,15 +568,19 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: Returns: total_loss: Scalar training loss. """ - ts_batch, cond_batch = batch - rec_loss, cond_class_logits = self(ts_batch, cond_batch) - cond_loss = 0.0 + ts_batch, static_context_batch, dynamic_context_batch = batch + rec_loss, cond_class_logits, fourier_loss = self(ts_batch, static_context_batch, dynamic_context_batch) - for var_name, logits in cond_class_logits.items(): - labels = cond_batch[var_name] - cond_loss += self.auxiliary_loss(logits, labels) + cond_loss = 0.0 + for var_name, outputs in cond_class_logits.items(): + labels = static_context_batch[var_name] + if var_name in self.continuous_context_vars: + loss = F.mse_loss(outputs, labels.float()) + elif var_name in self.categorical_context_vars: + loss = self.auxiliary_loss(outputs, labels) + cond_loss += loss.mean() - h, _ = self.context_module(cond_batch) + h, _, _ = self._get_context_embedding(static_context_batch, dynamic_context_batch, batch_size=ts_batch.shape[0]) tc_term = ( self.cfg.model.tc_loss_weight * total_correlation(h) if self.cfg.model.tc_loss_weight > 0.0 @@ -256,16 +588,22 @@ def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: ) total_loss = ( - rec_loss + self.context_reconstruction_loss_weight * cond_loss + tc_term + rec_loss + self.context_reconstruction_loss_weight * cond_loss + tc_term + fourier_loss * self.ff_weight ) + + # Skip this batch entirely if loss is bad — avoids corrupting weights before EMA can help + self.log_dict( { - "train_loss": total_loss, - "rec_loss": rec_loss, - "cond_loss": cond_loss, + "train_loss": total_loss.item(), + "rec_loss": rec_loss.item(), + "cond_loss": cond_loss.item() if isinstance(cond_loss, torch.Tensor) else float(cond_loss), "tc_loss": tc_term, + "fourier_loss": fourier_loss.item(), }, prog_bar=True, + sync_dist=True, + on_epoch=True, ) return total_loss @@ -282,30 +620,99 @@ def configure_optimizers(self) -> dict: scheduler = ReduceLROnPlateau(optimizer, **self.cfg.trainer.lr_scheduler_params) return { "optimizer": optimizer, - "lr_scheduler": scheduler, - "monitor": "train_loss", + "lr_scheduler": {"scheduler": scheduler, "monitor": "train_loss"}, } def on_train_start(self) -> None: """ Initialize EMA helper at start of training. """ - self._ema = EMA( - self.model, - beta=self.cfg.model.ema_decay, - update_every=self.cfg.model.ema_update_interval, - ) + if self._ema is None: + object.__setattr__(self, '_ema', EMA( + self.model, + beta=self.cfg.model.ema_decay, + update_every=self.cfg.model.ema_update_interval, + )) + # EMA is not a registered submodule so PL won't move it automatically + self._ema.to(self.device) + + def on_train_epoch_start(self) -> None: + """ + Apply linear LR warmup for the first warmup_epochs epochs. + After warmup, ReduceLROnPlateau takes over untouched. + """ + warmup_epochs = self.cfg.trainer.get("warmup_epochs", 0) + if warmup_epochs <= 0: + return + epoch = self.current_epoch + if epoch >= warmup_epochs: + return + target_lr = self.cfg.trainer.base_lr + warmup_lr = target_lr * max(0.01, epoch / warmup_epochs) + for opt in self.trainer.optimizers: + for pg in opt.param_groups: + pg["lr"] = warmup_lr + + def on_before_optimizer_step(self, optimizer) -> None: + """Log global gradient norm each step for observability.""" + total_norm_sq = 0.0 + for p in self.parameters(): + if p.grad is not None: + total_norm_sq += p.grad.detach().float().norm(2).item() ** 2 + grad_norm = total_norm_sq ** 0.5 + self.log("grad_norm", grad_norm, on_step=True, on_epoch=False, prog_bar=False) def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: """ Apply EMA update after each batch end. + Skip if the step was skipped due to NaN loss (outputs is None). + """ + if outputs is None: + return + if hasattr(self, '_ema') and self._ema: + self._ema.update() + + + def load_state_dict(self, state_dict, strict=True): + # Strip legacy _ema.* keys — EMA is restored separately via on_load_checkpoint. + # Old checkpoints have these because _ema was previously a registered submodule. + filtered = {k: v for k, v in state_dict.items() if not k.startswith('_ema.')} + return super().load_state_dict(filtered, strict=strict) + + def on_load_checkpoint(self, checkpoint: dict) -> None: + if 'ema_state_dict' in checkpoint and False: + if self._ema is None: + object.__setattr__(self, '_ema', EMA( + self.model, + beta=self.cfg.model.ema_decay, + update_every=self.cfg.model.ema_update_interval, + )) + self._ema.ema_model.load_state_dict(checkpoint['ema_state_dict']) + print(f"[EMA] Restored EMA weights from checkpoint") + else: + print(f"[EMA] No EMA weights in checkpoint, initializing fresh") + + def on_save_checkpoint(self, checkpoint: dict) -> None: + if self._ema is not None: + checkpoint['ema_state_dict'] = self._ema.ema_model.state_dict() + + + def _predict_x0_from_xt_with_grad( + self, x_t: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor, + dyn_ctx: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Predict x0 from x_t with gradients enabled (for reconstruction-guided sampling). + Returns x_start of shape (B, L, C). Call with x_t.requires_grad_(True). """ - if self._ema_helper: - self._ema_helper.update() + trend, season = self.model(x_t, t, padding_masks=None, cond=embedding, dyn_ctx=dyn_ctx) + x_start = self._decode_to_x0((trend + season).contiguous()) + return x_start @torch.no_grad() def model_predictions( - self, x: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor + self, x: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor, + dyn_ctx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict both noise and clean sample from current x. @@ -314,15 +721,27 @@ def model_predictions( pred_noise: predicted noise tensor. x_start: predicted clean sample tensor. """ - c = torch.cat([x, embedding.unsqueeze(1).repeat(1, self.seq_len, 1)], dim=-1) - trend, season = self.model(c, t, padding_masks=None) - x_start = self.fc(trend + season) + trend, season = self.model(x, t, padding_masks=None, cond=embedding, dyn_ctx=dyn_ctx) + x_start = self._decode_to_x0((trend + season).contiguous()) pred_noise = self.predict_noise_from_start(x, t, x_start) return pred_noise, x_start + @staticmethod + def _replace_conditional( + x_a: torch.Tensor, x_prev: torch.Tensor, cond_len: int + ) -> torch.Tensor: + """ + Replace the first cond_len time steps of x_prev with conditional data x_a. + x_a: (B, cond_len, C), x_prev: (B, L, C). Returns (B, L, C). + """ + out = x_prev.clone() + out[:, :cond_len] = x_a + return out + @torch.no_grad() def p_mean_variance( - self, x: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor + self, x: torch.Tensor, t: torch.Tensor, embedding: torch.Tensor, + dyn_ctx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute mean and variance for p(x_{t-1} | x_t). @@ -330,49 +749,227 @@ def p_mean_variance( Returns: pm, pv, plv, x_start: posterior parameters and predicted x0. """ - pred_noise, x_start = self.model_predictions(x, t, embedding) + pred_noise, x_start = self.model_predictions(x, t, embedding, dyn_ctx=dyn_ctx) pm, pv, plv = self.q_posterior(x_start, x, t) return pm, pv, plv, x_start @torch.no_grad() def p_sample( - self, x: torch.Tensor, t: int, embedding: torch.Tensor + self, x: torch.Tensor, t: int, embedding: torch.Tensor, + dyn_ctx: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Sample x_{t-1} from x_t using posterior distribution. """ bt = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long) - pm, pv, plv, _ = self.p_mean_variance(x, bt, embedding) - noise = torch.randn_like(x) if t > 0 else 0 - return pm + (0.5 * plv).exp() * noise + pm, pv, plv, _ = self.p_mean_variance(x, bt, embedding, dyn_ctx=dyn_ctx) + noise = ( + blueish_noise_like(x, power=self.blue_noise_power, correlated=self.correlated_noise) + if t > 0 + else 0 + ) + out = pm + (0.5 * plv).exp() * noise + return out + + def _reconstruction_guided_step_alg1( + self, + x_t: torch.Tensor, + t: int, + embedding: torch.Tensor, + x_a: torch.Tensor, + cond_len: int, + eta: float, + gamma: float, + dyn_ctx: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step of Algorithm 1: predict x̂_0, compute L_1 + γ*L_2, then + x̃_0 = x̂_0 + η ∇_{x_t}(L_1 + γ*L_2); sample x_{t-1} ~ N(μ(x̃_0, x_t), Σ) and Replace(x_a, x_{t-1}). + """ + bt = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long) + x_t = x_t.detach().requires_grad_(True) + + x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding, dyn_ctx=dyn_ctx) + x_hat_a = x_start[:, :cond_len] + L_1 = (x_a - x_hat_a).pow(2).mean() + + pm, pv, plv = self.q_posterior(x_start, x_t, bt) + noise = ( + blueish_noise_like(x_t, power=self.blue_noise_power, correlated=self.correlated_noise) + if t > 0 + else 0 + ) + x_prev_initial = (pm + (0.5 * plv).exp() * noise).detach() + L_2 = ((x_prev_initial - pm).pow(2) / pv.clamp(min=1e-8)).mean() + + loss = L_1 + gamma * L_2 + loss.backward() + with torch.no_grad(): + x_tilde_0 = x_start.detach() + eta * x_t.grad + pm_final, pv_final, plv_final = self.q_posterior(x_tilde_0, x_t.detach(), bt) + noise_final = ( + _randn_like_correlated(x_t, self.correlated_noise) + if t > 0 + else torch.zeros_like(x_t, device=x_t.device) + ) + x_prev = pm_final + (0.5 * plv_final).exp() * noise_final + x_prev = self._replace_conditional(x_a, x_prev, cond_len) + return x_prev + + def _reconstruction_guided_step_alg2( + self, + x_t: torch.Tensor, + t: int, + embedding: torch.Tensor, + x_a: torch.Tensor, + cond_len: int, + eta: float, + gamma: float, + K: int, + dyn_ctx: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step of Algorithm 2: K inner gradient updates on x_t, then one final sample and Replace. + """ + bt = torch.full((x_t.shape[0],), t, device=self.device, dtype=torch.long) + embedding_detach = embedding.detach() + dyn_ctx_detach = dyn_ctx.detach() if dyn_ctx is not None else None + x_t = x_t.detach().clone() + + for _ in range(K): + x_t = x_t.requires_grad_(True) + x_start = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach, dyn_ctx=dyn_ctx_detach) + x_hat_a = x_start[:, :cond_len] + L_1 = (x_a - x_hat_a).pow(2).mean() + pm, pv, plv = self.q_posterior(x_start, x_t, bt) + noise = ( + blueish_noise_like(x_t, power=self.blue_noise_power, correlated=self.correlated_noise) + if t > 0 + else 0 + ) + x_prev_initial = (pm + (0.5 * plv).exp() * noise).detach() + L_2 = ((x_prev_initial - pm).pow(2) / pv.clamp(min=1e-8)).mean() + loss = L_1 + gamma * L_2 + loss.backward() + with torch.no_grad(): + x_t = x_t + eta * x_t.grad + x_t = x_t.detach() + + with torch.no_grad(): + x_start_final = self._predict_x0_from_xt_with_grad(x_t, bt, embedding_detach, dyn_ctx=dyn_ctx_detach) + pm_final, pv_final, plv_final = self.q_posterior(x_start_final, x_t, bt) + noise_final = ( + blueish_noise_like(x_t, power=self.blue_noise_power, correlated=self.correlated_noise) + if t > 0 + else 0 + ) + x_prev = pm_final + (0.5 * plv_final).exp() * noise_final + x_prev = self._replace_conditional(x_a, x_prev, cond_len) + return x_prev + + def sample_reconstruction_guided( + self, + x_a: torch.Tensor, + shape: Tuple[int, int, int], + static_context_vars: dict, + dynamic_context_vars: dict = None, + algorithm: str = "alg1", + ) -> torch.Tensor: + """ + Full reverse-pass sampling with reconstruction guidance (Algorithm 1 or 2). + + Args: + shape: (batch_size, seq_len, time_series_dims). + static_context_vars: static context conditioning dict. + dynamic_context_vars: dynamic context conditioning dict. + x_a: Conditional (observed) data, shape (B, cond_len, C). First cond_len + time steps to reconstruct; model output is split as x̂_0 = [x̂_a, x̂_b]. + algorithm: "alg1" (one gradient step per t) or "alg2" (K inner steps per t). + + Returns: + Generated samples (B, L, C) with first cond_len steps equal to x_a. + + Architecture / config notes (optional improvements): + - x̂_0 split: Currently x̂_0 is the single model output (B,L,C); we split by time + so x̂_a = x̂_0[:, :cond_len], x̂_b = x̂_0[:, cond_len:]. Optionally use two output + heads (one for x̂_a, one for x̂_b) if you want different capacities. + - Context dropout: Consider adding dropout on the context embedding (or on + context encoder outputs) during training to improve robustness of + reconstruction-guided sampling at test time. + - Config: recon_guide_eta (gradient scale), recon_guide_gamma (L1 vs L2 trade-off), + recon_guide_K (int or list of length num_timesteps for per-t inner steps in alg2). + """ + cond_len = x_a.shape[1] + assert cond_len < shape[1], "x_a length must be < seq_len" + assert x_a.shape[0] == shape[0] and x_a.shape[2] == shape[2] + eta = self.recon_guide_eta + gamma = self.recon_guide_gamma + x = _randn_shape_correlated( + shape, self.device, torch.float32, self.correlated_noise + ) + embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars, batch_size=shape[0]) + x_a = x_a.to(self.device) + + for t in reversed(range(self.num_timesteps)): + K_t = ( + self.recon_guide_K[t] if t < len(self.recon_guide_K) else self.recon_guide_K[-1] + if isinstance(self.recon_guide_K, (list, tuple)) + else self.recon_guide_K + ) + if algorithm == "alg1": + x = self._reconstruction_guided_step_alg1( + x, t, embedding, x_a, cond_len, eta, gamma, dyn_ctx=dyn_ctx_seq + ) + else: + x = self._reconstruction_guided_step_alg2( + x, t, embedding, x_a, cond_len, eta, gamma, K_t, dyn_ctx=dyn_ctx_seq + ) + return x @torch.no_grad() - def sample(self, shape: Tuple[int, int, int], context_vars: dict) -> torch.Tensor: + def sample(self, shape: Tuple[int, int, int], static_context_vars: dict, dynamic_context_vars: dict = None) -> torch.Tensor: """ Full reverse-pass sampling over all timesteps. Args: shape: (batch_size, seq_len, dims) - context_vars: context conditioning dict + static_context_vars: static context conditioning dict + dynamic_context_vars: dynamic context conditioning dict Returns: Generated samples tensor. """ - x = torch.randn(shape, device=self.device) - embedding, _ = self.context_module(context_vars) + x = _randn_shape_correlated( + shape, self.device, torch.float32, self.correlated_noise + ) + embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars, batch_size=shape[0]) for t in reversed(range(self.num_timesteps)): - x = self.p_sample(x, t, embedding) + x = self.p_sample(x, t, embedding, dyn_ctx=dyn_ctx_seq) return x @torch.no_grad() def fast_sample( - self, shape: Tuple[int, int, int], context_vars: dict + self, shape: Tuple[int, int, int], static_context_vars: dict, + dynamic_context_vars: dict = None, cfg_scale: float = 1.0, ) -> torch.Tensor: """ - Faster sampling using a reduced number of timesteps. + DDIM sampling with optional classifier-free guidance. + + cfg_scale=1.0 → standard conditional sampling (no guidance). + cfg_scale>1.0 → CFG: runs both conditional and unconditional passes each step + and blends pred_noise = uncond + scale*(cond - uncond). + Requires the model to have been trained with context_embed_dropout > 0. """ - x = torch.randn(shape, device=self.device) - embedding, _ = self.context_module(context_vars) + x = _randn_shape_correlated( + shape, self.device, torch.float32, self.correlated_noise + ) + embedding, dyn_ctx_seq, _ = self._get_context_embedding(static_context_vars, dynamic_context_vars, batch_size=shape[0]) + + use_cfg = cfg_scale > 1.0 + if use_cfg: + uncond_emb = torch.zeros_like(embedding) + uncond_dyn = torch.zeros_like(dyn_ctx_seq) if dyn_ctx_seq is not None else None + times = torch.linspace( -1, self.num_timesteps - 1, steps=self.sampling_timesteps + 1 ) @@ -380,7 +977,13 @@ def fast_sample( pairs = list(zip(times[:-1], times[1:])) for time, time_next in pairs: bt = torch.full((x.shape[0],), time, device=self.device, dtype=torch.long) - pred_noise, x_start = self.model_predictions(x, bt, embedding) + if use_cfg: + pred_noise_u, _ = self.model_predictions(x, bt, uncond_emb, dyn_ctx=uncond_dyn) + pred_noise_c, _ = self.model_predictions(x, bt, embedding, dyn_ctx=dyn_ctx_seq) + pred_noise = pred_noise_u + cfg_scale * (pred_noise_c - pred_noise_u) + x_start = self.predict_start_from_noise(x, bt, pred_noise) + else: + pred_noise, x_start = self.model_predictions(x, bt, embedding, dyn_ctx=dyn_ctx_seq) if time_next < 0: x = x_start continue @@ -391,91 +994,151 @@ def fast_sample( * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() ) c = (1 - alpha_next - sigma**2).sqrt() - noise = torch.randn_like(x) + noise = _randn_like_correlated(x, self.correlated_noise) x = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise return x - def generate(self, context_vars: dict) -> torch.Tensor: + @contextmanager + def ema_scope(self): + if hasattr(self, "_ema") and self._ema and getattr(self.cfg.model, "use_ema_sampling", False): + self._ema.store(self.model.parameters()) + self._ema.copy_to(self.model.parameters()) + try: + yield + finally: + self._ema.restore(self.model.parameters()) + else: + yield + + def generate(self, static_context_vars: dict, dynamic_context_vars: dict = None, n: int = None) -> torch.Tensor: """ Public entry to generate conditioned samples in batches. Args: - context_vars: dict of context tensors for each sample. + static_context_vars: dict of context tensors for each sample. Pass {} for unconditional generation. + dynamic_context_vars: dict of dynamic context tensors for each sample. + n: number of samples to generate; required when static_context_vars is empty. Returns: Complete generated tensor of shape (N, seq_len, dims). """ bs = self.cfg.model.sampling_batch_size - total = len(next(iter(context_vars.values()))) + if static_context_vars: + total = len(next(iter(static_context_vars.values()))) + elif n is not None: + total = n + else: + raise ValueError("Pass n= when generating unconditionally (empty static_context_vars)") generated_samples = [] - for start_idx in tqdm( - range(0, total, bs), - unit="seq", - desc="[CENTS] Generating samples", - leave=True, - ): - end_idx = min(start_idx + bs, total) - batch_context_vars = { - var_name: var_tensor[start_idx:end_idx] - for var_name, var_tensor in context_vars.items() - } - current_bs = end_idx - start_idx - shape = (current_bs, self.seq_len, self.time_series_dims) + with self.ema_scope(): + for start_idx in tqdm( + range(0, total, bs), + unit="seq", + desc="[CENTS] Generating samples", + leave=True, + ): + end_idx = min(start_idx + bs, total) + batch_static_context_vars = { + var_name: var_tensor[start_idx:end_idx] + for var_name, var_tensor in static_context_vars.items() + } + batch_dynamic_context_vars = { + var_name: var_tensor[start_idx:end_idx] + for var_name, var_tensor in (dynamic_context_vars or {}).items() + } - with torch.no_grad(): - if getattr(self.cfg.model, "use_ema_sampling", False) and hasattr( - self, "_ema_helper" - ): - samples = self._ema_helper.ema_model._generate( - shape, batch_context_vars - ) - else: - samples = ( - self.fast_sample(shape, batch_context_vars) - if self.fast_sampling - else self.sample(shape, batch_context_vars) - ) - - generated_samples.append(samples) + current_bs = end_idx - start_idx + shape = (current_bs, self.seq_len, self.time_series_dims) - return torch.cat(generated_samples, dim=0) + with torch.no_grad(): + cfg_scale = getattr(self, '_cfg_scale', 1.0) + if self.fast_sampling: + samples = self.fast_sample(shape, batch_static_context_vars, batch_dynamic_context_vars, cfg_scale=cfg_scale) + else: + samples = self.sample(shape, batch_static_context_vars, batch_dynamic_context_vars) -class EMA(nn.Module): - """ - Exponential Moving Average (EMA) helper for model parameters. - - Maintains a shadow copy of the model weights that are updated - via EMA every `update_every` steps. - """ + generated_samples.append(samples.cpu()) - def __init__(self, model: nn.Module, beta: float, update_every: int): + return torch.cat(generated_samples, dim=0) + + def _ensure_ema_helper(self) -> None: """ - Args: - model: Base model to copy for EMA tracking. - beta: EMA decay rate (0 < beta < 1). - update_every: Frequency (in steps) to apply EMA update. + Ensure EMA helper is initialized if needed for inference. """ + if not hasattr(self, '_ema') or self._ema is None: + print("Initializing EMA helper for inference...") + object.__setattr__(self, '_ema', EMA( + self.model, + beta=self.cfg.model.ema_decay, + update_every=self.cfg.model.ema_update_interval, + )) + def stratified_timesteps(self, batch_size: int, num_timesteps: int, k_bins: int, device=None) -> torch.Tensor: + device = device or "cpu" + k_bins = min(k_bins, batch_size) + edges = torch.linspace(0, num_timesteps, steps=k_bins + 1, device=device) + + # sample one t per bin + u = torch.rand(k_bins, device=device) + t_bins = (edges[:-1] + u * (edges[1:] - edges[:-1])).floor().clamp_(0, num_timesteps - 1).long() + + # repeat to fill batch, then shuffle + reps = (batch_size + k_bins - 1) // k_bins + t = t_bins.repeat(reps)[:batch_size] + t = t[torch.randperm(batch_size, device=device)] + return t + + +class EMA(nn.Module): + def __init__(self, model, beta, update_every): super().__init__() - self.model = copy.deepcopy(model) - self.ema_model = self.model.eval() self.beta = beta self.update_every = update_every self.step = 0 - for param in self.ema_model.parameters(): - param.requires_grad = False + self.ema_model = copy.deepcopy(model) + self.ema_model.eval() + self.ema_model.requires_grad_(False) + + # Store as plain python attribute, not nn.Module attribute + # This prevents it being registered as a submodule and saved in state_dict + object.__setattr__(self, '_source_model', model) + + self.collected_params = [] - def update(self) -> None: - """ - Perform an EMA update of the shadow model parameters. - Called typically at end of each training batch. - """ + def update(self): self.step += 1 if self.step % self.update_every != 0: return with torch.no_grad(): - for ema_p, model_p in zip( - self.ema_model.parameters(), self.model.parameters() + for ema_p, src_p in zip( + self.ema_model.parameters(), + self._source_model.parameters() ): - ema_p.data.mul_(self.beta).add_(model_p.data, alpha=1.0 - self.beta) + ema_p.data.mul_(self.beta).add_(src_p.data, alpha=1.0 - self.beta) + + def store(self, parameters): + """ + Save the current parameters (of the live model) to a temporary list. + Used by the context manager to back up weights before swapping. + """ + self.collected_params = [param.clone().cpu() for param in parameters] + + def restore(self, parameters): + """ + Restore the saved parameters back to the live model. + """ + if not self.collected_params: + raise RuntimeError("No parameters stored to restore.") + + for param, saved_param in zip(parameters, self.collected_params): + param.data.copy_(saved_param.data.to(param.device)) + + self.collected_params = [] # Clear memory + + def copy_to(self, parameters): + """ + Copy the EMA shadow weights INTO the live model parameters. + """ + for param, ema_param in zip(parameters, self.ema_model.parameters()): + param.data.copy_(ema_param.data.to(param.device)) \ No newline at end of file diff --git a/cents/models/model_utils.py b/cents/models/model_utils.py index 304cca4..be4cf73 100644 --- a/cents/models/model_utils.py +++ b/cents/models/model_utils.py @@ -9,6 +9,7 @@ """ import math +from typing import Optional import numpy as np import torch @@ -17,6 +18,7 @@ from torch import nn + def linear_beta_schedule(timesteps: int) -> torch.Tensor: """ Create a linear schedule of betas for diffusion noise levels. @@ -52,6 +54,20 @@ def cosine_beta_schedule(timesteps: int, s: float = 0.004) -> torch.Tensor: return torch.clip(betas, 0, 0.999) +def cosine_beta_schedule_logsnr( + timesteps: int, + logsnr_min: float = -8.0, + logsnr_max: float = 15.0, + eps: float = 1e-5, +) -> torch.Tensor: + t = torch.linspace(0, 1, timesteps + 1, dtype=torch.float32) + logsnr = logsnr_max + 0.5 * (logsnr_min - logsnr_max) * (1 - torch.cos(math.pi * t)) + + alpha_bar = torch.sigmoid(logsnr).clamp(eps, 1.0 - eps) + + betas = 1 - (alpha_bar[1:] / alpha_bar[:-1]) + return torch.clip(betas, 0, 0.999) + def exists(x): return x is not None @@ -134,7 +150,6 @@ def forward(self, x): x: [batch size, sequence length, embed dim] output: [batch size, sequence length, embed dim] """ - # print(x.shape) x = x + self.pe return self.dropout(x) @@ -213,14 +228,18 @@ def forward(self, x): class Conv_MLP(nn.Module): def __init__(self, in_dim, out_dim, resid_pdrop=0.0): super().__init__() - self.sequential = nn.Sequential( - Transpose(shape=(1, 2)), - nn.Conv1d(in_dim, out_dim, 3, stride=1, padding=1), - nn.Dropout(p=resid_pdrop), - ) + self.conv = nn.Conv1d(in_dim, out_dim, 3, stride=1, padding=1) + if self.conv.weight.requires_grad: + self.conv.weight.register_hook(lambda grad: grad.contiguous()) + self.drop = nn.Dropout(p=resid_pdrop) def forward(self, x): - return self.sequential(x).transpose(1, 2) + x = x.transpose(1, 2).contiguous() # (B, C, T) contiguous + x = self.conv(x) + x = self.drop(x) + out = x.transpose(1, 2).contiguous() # back to (B, T, C), contiguous + return out + class Transformer_MLP(nn.Module): @@ -265,15 +284,16 @@ def forward(self, x): class AdaLayerNorm(nn.Module): def __init__(self, n_embd): super().__init__() - self.emb = SinusoidalPosEmb(n_embd) + # self.emb = SinusoidalPosEmb(n_embd) self.silu = nn.SiLU() self.linear = nn.Linear(n_embd, n_embd * 2) + + nn.init.zeros_(self.linear.bias) + nn.init.zeros_(self.linear.weight) self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False) - def forward(self, x, timestep, label_emb=None): - emb = self.emb(timestep) - if label_emb is not None: - emb = emb + label_emb + def forward(self, x, emb): + # emb: (B, n_embd) - Pre-computed and combined (time + label) emb = self.linear(self.silu(emb)).unsqueeze(1) scale, shift = torch.chunk(emb, 2, dim=2) x = self.layernorm(x) * (1 + scale) + shift @@ -325,6 +345,8 @@ def __init__(self, in_dim, out_dim, in_feat, out_feat, act): def forward(self, input): b, c, h = input.shape + if not input.is_contiguous(): + input = input.contiguous() x = self.trend(input).transpose(1, 2) trend_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device)) trend_vals = trend_vals.transpose(1, 2) @@ -553,16 +575,10 @@ def __init__( activate="GELU", ): super().__init__() - self.ln1 = AdaLayerNorm(n_embd) - self.ln2 = nn.LayerNorm(n_embd) - self.attn = FullAttention( - n_embd=n_embd, - n_head=n_head, - attn_pdrop=attn_pdrop, - resid_pdrop=resid_pdrop, - ) - + self.ln2 = AdaLayerNorm(n_embd) + self.attn = FullAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + assert activate in ["GELU", "GELU2"] act = nn.GELU() if activate == "GELU" else GELU2() @@ -573,10 +589,11 @@ def __init__( nn.Dropout(resid_pdrop), ) - def forward(self, x, timestep, mask=None, label_emb=None): - a, att = self.attn(self.ln1(x, timestep, label_emb), mask=mask) + def forward(self, x, cond_emb, mask=None): + # cond_emb is the combined time+label embedding + a, att = self.attn(self.ln1(x, cond_emb), mask=mask) x = x + a - x = x + self.mlp(self.ln2(x)) # only one really use encoder_output + x = x + self.mlp(self.ln2(x, cond_emb)) return x, att @@ -607,10 +624,10 @@ def __init__( ] ) - def forward(self, input, t, padding_masks=None, label_emb=None): + def forward(self, input, cond_emb, padding_masks=None): x = input - for block_idx in range(len(self.blocks)): - x, _ = self.blocks[block_idx](x, t, mask=padding_masks, label_emb=label_emb) + for block in self.blocks: + x, _ = block(x, cond_emb, mask=padding_masks) return x @@ -628,11 +645,12 @@ def __init__( mlp_hidden_times=4, activate="GELU", condition_dim=1024, + has_dynamic_ctx=False, ): super().__init__() self.ln1 = AdaLayerNorm(n_embd) - self.ln2 = nn.LayerNorm(n_embd) + self.ln2 = AdaLayerNorm(n_embd) self.attn1 = FullAttention( n_embd=n_embd, @@ -647,6 +665,20 @@ def __init__( attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, ) + # Cross-attention for temporally-aligned dynamic context — only created when dynamic + # context is actually present (has_dynamic_ctx=True). Saves ~495K params when absent. + if has_dynamic_ctx: + self.ln_dyn = AdaLayerNorm(n_embd) + self.attn_dyn = CrossAttention( + n_embd=n_embd, + condition_embd=condition_dim, + n_head=n_head, + attn_pdrop=attn_pdrop, + resid_pdrop=resid_pdrop, + ) + else: + self.ln_dyn = None + self.attn_dyn = None self.ln1_1 = AdaLayerNorm(n_embd) @@ -654,9 +686,7 @@ def __init__( act = nn.GELU() if activate == "GELU" else GELU2() self.trend = TrendBlock(n_channel, n_channel, n_embd, n_feat, act=act) - # self.decomp = MovingBlock(n_channel) self.seasonal = FourierLayer(d_model=n_embd) - # self.seasonal = SeasonBlock(n_channel, n_channel) self.mlp = nn.Sequential( nn.Linear(n_embd, mlp_hidden_times * n_embd), @@ -668,14 +698,27 @@ def __init__( self.proj = nn.Conv1d(n_channel, n_channel * 2, 1) self.linear = nn.Linear(n_embd, n_feat) - def forward(self, x, encoder_output, timestep, mask=None, label_emb=None): - a, att = self.attn1(self.ln1(x, timestep, label_emb), mask=mask) + def forward(self, x, encoder_output, cond_emb, dyn_ctx=None, mask=None): + a, att = self.attn1(self.ln1(x, cond_emb), mask=mask) x = x + a - a, att = self.attn2(self.ln1_1(x, timestep), encoder_output, mask=mask) + a, att = self.attn2(self.ln1_1(x, cond_emb), encoder_output, mask=mask) x = x + a - x1, x2 = self.proj(x).chunk(2, dim=1) + if dyn_ctx is not None and self.attn_dyn is not None: + a, _ = self.attn_dyn(self.ln_dyn(x, cond_emb), dyn_ctx) + x = x + a + + # FIX: chunk() returns views that are often non-contiguous. + # Since self.proj and self.trend use Conv1d, this causes the DDP stride mismatch. + x_proj = self.proj(x) + x1, x2 = x_proj.chunk(2, dim=1) + + # Make contiguous before passing to specialized blocks + x1 = x1.contiguous() + x2 = x2.contiguous() + trend, season = self.trend(x1), self.seasonal(x2) - x = x + self.mlp(self.ln2(x)) + + x = x + self.mlp(self.ln2(x, cond_emb)) m = torch.mean(x, dim=1, keepdim=True) return x - m, self.linear(m), trend, season @@ -693,6 +736,7 @@ def __init__( mlp_hidden_times=4, block_activate="GELU", condition_dim=512, + has_dynamic_ctx=False, ): super().__init__() self.d_model = n_embd @@ -709,20 +753,24 @@ def __init__( mlp_hidden_times=mlp_hidden_times, activate=block_activate, condition_dim=condition_dim, + has_dynamic_ctx=has_dynamic_ctx, ) - for _ in range(n_layer) + for _ in range(n_layer) ] ) - def forward(self, x, t, enc, padding_masks=None, label_emb=None): + def forward(self, x, cond_emb, enc, dyn_ctx=None, padding_masks=None): b, c, _ = x.shape - # att_weights = [] mean = [] - season = torch.zeros((b, c, self.d_model), device=x.device) - trend = torch.zeros((b, c, self.n_feat), device=x.device) - for block_idx in range(len(self.blocks)): - x, residual_mean, residual_trend, residual_season = self.blocks[block_idx]( - x, enc, t, mask=padding_masks, label_emb=label_emb + # Initialize accumulating tensors on the correct device + season = torch.zeros((b, c, x.shape[-1]), device=x.device) + # Note: FourierLayer returns same dim as input x (n_embd) + + trend = torch.zeros((b, c, self.blocks[0].linear.out_features), device=x.device) + + for block in self.blocks: + x, residual_mean, residual_trend, residual_season = block( + x, enc, cond_emb, dyn_ctx=dyn_ctx, mask=padding_masks ) season += residual_season trend += residual_trend @@ -747,12 +795,32 @@ def __init__( block_activate="GELU", max_len=2048, conv_params=None, + cond_dim=None, + has_dynamic_ctx=False, **kwargs ): super().__init__() self.emb = Conv_MLP(n_feat, n_embd, resid_pdrop=resid_pdrop) self.inverse = Conv_MLP(n_embd, n_feat, resid_pdrop=resid_pdrop) + self.cond_dim = cond_dim + if cond_dim is not None: + # Map static context embedding (B, cond_dim) -> (B, n_embd) + self.cond_proj = nn.Linear(cond_dim, n_embd) + # Map dynamic context sequence (B, T, cond_dim) -> (B, T, n_embd); only when needed + self.dyn_ctx_proj = nn.Linear(cond_dim, n_embd) if has_dynamic_ctx else None + else: + self.cond_proj = None + self.dyn_ctx_proj = None + + self.time_emb = SinusoidalPosEmb(n_embd) + + self.cond_mix_mlp = nn.Sequential( + nn.Linear(n_embd * 2, n_embd), + nn.SiLU(), + nn.Linear(n_embd, n_embd), + ) + if conv_params is None or conv_params[0] is None: if n_feat < 32 and n_channel < 64: kernel_size, padding = 1, 0 @@ -804,34 +872,57 @@ def __init__( mlp_hidden_times, block_activate, condition_dim=n_embd, + has_dynamic_ctx=has_dynamic_ctx, ) self.pos_dec = LearnablePositionalEncoding( n_embd, dropout=resid_pdrop, max_len=max_len ) - def forward(self, input, t, padding_masks=None, return_res=False): + def forward(self, input, t, padding_masks=None, return_res=False, cond=None, dyn_ctx=None): + # cond: (B, cond_dim) static context or None + # dyn_ctx: (B, T, cond_dim) dynamic context sequence or None + # Ensure float32 so conv/linear (float32 params) never see double input + if input.is_floating_point(): + input = input.float() + if cond is not None and cond.is_floating_point(): + cond = cond.float() + if dyn_ctx is not None and dyn_ctx.is_floating_point(): + dyn_ctx = dyn_ctx.float() + t_emb = self.time_emb(t) + + label_emb = None + if (cond is not None) and (self.cond_proj is not None): + label_emb = self.cond_proj(cond) # (B, n_embd) + total_cond_emb = self.cond_mix_mlp(torch.concat([t_emb, label_emb], dim=1)) + else: + total_cond_emb = t_emb + + # Project dynamic context sequence to n_embd for cross-attention + dyn_ctx_emb = None + if dyn_ctx is not None and self.dyn_ctx_proj is not None: + dyn_ctx_emb = self.dyn_ctx_proj(dyn_ctx) # (B, T, n_embd) + emb = self.emb(input) inp_enc = self.pos_enc(emb) - enc_cond = self.encoder(inp_enc, t, padding_masks=padding_masks) + + enc_cond = self.encoder(inp_enc, total_cond_emb, padding_masks=padding_masks) inp_dec = self.pos_dec(emb) output, mean, trend, season = self.decoder( - inp_dec, t, enc_cond, padding_masks=padding_masks + inp_dec, total_cond_emb, enc_cond, dyn_ctx=dyn_ctx_emb, padding_masks=padding_masks ) res = self.inverse(output) - res_m = torch.mean(res, dim=1, keepdim=True) - season_error = ( - self.combine_s(season.transpose(1, 2)).transpose(1, 2) + res - res_m - ) - trend = self.combine_m(mean) + res_m + trend + + res_m = torch.mean(res, dim=1, keepdim=True).contiguous() + combine_m_out = self.combine_m(mean).contiguous() + combine_s_out = self.combine_s(season.transpose(1, 2)).transpose(1, 2).contiguous() + season_error = (combine_s_out + res - res_m).contiguous() + trend = (combine_m_out + res_m + trend).contiguous() if return_res: - return ( - trend, - self.combine_s(season.transpose(1, 2)).transpose(1, 2), - res - res_m, - ) + out_res = res - res_m + return trend, combine_s_out, out_res return trend, season_error diff --git a/cents/models/normalizer.py b/cents/models/normalizer.py index 39afd4b..6f3b5e1 100644 --- a/cents/models/normalizer.py +++ b/cents/models/normalizer.py @@ -6,18 +6,21 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +from omegaconf import ListConfig + from cents.datasets.utils import split_timeseries from cents.models.base import NormalizerModel -from cents.models.context import ContextModule +from cents.models.context import MLPContextModule, SepMLPContextModule, DynamicContextModule_Transformer # Import to trigger registration +from cents.models.context_registry import get_context_module_cls +from cents.models.stats_head_registry import register_stats_head, get_stats_head_cls from cents.models.registry import register_model +from cents.utils.utils import get_context_config -class _StatsHead(nn.Module): - """ - Head module predicting summary statistics (mean, std, and optionally min/max z-scores) from context embedding. - """ - +@register_stats_head("default", "mlp") +class MLPStatsHead(nn.Module): def __init__( self, embedding_dim: int, @@ -26,16 +29,6 @@ def __init__( do_scale: bool, n_layers: int = 3, ): - """ - Initializes the statistics head network. - - Args: - embedding_dim: Dimensionality of the input context embedding. - hidden_dim: Number of units in each hidden layer. - time_series_dims: Number of dimensions in the original time series. - do_scale: Whether to predict scaling min/max parameters. - n_layers: Number of hidden linear layers before the output. - """ super().__init__() self.time_series_dims = time_series_dims self.do_scale = do_scale @@ -48,36 +41,63 @@ def __init__( in_dim = hidden_dim layers.append(nn.Linear(in_dim, out_dim)) self.net = nn.Sequential(*layers) + + self._initialize_output_layer() + + def _initialize_output_layer(self, init_sigma: float = 1.0): + D = self.time_series_dims + K = 4 if self.do_scale else 2 + out_layer = self.net[-1] - def forward(self, z: torch.Tensor): - """ - Forward pass to compute predicted statistics. + with torch.no_grad(): + nn.init.xavier_uniform_(out_layer.weight, gain=0.01) + nn.init.zeros_(out_layer.bias) + + if self.do_scale: + # 1. Initialize z_min to -2.0 + out_layer.bias[2 * D : 3 * D].fill_(-2.0) + out_layer.bias[3 * D : 4 * D].fill_(4.0) - Args: - z: Context embedding tensor of shape (batch_size, embedding_dim). + @staticmethod + def _soft_clamp_tanh(x: torch.Tensor, bound: float) -> torch.Tensor: + if bound <= 0: + raise ValueError("bound must be > 0") + return bound * torch.tanh(x / bound) - Returns: - pred_mu: Predicted means, shape (batch_size, time_series_dims). - pred_sigma: Predicted standard deviations, shape (batch_size, time_series_dims). - pred_z_min: Predicted min z-scores, or None if do_scale=False. - pred_z_max: Predicted max z-scores, or None if do_scale=False. - """ + def forward(self, z: torch.Tensor): out = self.net(z) batch_size = out.size(0) + if self.do_scale: out = out.view(batch_size, 4, self.time_series_dims) pred_mu = out[:, 0, :] pred_log_sigma = out[:, 1, :] + + # z_min is predicted normally pred_z_min = out[:, 2, :] - pred_z_max = out[:, 3, :] + + # The 4th output is now the "raw delta" + raw_delta = out[:, 3, :] + + # Ensure the range is strictly positive (minimum 1e-4) + # F.softplus(x) = log(1 + exp(x)) + actual_range = F.softplus(raw_delta) + 1e-4 + + # Structurally guarantee z_max > z_min + pred_z_max = pred_z_min + actual_range + else: out = out.view(batch_size, 2, self.time_series_dims) pred_mu = out[:, 0, :] pred_log_sigma = out[:, 1, :] pred_z_min = None pred_z_max = None - pred_sigma = torch.exp(pred_log_sigma) - return pred_mu, pred_sigma, pred_z_min, pred_z_max + + pred_log_sigma_unclamped = pred_log_sigma + pred_log_sigma_clamped = self._soft_clamp_tanh(pred_log_sigma, bound=10.0) + pred_sigma = torch.exp(pred_log_sigma_clamped) + + return pred_mu, pred_sigma, pred_z_min, pred_z_max, pred_log_sigma_unclamped class _NormalizerModule(nn.Module): @@ -87,39 +107,81 @@ class _NormalizerModule(nn.Module): def __init__( self, - cond_module: nn.Module, + static_cond_module: nn.Module = None, + dynamic_cond_module: nn.Module = None, hidden_dim: int = 512, time_series_dims: int = 2, do_scale: bool = True, + stats_head_type: str = "mlp", + n_layers: int = 3, ): - """ - Args: - cond_module: ContextModule instance for embedding context variables. - hidden_dim: Hidden dimension size for the stats head. - time_series_dims: Number of time series dimensions. - do_scale: Whether to include scaling predictions. - """ super().__init__() - self.cond_module = cond_module - self.embedding_dim = cond_module.embedding_dim - self.stats_head = _StatsHead( + self.static_cond_module = static_cond_module + self.dynamic_cond_module = dynamic_cond_module + + # Determine embedding dimension from available modules + if static_cond_module is not None: + self.embedding_dim = static_cond_module.embedding_dim + elif dynamic_cond_module is not None: + self.embedding_dim = dynamic_cond_module.embedding_dim + else: + raise ValueError("At least one of static_cond_module or dynamic_cond_module must be provided") + + # If both modules exist, combine their embeddings + if static_cond_module is not None and dynamic_cond_module is not None: + combined_dim = static_cond_module.embedding_dim + dynamic_cond_module.embedding_dim + self.combine_mlp = nn.Sequential( + nn.Linear(combined_dim, self.embedding_dim), + nn.ReLU(), + ) + else: + self.combine_mlp = None + + StatsHeadCls = get_stats_head_cls(stats_head_type) + self.stats_head = StatsHeadCls( embedding_dim=self.embedding_dim, hidden_dim=hidden_dim, time_series_dims=time_series_dims, do_scale=do_scale, + n_layers=n_layers, ) - def forward(self, cat_vars_dict: dict): - """ - Compute normalization parameters from categorical context. - - Args: - cat_vars_dict: Mapping of context variable names to label tensors. - - Returns: - Tuple of (pred_mu, pred_sigma, pred_z_min, pred_z_max). - """ - embedding, _ = self.cond_module(cat_vars_dict) + def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_dict: dict = None): + embeddings = [] + + # Process static context variables + if self.static_cond_module is not None: + if static_context_vars_dict: + device = next(self.static_cond_module.parameters()).device + static_context_vars_dict = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in static_context_vars_dict.items() + } + static_embedding, _ = self.static_cond_module(static_context_vars_dict) + embeddings.append(static_embedding) + + # Process dynamic context variables + if self.dynamic_cond_module is not None: + if dynamic_context_vars_dict: + device = next(self.dynamic_cond_module.parameters()).device + dynamic_context_vars_dict = { + k: v.to(device, non_blocking=False) if isinstance(v, torch.Tensor) else v + for k, v in dynamic_context_vars_dict.items() + } + dynamic_embedding, _ = self.dynamic_cond_module(dynamic_context_vars_dict) + if torch.isnan(dynamic_embedding).any() or torch.isinf(dynamic_embedding).any(): + raise ValueError(f"NaN/Inf detected in dynamic embedding.") + embeddings.append(dynamic_embedding) + + # Combine embeddings + if len(embeddings) == 2: + combined = torch.cat(embeddings, dim=1) + embedding = self.combine_mlp(combined) + elif len(embeddings) == 1: + embedding = embeddings[0] + else: + raise ValueError("No context variables provided") + return self.stats_head(embedding) @@ -134,15 +196,8 @@ def __init__( dataset_cfg, normalizer_training_cfg, dataset, + context_cfg, ): - """ - Initializes the Normalizer training module. - - Args: - dataset_cfg: OmegaConf dataset config (provides context_vars, columns). - normalizer_training_cfg: Config for normalizer training (lr, batch_size). - dataset: Instance of TimeSeriesDataset containing data DataFrame. - """ super().__init__() self.save_hyperparameters(ignore=["dataset"]) @@ -150,281 +205,649 @@ def __init__( self.normalizer_training_cfg = normalizer_training_cfg self.dataset = dataset - self.context_vars = list(dataset_cfg.context_vars.keys()) - self.time_series_cols = dataset_cfg.time_series_columns[ - : dataset_cfg.time_series_dims - ] + # Get continuous variables from config if specified + self.continuous_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "continuous"] + self.categorical_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "categorical"] + self.dynamic_context_vars = [k for k, v in self.dataset_cfg.context_vars.items() if v[0] == "time_series"] + + self.static_context_vars = self.categorical_vars + self.continuous_vars + self.context_vars = self.static_context_vars + self.dynamic_context_vars + + # Normalizer-specific conditioning: subset of static vars for grouping / stats (e.g. per-station) + self.normalizer_group_vars = getattr(self.dataset_cfg, "normalizer_group_vars", None) + self.normalizer_static_vars = ( + list(self.normalizer_group_vars) if self.normalizer_group_vars is not None + else self.static_context_vars + ) + if self.normalizer_static_vars: + bad = [v for v in self.normalizer_static_vars if v not in self.static_context_vars] + if bad: + raise ValueError(f"normalizer_group_vars {self.normalizer_group_vars} contains vars not in static_context_vars: {bad}") + # When normalizer_group_vars is set it is always static-only (dynamic vars are rejected in + # _build_training_samples), so exclude dynamic vars from normalizer conditioning entirely. + self.normalizer_dynamic_vars = [] if self.normalizer_group_vars is not None else self.dynamic_context_vars + self._group_bin_edges = {} # filled in setup() when grouping by continuous vars + + self.time_series_cols = dataset_cfg.time_series_columns[: dataset_cfg.time_series_dims] self.time_series_dims = dataset_cfg.time_series_dims self.do_scale = dataset_cfg.scale - - self.context_module = ContextModule( - dataset_cfg.context_vars, - 256, + self.seq_len = dataset_cfg.seq_len + self.num_ts_steps = getattr(dataset_cfg, "num_ts_steps", None) + + self.static_module_type = self.dataset.static_module_type + self.dynamic_module_type = self.dataset.dynamic_module_type + self.stats_head_type = self.dataset.stats_head_type + self.loss_type = getattr(self.normalizer_training_cfg, "loss_type", "mse") + self.use_global_stats_preprocessing = bool( + getattr(self.normalizer_training_cfg, "use_global_stats_preprocessing", True) ) + if hasattr(self.dataset_cfg, "normalizer_use_global_stats_preprocessing"): + self.use_global_stats_preprocessing = bool( + self.dataset_cfg.normalizer_use_global_stats_preprocessing + ) + # When not using global preprocessing: mu is predicted in asinh space; clamp to avoid sinh explosion + self.max_asinh_mu = float(getattr(self.normalizer_training_cfg, "max_asinh_mu", 10.0)) + # Floor sigma and scale range so normalized values (z) cannot explode + self.min_sigma = float(getattr(self.normalizer_training_cfg, "min_sigma", 1e-3)) + self.min_scale_range = float(getattr(self.normalizer_training_cfg, "min_scale_range", 0.25)) + + self.register_buffer("global_mu_mean", torch.tensor(0.0)) + self.register_buffer("global_mu_std", torch.tensor(1.0)) + self.register_buffer("global_log_sigma_mean", torch.tensor(0.0)) + + # Create static context module (only normalizer_static_vars so it matches group conditioning) + self.static_context_module = None + if self.normalizer_static_vars: + StaticContextModuleCls = get_context_module_cls(self.static_module_type) + n_bins = getattr(self.dataset_cfg, "numeric_context_bins", 5) + self.static_context_vars_dict = {} + for k in self.normalizer_static_vars: + if k in self.continuous_vars: + # Binned continuous: treat as categorical with n_bins for normalizer conditioning + self.static_context_vars_dict[k] = ["categorical", n_bins] + else: + self.static_context_vars_dict[k] = self.dataset.context_var_dict[k] + self.static_context_module = StaticContextModuleCls( + self.static_context_vars_dict, + 256, + ) + # Create dynamic context module only for vars the normalizer should condition on. + # Filter to vars that are both in normalizer_dynamic_vars and are actual dynamic (time_series) vars. + # When normalizer_group_vars is set, normalizer_dynamic_vars is empty so the dict is empty + # and no dynamic module is created. + self.dynamic_context_module = None + _normalizer_dynamic_vars_dict = { + k: v for k, v in self.dataset_cfg.context_vars.items() + if k in self.normalizer_dynamic_vars and k in self.dynamic_context_vars + } + if _normalizer_dynamic_vars_dict and self.dynamic_module_type is not None: + DynamicContextModuleCls = get_context_module_cls("dynamic", self.dynamic_module_type) + dynamic_context_vars_dict = _normalizer_dynamic_vars_dict + dynamic_seq_len = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len + self.dynamic_context_module = DynamicContextModuleCls( + dynamic_context_vars_dict, + 256, + seq_len=dynamic_seq_len, + ) + self.normalizer_model = _NormalizerModule( - cond_module=self.context_module, + static_cond_module=self.static_context_module, + dynamic_cond_module=self.dynamic_context_module, hidden_dim=512, time_series_dims=self.time_series_dims, do_scale=self.do_scale, + stats_head_type=self.stats_head_type, + n_layers=context_cfg.normalizer.n_layers, ) # Will be populated in setup() - self.group_stats = {} + self.sample_stats = [] + self._verify_parameters() + + def _verify_parameters(self): + all_param_names = [name for name, _ in self.named_parameters()] + context_param_names = [name for name in all_param_names if 'cond_module' in name or 'context_module' in name] + stats_head_param_names = [name for name in all_param_names if 'stats_head' in name] + + if not context_param_names: + raise RuntimeError( + "Context module parameters not found! " + f"Found parameter names: {all_param_names[:10]}..." + ) + + print(f"[Normalizer] Found {len(context_param_names)} context module parameters") + print(f"[Normalizer] Found {len(stats_head_param_names)} stats head parameters") + print(f"[Normalizer] Total trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}") def setup(self, stage: Optional[str] = None): """ - Lightning hook: compute group statistics before training. + Lightning hook: prepare training data before training. """ - self.group_stats = self._compute_group_stats() + # Compute per-sample statistics + # Note: Using robust quantile scaling for targets to avoid outlier instability + mode = getattr(self.dataset_cfg, "normalizer_stats_mode", "sample") + group_vars = getattr(self.dataset_cfg, "normalizer_group_vars", None) + self.sample_stats = self._build_training_samples(mode, use_quantile_scale=True, group_vars=group_vars) + + if self.use_global_stats_preprocessing: + # --- COMPUTE GLOBAL TARGET STATS FOR SCALING --- + # 1. Global Mu Stats (for Z-score scaling) + all_mus = np.concatenate([s[2] for s in self.sample_stats]) + self.target_mu_mean = torch.tensor(all_mus.mean(), dtype=torch.float32) + self.target_mu_std = torch.tensor(all_mus.std() + 1e-8, dtype=torch.float32) + + # 2. Global Sigma Stats (for Log-Space Centering) + all_sigmas_concat = np.concatenate([s[3] for s in self.sample_stats]) + self.target_log_sigma_mean = torch.tensor( + np.log(all_sigmas_concat + 1e-8).mean(), dtype=torch.float32 + ) - def forward(self, cat_vars_dict: dict): - """ - Predict normalization parameters for a batch of categorical contexts. + self.global_mu_mean.fill_(self.target_mu_mean) + self.global_mu_std.fill_(self.target_mu_std) + self.global_log_sigma_mean.fill_(self.target_log_sigma_mean) + + print(f"Global Target Stats: Mu Mean={self.target_mu_mean:.4f}, Mu Std={self.target_mu_std:.4f}") + print(f"Global Target Log Sigma Mean: {self.target_log_sigma_mean:.4f}") + else: + # No global preprocessing: mu in asinh space, log(sigma) direct; identity for sigma centering + self.global_mu_mean.zero_() + self.global_mu_std.fill_(1.0) + self.global_log_sigma_mean.zero_() + print(f"Normalizer: use_global_stats_preprocessing=False — predicting asinh(mu) (clamp ±{self.max_asinh_mu}) and log(sigma) directly.") - Args: - cat_vars_dict: Mapping of context variable names to label tensors. + # Global range mean for safeguard: floor rng to 0.01 * global_rng_mean when do_scale + if self.do_scale: + all_rngs = [] + for s in self.sample_stats: + zlow, zhigh = s[4], s[5] + if zlow is not None and zhigh is not None: + all_rngs.extend((np.asarray(zhigh) - np.asarray(zlow)).flatten().tolist()) + + # Log initial predictions + if stage == "fit" or stage is None: + self.train() + + def _raw_mu_to_real(self, pred_mu_raw: torch.Tensor) -> torch.Tensor: + """Convert network mu output to real-world mu (handles both global and direct/asinh paths).""" + if self.use_global_stats_preprocessing: + return (pred_mu_raw * self.global_mu_std) + self.global_mu_mean + # Direct path: network predicts asinh(mu); invert with sinh and clamp for stability + return torch.sinh(torch.clamp(pred_mu_raw, -self.max_asinh_mu, self.max_asinh_mu)) + + def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_dict: dict = None): + """ + Predict normalization parameters. + Applies UNSCALING logic to convert network outputs back to real-world range. Returns: - Tuple of (pred_mu, pred_sigma, pred_z_min, pred_z_max). + Tuple of (pred_mu_real, pred_sigma_real, pred_z_min, pred_z_max, pred_log_sigma_raw). """ - return self.normalizer_model(cat_vars_dict) + pred_mu_raw, pred_sigma, pred_zmin, pred_zmax, pred_log_sigma_raw = self.normalizer_model(static_context_vars_dict, dynamic_context_vars_dict) - def training_step(self, batch, batch_idx: int): - """ - Training step: regress predicted stats against true group stats. + pred_mu_real = self._raw_mu_to_real(pred_mu_raw) - Args: - batch: Tuple of (cat_vars_dict, mu, sigma, zmin, zmax). - batch_idx: Batch index. + # Unscale Sigma: exp(NetworkLogOutput + GlobalLogMean) or exp(NetworkLogOutput) when no global + pred_log_sigma_real = pred_log_sigma_raw + self.global_log_sigma_mean + pred_sigma_real = torch.exp(pred_log_sigma_real).clamp(min=self.min_sigma) + # Safeguard: sigma must be at least 0.01 * global sigma (exp(global_log_sigma_mean)) + sigma_global = torch.exp(self.global_log_sigma_mean) + sigma_floor = max(self.min_sigma, (0.01 * sigma_global).item()) + pred_sigma_real = torch.clamp(pred_sigma_real, min=sigma_floor) - Returns: - loss tensor. + return pred_mu_real, pred_sigma_real, pred_zmin, pred_zmax, pred_log_sigma_raw + + def _compute_loss_mse(self, pred_mu_raw, pred_log_sigma_raw, mu_t_scaled, target_log_sigma_centered): + """ + Compute MSE loss in the SCALED space. """ - cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch - pred_mu, pred_sigma, pred_z_min, pred_z_max = self(cat_vars_dict) + loss_mu = F.mse_loss(pred_mu_raw, mu_t_scaled) + + # Use Huber Loss for Log Sigma to be robust against outliers + # Clamp targets slightly to prevent infinite loss if data is broken + target_log_sigma_centered = torch.clamp(target_log_sigma_centered, min=-10.0, max=10.0) + loss_sigma = F.smooth_l1_loss(pred_log_sigma_raw, target_log_sigma_centered, beta=1.0) + + return loss_mu, loss_sigma - loss_mu = F.mse_loss(pred_mu, mu_t) - loss_sigma = F.mse_loss(pred_sigma, sigma_t) - total_loss = loss_mu + loss_sigma + def training_step(self, batch, batch_idx: int): + static_context_vars_dict, dynamic_context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t = batch + pred_mu_raw, _, pred_z_min, pred_z_max, pred_log_sigma_raw = self.normalizer_model(static_context_vars_dict, dynamic_context_vars_dict) - if self.do_scale: - total_loss += F.mse_loss(pred_z_min, zmin_t) + F.mse_loss( - pred_z_max, zmax_t + # 2. Targets: scaled (global) or asinh(mu) + log(sigma) (direct) + if self.use_global_stats_preprocessing: + mu_t_scaled = (mu_t - self.global_mu_mean) / self.global_mu_std + else: + # asinh(mu) compresses full range to ~[-8, 8]; clamp so target matches forward clamp + mu_t_scaled = torch.clamp( + torch.asinh(mu_t), + -self.max_asinh_mu, + self.max_asinh_mu, ) + target_log_sigma_centered = torch.log(sigma_t + 1e-8) - self.global_log_sigma_mean - self.log("train_loss", total_loss, prog_bar=True) + # 3. Compute Loss + loss_mu, loss_sigma = self._compute_loss_mse( + pred_mu_raw, pred_log_sigma_raw, mu_t_scaled, target_log_sigma_centered + ) + + total_loss = loss_mu + loss_sigma + + # 4. Scaling parameters (z_min/z_max) - These are already naturally roughly scaled (-2 to 2) + # We use SmoothL1Loss to be robust to outliers + if self.do_scale: + if torch.isnan(pred_z_min).any() or torch.isnan(pred_z_max).any(): + raise ValueError(f"NaN detected in scale predictions at batch {batch_idx}") + loss_zmin = F.smooth_l1_loss(pred_z_min, zmin_t, beta=1.0) + loss_zmax = F.smooth_l1_loss(pred_z_max, zmax_t, beta=1.0) + total_loss += loss_zmin + loss_zmax + else: + loss_zmin = torch.tensor(0.0, device=total_loss.device) + loss_zmax = torch.tensor(0.0, device=total_loss.device) + + if torch.isnan(total_loss) or torch.isinf(total_loss): + raise ValueError(f"NaN/Inf loss detected.") + + # Log metrics + self.log("train_loss", total_loss, prog_bar=True, sync_dist=True) + self.log("loss_mu", loss_mu, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) + self.log("loss_sigma", loss_sigma, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True) + + if batch_idx % 100 == 0: + with torch.no_grad(): + pred_mu_real = self._raw_mu_to_real(pred_mu_raw) + self.log("pred_mu_mean_real", pred_mu_real.mean(), on_step=True, on_epoch=False) + self.log("target_mu_mean_real", mu_t.mean(), on_step=True, on_epoch=False) + return total_loss def configure_optimizers(self): - """ - Configure optimizer for normalizer training. - - Returns: - Adam optimizer instance. - """ - return torch.optim.Adam(self.parameters(), lr=self.normalizer_training_cfg.lr) + optimizer = torch.optim.Adam( + self.parameters(), + lr=self.normalizer_training_cfg.lr, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0 + ) + return optimizer + + def on_train_batch_end(self, outputs, batch, batch_idx): + if batch_idx % 100 == 0: + total_norm = 0.0 + for p in self.parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** (1. / 2) + self.log("grad_norm", total_norm, on_step=True, on_epoch=False) def train_dataloader(self): - """ - Returns a DataLoader over per-group statistics samples. - """ ds = self._create_training_dataset() return DataLoader( ds, batch_size=self.normalizer_training_cfg.batch_size, shuffle=True, - num_workers=1, + num_workers=4, + persistent_workers=True, + pin_memory=torch.cuda.is_available(), + prefetch_factor=2, ) - def _compute_group_stats(self) -> dict: - """ - Compute per-group (context combination) statistics from raw data. - - Returns: - Mapping from context tuple to (mu_array, std_array, zmin_array, zmax_array). - """ + def _compute_per_sample_stats(self) -> list: + # Same implementation as before df = self.dataset.data.copy() - grouped_stats = {} - for group_vals, group_df in df.groupby(self.context_vars): - dimension_points = [[] for _ in range(self.time_series_dims)] - for _, row in group_df.iterrows(): - for d, col_name in enumerate(self.time_series_cols): - arr = np.array(row[col_name], dtype=np.float32).flatten() - dimension_points[d].append(arr) - dimension_points = [np.concatenate(d, axis=0) for d in dimension_points] - mu_array = np.array( - [pts.mean() for pts in dimension_points], dtype=np.float32 - ) - std_array = np.array( - [pts.std() + 1e-8 for pts in dimension_points], dtype=np.float32 - ) + sample_stats = [] + continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] + + for idx, row in df.iterrows(): + context_vars_dict = {} + for var_name in self.static_context_vars: + if var_name in row: + if var_name in continuous_vars: + # Normalize inputs using min-max if available, otherwise just cast + context_vars_dict[var_name] = torch.tensor(row[var_name], dtype=torch.float32) + else: + context_vars_dict[var_name] = torch.tensor(row[var_name], dtype=torch.long) + + dynamic_ctx_dict = {} + for var_name in self.dynamic_context_vars: + ts_data = row.get(var_name) + if ts_data is not None: + if isinstance(ts_data, (np.ndarray, list)): + dynamic_ctx_dict[var_name] = np.array(ts_data) + else: + context_len = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len + dynamic_ctx_dict[var_name] = np.full(context_len, ts_data) + + dimension_points = [] + for col_name in self.time_series_cols: + arr = np.array(row[col_name], dtype=np.float32).flatten() + dimension_points.append(arr) + + mu_array = np.array([pts.mean() for pts in dimension_points], dtype=np.float32) + std_array = np.array([pts.std() + 1e-8 for pts in dimension_points], dtype=np.float32) if self.do_scale: - z_min_array = np.array( - [ - (pts - mu).min() / std - for pts, mu, std in zip(dimension_points, mu_array, std_array) - ], - dtype=np.float32, - ) - z_max_array = np.array( - [ - (pts - mu).max() / std - for pts, mu, std in zip(dimension_points, mu_array, std_array) - ], - dtype=np.float32, - ) + z_min_array = np.array([(pts - mu).min() / std for pts, mu, std in zip(dimension_points, mu_array, std_array)], dtype=np.float32) + z_max_array = np.array([(pts - mu).max() / std for pts, mu, std in zip(dimension_points, mu_array, std_array)], dtype=np.float32) else: z_min_array = z_max_array = None - - grouped_stats[tuple(group_vals)] = ( - mu_array, - std_array, - z_min_array, - z_max_array, - ) - return grouped_stats + + sample_stats.append((context_vars_dict, dynamic_ctx_dict, mu_array, std_array, z_min_array, z_max_array)) + + return sample_stats def _create_training_dataset(self) -> Dataset: - """ - Build an internal Dataset yielding true stats for each context group. - - Returns: - PyTorch Dataset of samples (cat_vars_dict, mu, sigma, zmin, zmax). - """ - data_tuples = [ - (ctx_tuple, mu_arr, sigma_arr, zmin_arr, zmax_arr) - for ctx_tuple, ( - mu_arr, - sigma_arr, - zmin_arr, - zmax_arr, - ) in self.group_stats.items() - ] - class _TrainSet(Dataset): - """ - Adapter Dataset to wrap group_stats tuples for DataLoader. - """ - - def __init__(self, samples, context_vars, do_scale): + def __init__(self, samples, dynamic_context_vars, do_scale, dataset_cfg): self.samples = samples - self.context_vars = context_vars + self.dynamic_context_vars = dynamic_context_vars self.do_scale = do_scale + self.dataset_cfg = dataset_cfg def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int): - """ - Returns one training sample. - - Args: - idx: Index of the sample. - - Returns: - cat_vars_dict: Tensor dict of context labels. - mu_t: True mean tensor. - sigma_t: True std tensor. - zmin_t: True min z-score tensor or None. - zmax_t: True max z-score tensor or None. - """ - ctx_tuple, mu_arr, sigma_arr, zmin_arr, zmax_arr = self.samples[idx] - cat_vars_dict = { - var_name: torch.tensor(ctx_tuple[i], dtype=torch.long) - for i, var_name in enumerate(self.context_vars) - } + static_context_vars_dict, dynamic_context_vars_dict, mu_arr, sigma_arr, zmin_arr, zmax_arr = self.samples[idx] + for var_name in self.dynamic_context_vars: + if var_name in dynamic_context_vars_dict: + ts_data = np.array(dynamic_context_vars_dict[var_name]) + var_info = self.dataset_cfg.context_vars.get(var_name, None) + if var_info and var_info[1] is not None: + dynamic_context_vars_dict[var_name] = torch.from_numpy(ts_data).long() + else: + dynamic_context_vars_dict[var_name] = torch.from_numpy(ts_data).float() + mu_t = torch.from_numpy(mu_arr).float() sigma_t = torch.from_numpy(sigma_arr).float() - zmin_t = torch.from_numpy(zmin_arr).float() if self.do_scale else None - zmax_t = torch.from_numpy(zmax_arr).float() if self.do_scale else None - return cat_vars_dict, mu_t, sigma_t, zmin_t, zmax_t + if self.do_scale: + zmin_t = torch.from_numpy(zmin_arr).float() + zmax_t = torch.from_numpy(zmax_arr).float() + else: + # Return dummy tensors so DataLoader collate does not see None + zmin_t = torch.zeros_like(mu_t) + zmax_t = torch.zeros_like(mu_t) + return static_context_vars_dict, dynamic_context_vars_dict, mu_t, sigma_t, zmin_t, zmax_t - return _TrainSet(data_tuples, self.context_vars, self.do_scale) + return _TrainSet(self.sample_stats, self.dynamic_context_vars, self.do_scale, self.dataset_cfg) def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Normalize a DataFrame of time series using learned parameters. - - Pads or splits if needed, then applies z-score and min-max scaling. - - Args: - df: Input DataFrame with raw time series columns. - - Returns: - DataFrame with normalized series in same columns. - """ missing = [c for c in self.time_series_cols if c not in df.columns] - if missing: df = split_timeseries(df, self.time_series_cols) - missing = [c for c in self.time_series_cols if c not in df.columns] - - assert not missing, ( - "Normalizer.transform expects data in split format with columns " - f"{self.time_series_cols}." - ) - + df_out = df.copy() self.eval() + continuous_vars = set(self.continuous_vars) + categorical_ts = getattr(self.dataset, 'categorical_time_series', {}) + group_edges = getattr(self, "_group_bin_edges", {}) + with torch.no_grad(): - for i, row in df_out.iterrows(): - ctx = { - v: torch.tensor(row[v], dtype=torch.long).unsqueeze(0) - for v in self.context_vars - } - mu, sigma, zmin, zmax = self(ctx) - mu, sigma = mu[0].cpu().numpy(), sigma[0].cpu().numpy() + for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Normalizing"): + static_context_vars_dict = {} + dynamic_context_vars_dict = {} + for v in self.context_vars: + if v in self.normalizer_static_vars: + if v in continuous_vars and v in group_edges: + edges = group_edges[v] + bin_idx = np.digitize(np.asarray([float(row[v])]), edges[1:-1], right=False) + bin_idx = np.clip(bin_idx, 0, len(edges) - 2).item() + static_context_vars_dict[v] = torch.tensor(bin_idx, dtype=torch.long).unsqueeze(0) + elif v in continuous_vars: + static_context_vars_dict[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + else: + static_context_vars_dict[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) + elif v in self.static_context_vars: + continue + elif v in self.normalizer_dynamic_vars: + dynamic_context_vars_dict[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + elif v in self.dynamic_context_vars: + continue # dynamic var excluded from normalizer conditioning + else: + raise ValueError(f"Variable {v} not found in context_vars") + + # self(ctx) calls forward, which automatically UNSCALES predictions + pred_mu, pred_sigma, pred_zmin, pred_zmax, _ = self(static_context_vars_dict, dynamic_context_vars_dict) + mu, sigma = pred_mu[0].cpu().numpy(), pred_sigma[0].cpu().numpy() for d, col in enumerate(self.time_series_cols): + if col in categorical_ts: + df_out.at[i, col] = np.asarray(row[col]).astype(np.int32) + continue + arr = np.asarray(row[col], dtype=np.float32) - z = (arr - mu[d]) / (sigma[d] + 1e-8) + + sigma_floor = max(self.min_sigma, 0.01 * np.exp(self.global_log_sigma_mean.cpu().item())) + sigma_eff = max(float(sigma[d]), sigma_floor) + z = (arr - mu[d]) / sigma_eff if self.do_scale: - zmin_, zmax_ = zmin[0, d].item(), zmax[0, d].item() + zmin_, zmax_ = pred_zmin[0, d].item(), pred_zmax[0, d].item() rng = (zmax_ - zmin_) + 1e-8 - z = (z - zmin_) / rng + rng_floor = max(self.min_scale_range, .25) + rng_eff = max(rng, rng_floor) + z = (z - zmin_) / rng_eff + df_out.at[i, col] = z return df_out def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Denormalize a DataFrame of z-scored series back to original scale. - - Args: - df: DataFrame with normalized series columns. - - Returns: - DataFrame with denormalized series. - """ missing = [c for c in self.time_series_cols if c not in df.columns] - if missing: df = split_timeseries(df, self.time_series_cols) - missing = [c for c in self.time_series_cols if c not in df.columns] - - assert not missing, ( - "Normalizer.inverse_transform expects split format with columns " - f"{self.time_series_cols}." - ) df_out = df.copy() self.eval() + # continuous_vars = getattr(self.dataset_cfg, "continuous_context_vars", None) or [] + + continuous_vars = set(self.continuous_vars) + categorical_ts = getattr(self.dataset, "categorical_time_series", {}) + group_edges = getattr(self, "_group_bin_edges", {}) + # Pre-compute global fallback stats (used when row has no context columns) + _global_mu = np.full(self.time_series_dims, self.global_mu_mean.item(), dtype=np.float32) + _global_sigma = np.full( + self.time_series_dims, + max(float(np.exp(self.global_log_sigma_mean.item())), self.min_sigma), + dtype=np.float32, + ) + with torch.no_grad(): - for i, row in df_out.iterrows(): - ctx = { - v: torch.tensor(row[v], dtype=torch.long).unsqueeze(0) - for v in self.context_vars - } - mu, sigma, zmin, zmax = self(ctx) - mu, sigma = mu[0].cpu().numpy(), sigma[0].cpu().numpy() + for i, row in tqdm(df_out.iterrows(), total=len(df_out), desc="Inverse normalizing"): + # Unconditional path: no context columns present → use global stats directly + if not any(v in row for v in self.normalizer_static_vars + self.normalizer_dynamic_vars): + for d, col in enumerate(self.time_series_cols): + z = np.asarray(row[col], dtype=np.float32) + df_out.at[i, col] = (z * _global_sigma[d] + _global_mu[d]).tolist() + continue + + static_context_vars_dict = {} + dynamic_context_vars_dict = {} + for v in self.context_vars: + if v in self.normalizer_static_vars: + # Normalizer only conditions on normalizer_static_vars (e.g. station) + if v in continuous_vars and v in group_edges: + edges = group_edges[v] + bin_idx = np.digitize(np.asarray([float(row[v])]), edges[1:-1], right=False) + bin_idx = np.clip(bin_idx, 0, len(edges) - 2).item() + static_context_vars_dict[v] = torch.tensor(bin_idx, dtype=torch.long).unsqueeze(0) + elif v in continuous_vars: + static_context_vars_dict[v] = torch.tensor(row[v], dtype=torch.float32).unsqueeze(0) + else: + static_context_vars_dict[v] = torch.tensor(row[v], dtype=torch.long).unsqueeze(0) + elif v in self.static_context_vars: + continue # not used by normalizer + elif v in self.normalizer_dynamic_vars: + # Dynamic var included in normalizer conditioning + val = row[v] + dtype_np = np.int64 if v in categorical_ts else np.float32 + arr = np.asarray(val, dtype=dtype_np) + if arr.ndim == 0: + arr = arr[np.newaxis, np.newaxis] + elif arr.ndim == 1: + arr = arr[np.newaxis, :] + dtype = torch.long if v in categorical_ts else torch.float32 + dynamic_context_vars_dict[v] = torch.tensor(arr, dtype=dtype) + elif v in self.dynamic_context_vars: + continue # dynamic var excluded from normalizer conditioning + + # self(ctx) calls forward, which automatically UNSCALES predictions + pred_mu, pred_sigma, pred_zmin, pred_zmax, _ = self(static_context_vars_dict, dynamic_context_vars_dict) + mu, sigma = pred_mu[0].cpu().numpy(), pred_sigma[0].cpu().numpy() for d, col in enumerate(self.time_series_cols): z = np.asarray(row[col], dtype=np.float32) + sigma_floor = max(self.min_sigma, 0.01 * np.exp(self.global_log_sigma_mean.cpu().item())) + sigma_eff = max(float(sigma[d]), sigma_floor) if self.do_scale: - zmin_, zmax_ = zmin[0, d].item(), zmax[0, d].item() + zmin_, zmax_ = pred_zmin[0, d].item(), pred_zmax[0, d].item() rng = (zmax_ - zmin_) + 1e-8 - z = z * rng + zmin_ - arr = z * (sigma[d] + 1e-8) + mu[d] - df_out.at[i, col] = arr + rng_eff = max(rng, self.min_scale_range) + z = z * rng_eff + zmin_ + arr = z * sigma_eff + mu[d] + df_out.at[i, col] = arr.tolist() return df_out + + def _build_training_samples( + self, + mode: str = "sample", + group_vars: Optional[list[str]] = None, + use_quantile_scale: bool = False, + q_low: float = 0.02, + q_high: float = 0.98, + ) -> list: + """ + Build training samples. + Note: Quantile scaling (q_low/q_high) only affects z_min/z_max calculation, + not the mean/std targets. + """ + assert mode in {"sample", "group"}, f"mode must be 'sample' or 'group', got {mode}" + + df = self.dataset.data.copy() + + continuous_vars = set(self.continuous_vars) + dynamic_vars = set(self.dynamic_context_vars) + static_vars = [v for v in self.static_context_vars] + + if group_vars is None: + # Default: group by all static categorical (continuous in normalizer_static_vars get binned in group mode) + group_vars = [v for v in self.normalizer_static_vars if (v not in continuous_vars and v not in dynamic_vars)] + + bad = [v for v in group_vars if v in dynamic_vars] + if bad: + raise ValueError(f"group_vars contains dynamic vars {bad}") + + def _row_stats(row) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]: + dim_points = [] + for d, col_name in enumerate(self.time_series_cols): + arr = np.asarray(row[col_name], dtype=np.float32).reshape(-1) + dim_points.append(arr) + + mu = np.array([x.mean() for x in dim_points], dtype=np.float32) + std = np.array([x.std() + 1e-8 for x in dim_points], dtype=np.float32) + + if not self.do_scale: + return mu, std, None, None + + zlow = np.zeros(self.time_series_dims, dtype=np.float32) + zhigh = np.zeros(self.time_series_dims, dtype=np.float32) + + for i, (x, m, s) in enumerate(zip(dim_points, mu, std)): + z = (x - m) / s + if use_quantile_scale: + zlow[i] = np.quantile(z, q_low).astype(np.float32) + zhigh[i] = np.quantile(z, q_high).astype(np.float32) + else: + zlow[i] = z.min().astype(np.float32) + zhigh[i] = z.max().astype(np.float32) + + return mu, std, zlow, zhigh + + samples = [] + + if mode == "sample": + for _, row in df.iterrows(): + context_vars_dict = {} + for v in self.normalizer_static_vars: + if v not in row: + continue + if v in continuous_vars: + context_vars_dict[v] = torch.tensor(row[v], dtype=torch.float32) + else: + context_vars_dict[v] = torch.tensor(row[v], dtype=torch.long) + + dynamic_ctx_dict = {} + for v in self.normalizer_dynamic_vars: + if v not in row: continue + ts_data = row[v] + if isinstance(ts_data, (np.ndarray, list)): + dynamic_ctx_dict[v] = np.array(ts_data) + else: + L = self.num_ts_steps if self.num_ts_steps is not None else self.seq_len + dynamic_ctx_dict[v] = np.full(L, ts_data) + + mu, std, zlow, zhigh = _row_stats(row) + samples.append((context_vars_dict, dynamic_ctx_dict, mu, std, zlow, zhigh)) + + return samples + + # mode == "group" + # Pre-process continuous vars for grouping by binning; store edges for inverse_transform + self._group_bin_edges = {} + for v in group_vars: + if v in continuous_vars: + n_bins = getattr(self.dataset_cfg, "numeric_context_bins", 5) + binned, edges = pd.cut( + df[v], bins=n_bins, labels=False, include_lowest=True, + duplicates="drop", retbins=True + ) + self._group_bin_edges[v] = np.asarray(edges, dtype=np.float64) + df[v] = binned + grouped = df.groupby(list(group_vars), dropna=False) + + for group_key, gdf in grouped: + if len(group_vars) == 1: + group_key = (group_key,) + + context_vars_dict = {} + for i, v in enumerate(group_vars): + val = group_key[i] + if v in continuous_vars: + if pd.isna(val): val = 0 + context_vars_dict[v] = torch.tensor(int(val), dtype=torch.long) + else: + context_vars_dict[v] = torch.tensor(val, dtype=torch.long) + + dynamic_ctx_dict = {} + + dim_points = [[] for _ in range(self.time_series_dims)] + for _, row in gdf.iterrows(): + for d, col_name in enumerate(self.time_series_cols): + arr = np.asarray(row[col_name], dtype=np.float32).reshape(-1) + dim_points[d].append(arr) + + dim_points = [np.concatenate(xs, axis=0) if len(xs) else np.zeros((0,), dtype=np.float32) + for xs in dim_points] + + mu = np.array([x.mean() if x.size else 0.0 for x in dim_points], dtype=np.float32) + std = np.array([x.std() + 1e-8 if x.size else 1.0 for x in dim_points], dtype=np.float32) + + if self.do_scale: + zlow = np.zeros(self.time_series_dims, dtype=np.float32) + zhigh = np.zeros(self.time_series_dims, dtype=np.float32) + for i, (x, m, s) in enumerate(zip(dim_points, mu, std)): + if x.size == 0: + zlow[i], zhigh[i] = -2.0, 2.0 + continue + z = (x - m) / s + if use_quantile_scale: + zlow[i] = np.quantile(z, q_low).astype(np.float32) + zhigh[i] = np.quantile(z, q_high).astype(np.float32) + else: + zlow[i] = z.min().astype(np.float32) + zhigh[i] = z.max().astype(np.float32) + else: + zlow = zhigh = None + + samples.append((context_vars_dict, dynamic_ctx_dict, mu, std, zlow, zhigh)) + + return samples \ No newline at end of file diff --git a/cents/models/stats_head_registry.py b/cents/models/stats_head_registry.py new file mode 100644 index 0000000..19cbf02 --- /dev/null +++ b/cents/models/stats_head_registry.py @@ -0,0 +1,55 @@ +from typing import Dict + +_STATS_HEAD_REGISTRY = {} + + +def register_stats_head(*names): + """ + Decorator: registers a stats head class under one or more names. + + Args: + *names: One or more names to register the class under. + + Example: + @register_stats_head("default", "mlp") + class MLPStatsHead(nn.Module): + pass + """ + + def decorator(cls): + for name in names: + _STATS_HEAD_REGISTRY[name] = cls + return cls + + return decorator + + +def get_stats_head_cls(key: str) -> type: + """ + Fetch the stats head class for `key`. Raises if not found. + + Args: + key: The name of the stats head to retrieve. + + Returns: + The stats head class. + + Raises: + ValueError: If the key is not found in the registry. + """ + try: + return _STATS_HEAD_REGISTRY[key] + except KeyError: + raise ValueError( + f"Unknown stats head '{key}'. Available: {list(_STATS_HEAD_REGISTRY.keys())}" + ) + + +def get_available_stats_heads() -> list[str]: + """ + Get a list of all available stats head names. + + Returns: + List of available stats head names. + """ + return list(_STATS_HEAD_REGISTRY.keys()) diff --git a/cents/trainer.py b/cents/trainer.py index 70dc8ed..a8797b9 100644 --- a/cents/trainer.py +++ b/cents/trainer.py @@ -1,9 +1,10 @@ +import csv from pathlib import Path from typing import Dict, List, Optional import pytorch_lightning as pl import wandb -from hydra import compose, initialize_config_dir +from datetime import datetime from omegaconf import DictConfig, OmegaConf from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.loggers import WandbLogger @@ -13,6 +14,7 @@ from cents.eval.eval import Evaluator from cents.models.registry import get_model_cls from cents.utils.utils import get_normalizer_training_config +from cents.utils.config_loader import load_yaml, apply_overrides PKG_ROOT = Path(__file__).resolve().parent CONF_DIR = PKG_ROOT / "config" @@ -69,22 +71,28 @@ def __init__( self.model = self._instantiate_model() self.pl_trainer = self._instantiate_trainer() - def fit(self) -> "Trainer": + def fit(self, ckpt_path: Optional[str] = None) -> "Trainer": """ Start training. + Args: + ckpt_path: Optional path to checkpoint file (.ckpt) to resume training from. + If provided, training will resume from this checkpoint. + Returns: Self, to allow method chaining. """ if self.model_type == "normalizer": - self.pl_trainer.fit(self.model) + self.pl_trainer.fit(self.model, ckpt_path=ckpt_path) else: train_loader = self.dataset.get_train_dataloader( batch_size=self.cfg.trainer.batch_size, shuffle=True, - num_workers=4, + num_workers=4, # Maximum for 7.5GB/10GB GPU usage + persistent_workers=True, ) - self.pl_trainer.fit(self.model, train_loader, None) + print(f"[Cents] Training model on {len(train_loader)} batches") + self.pl_trainer.fit(self.model, train_loader, None, ckpt_path=ckpt_path) return self def get_data_generator(self) -> DataGenerator: @@ -135,22 +143,40 @@ def evaluate(self, **kwargs) -> Dict: def _compose_cfg(self, ov: List[str]) -> DictConfig: """ - Compose the full Hydra configuration by merging defaults, - dataset-specific config, and any user overrides. + Compose configuration by loading YAMLs and applying overrides. - Args: - ov: List of Hydra-style overrides. - - Returns: - OmegaConf DictConfig. + Structure: + cfg.model <- config/model/{model_type}.yaml + cfg.trainer <- config/trainer/{model_type}.yaml + cfg.dataset <- provided dataset.cfg (if any) """ - base_ov = [f"model={self.model_type}", f"trainer={self.model_type}"] - with initialize_config_dir(str(CONF_DIR), version_base=None): - cfg = compose(config_name="config", overrides=base_ov + ov) + model_cfg = load_yaml(CONF_DIR / "model" / f"{self.model_type}.yaml") + trainer_cfg = load_yaml(CONF_DIR / "trainer" / f"{self.model_type}.yaml") + + cfg = OmegaConf.create({}) + cfg.model = model_cfg + cfg.trainer = trainer_cfg + if self.dataset is not None: cfg.dataset = OmegaConf.create( OmegaConf.to_container(self.dataset.cfg, resolve=True) ) + + cfg = apply_overrides(cfg, ov) + + # Ensure required top-level fields exist without Hydra + if not hasattr(cfg, "device"): + cfg.device = "auto" + if not hasattr(cfg, "job_name"): + ds_name = getattr(cfg, "dataset", {}).get("name", "dataset") if isinstance(getattr(cfg, "dataset", {}), dict) else getattr(cfg.dataset, "name", "dataset") + ds_group = getattr(cfg, "dataset", {}).get("user_group", "all") if isinstance(getattr(cfg, "dataset", {}), dict) else getattr(cfg.dataset, "user_group", "all") + model_name = getattr(cfg.model, "name", self.model_type) + cfg.job_name = f"{model_name}_{ds_name}_{ds_group}" + if not hasattr(cfg, "run_dir") or not cfg.run_dir: + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + cfg.run_dir = str(PKG_ROOT / "outputs" / cfg.job_name / timestamp) + # Checkpoint dir: run_dir/checkpoints so run root stays clean + cfg.checkpoint_dir = str(Path(cfg.run_dir) / "checkpoints") return cfg def _instantiate_model(self): @@ -170,32 +196,92 @@ def _instantiate_model(self): def _instantiate_trainer(self) -> pl.Trainer: """ Build a PyTorch Lightning Trainer with ModelCheckpoint and loggers. - - Returns: - Configured pl.Trainer instance. + Saves checkpoints every N epochs (if configured) and always keeps last.ckpt. """ tc = self.cfg.trainer callbacks = [] + + # ---- Build descriptive base filename ---- + filename_parts = [ + self.cfg.dataset.name, + self.model_type, + f"dim{self.cfg.dataset.time_series_dims}", + ] + + from cents.utils.utils import get_context_config + context_cfg = get_context_config() + + if context_cfg.static_context.type: + filename_parts.append(f"ctx{context_cfg.static_context.type}") + + if context_cfg.dynamic_context.type: + filename_parts.append(f"dyn{context_cfg.dynamic_context.type}") + + if context_cfg.normalizer.stats_head_type: + filename_parts.append(f"stats{context_cfg.normalizer.stats_head_type}") + + base_name = "_".join(filename_parts) + + # ---- Checkpoint directory ---- + checkpoint_dir = (Path(self.cfg.run_dir) / "checkpoints") + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # ---- Periodic saving config ---- + every_n = getattr(tc.checkpoint, "every_n_epochs", None) + + if every_n and every_n > 0: + filename = f"{base_name}_epoch={{epoch:04d}}" + every_n_epochs = every_n + save_top_k = -1 # keep ALL periodic checkpoints + else: + filename = base_name + every_n_epochs = None + save_top_k = getattr(tc.checkpoint, "save_top_k", 1) + + print(f"Saving every {every_n_epochs} epochs (if configured)") + callbacks.append( ModelCheckpoint( - dirpath=self.cfg.run_dir, - filename=( - f"{self.cfg.dataset.name}_{self.model_type}" - f"_dim{self.cfg.dataset.time_series_dims}" - ), - save_last=tc.checkpoint.save_last, + dirpath=checkpoint_dir, + filename=filename, + every_n_epochs=every_n_epochs, save_on_train_epoch_end=True, + save_last=True, # always keep last.ckpt + save_top_k=save_top_k, + auto_insert_metric_name=False, ) ) + callbacks.append(EvalAfterTraining(self.cfg, self.dataset)) + + fid_cfg = tc.get("intermediate_fid", {}) + if fid_cfg.get("enabled", False) and self.dataset is not None: + callbacks.append(IntermediateFIDCallback( + cfg=self.cfg, + dataset=self.dataset, + every_n_epochs=fid_cfg.get("every_n_epochs", 20), + n_samples=fid_cfg.get("n_samples", 3500), + fast_timesteps=fid_cfg.get("fast_timesteps", 50), + top_k=fid_cfg.get("top_k", 3), + )) + + if getattr(self.cfg, "run_dir", None): + callbacks.append(LogLossToCsv(self.cfg.run_dir)) + + # ---- Logger ---- logger = False if getattr(self.cfg, "wandb", None) and self.cfg.wandb.enabled: + wandb_id_file = Path(self.cfg.run_dir) / "wandb_run_id.txt" + existing_run_id = wandb_id_file.read_text().strip() if wandb_id_file.exists() else None logger = WandbLogger( project=self.cfg.wandb.project or "cents", entity=self.cfg.wandb.entity, name=self.cfg.wandb.name, save_dir=self.cfg.run_dir, + id=existing_run_id, + resume="must" if existing_run_id else None, ) + callbacks.append(_WandbRunIdSaver(self.cfg.run_dir)) return pl.Trainer( max_epochs=tc.max_epochs, @@ -205,12 +291,64 @@ def _instantiate_trainer(self) -> pl.Trainer: precision=tc.precision, log_every_n_steps=tc.get("log_every_n_steps", 1), accumulate_grad_batches=tc.get("gradient_accumulate_every", 1), + gradient_clip_val=tc.get("gradient_clip_val", None), callbacks=callbacks, logger=logger, default_root_dir=self.cfg.run_dir, ) + +class LogLossToCsv(Callback): + """Append epoch loss values to runs//train_losses.csv.""" + + def __init__(self, run_dir: str): + super().__init__() + self.run_dir = Path(run_dir) + self._csv_path = self.run_dir / "train_losses.csv" + # Don't rewrite the header if the file already exists (resume case) + self._header_written = self._csv_path.exists() + + def _ensure_header(self, metric_names: List[str]) -> None: + if self._header_written: + return + self.run_dir.mkdir(parents=True, exist_ok=True) + with open(self._csv_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["epoch"] + metric_names) + self._header_written = True + + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + metrics = trainer.callback_metrics + if not metrics: + return + # Filter to loss-like keys and sort for consistent column order + loss_keys = sorted(k for k in metrics if "loss" in k.lower()) + if not loss_keys: + return + self._ensure_header(loss_keys) + row = [trainer.current_epoch] + for k in loss_keys: + v = metrics[k] + row.append(float(v) if hasattr(v, "item") else float(v)) + with open(self._csv_path, "a", newline="") as f: + csv.writer(f).writerow(row) + + +class _WandbRunIdSaver(Callback): + """Persist the W&B run ID to disk so --resume-from-checkpoint can continue the same run.""" + + def __init__(self, run_dir: str): + super().__init__() + self._id_file = Path(run_dir) / "wandb_run_id.txt" + + def on_train_start(self, trainer, pl_module): # noqa: ARG002 + run = getattr(trainer.logger, "experiment", None) + if run is not None and not self._id_file.exists(): + self._id_file.write_text(run.id) + print(f"[Cents] Saved W&B run ID {run.id} to {self._id_file}") + + class EvalAfterTraining(Callback): """Run full evaluator at the *end* of training and log metrics to W&B.""" @@ -229,3 +367,233 @@ def on_train_end(self, trainer, pl_module): run = getattr(trainer.logger, "experiment", None) if run is not None: run.log(results["metrics"]) + + +class IntermediateFIDCallback(Callback): + """ + Every N epochs: generate samples, compute context-FID, save a checkpoint. + + Checkpoint retention policy: + - Top-k epochs by FID (lower = better) are kept permanently. + - The FID-check epochs immediately before and after each top-k epoch are + kept as "neighbors" (i.e. ±every_n_epochs). + - The two most recent FID-check checkpoints are always kept as a rolling + buffer so that a newly-crowned top-k epoch's before-neighbor is available. + - Everything else is deleted after each FID check. + + Results are written to {run_dir}/intermediate_fid.csv and logged to W&B. + Checkpoints live in {run_dir}/fid_checkpoints/. + """ + + def __init__( + self, + cfg, + dataset, + every_n_epochs: int = 20, + n_samples: int = 3500, + fast_timesteps: int = 50, + top_k: int = 3, + ): + super().__init__() + self.cfg = cfg + self.dataset = dataset + self.every_n_epochs = every_n_epochs + self.n_samples = n_samples + self.fast_timesteps = fast_timesteps + self.top_k = top_k + self._csv_path = Path(cfg.run_dir) / "intermediate_fid.csv" + self._fid_ckpt_dir = Path(cfg.run_dir) / "fid_checkpoints" + self._fid_ckpt_dir.mkdir(parents=True, exist_ok=True) + # (fid, epoch, ckpt_path) — only currently-kept records + self._fid_records: List[tuple] = [] + self._header_written = self._csv_path.exists() + self._reload_records() + + def _reload_records(self) -> None: + """On resume: rebuild _fid_records from the CSV, keeping only entries whose checkpoint still exists on disk.""" + if not self._csv_path.exists(): + return + import csv as _csv + with open(self._csv_path, newline="") as f: + reader = _csv.DictReader(f) + for row in reader: + try: + epoch = int(row["epoch"]) - 1 # CSV stores epoch+1 + fid = float(row["context_fid"]) + except (KeyError, ValueError): + continue + ckpt_path = str(self._fid_ckpt_dir / f"fid_epoch={epoch:04d}.ckpt") + if Path(ckpt_path).exists(): + self._fid_records.append((fid, epoch, ckpt_path)) + if self._fid_records: + print(f"[IntermediateFID] Resumed with {len(self._fid_records)} prior FID records " + f"(best={min(r[0] for r in self._fid_records):.4f})") + + # ------------------------------------------------------------------ + # Main hook + # ------------------------------------------------------------------ + + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + epoch = trainer.current_epoch + if (epoch + 1) % self.every_n_epochs != 0: + return + + # Checkpoint saving must happen on all ranks — DDP collects state across processes. + ckpt_path = str(self._fid_ckpt_dir / f"fid_epoch={epoch:04d}.ckpt") + trainer.save_checkpoint(ckpt_path) + + # Everything else (generation, FID, logging, pruning) only on rank 0. + if not trainer.is_global_zero: + return + + import numpy as np + import torch + from cents.eval.eval_metrics import Context_FID + + device = pl_module.device + dataset = self.dataset + + # Snapshot dataset.data, add rarity column, then restore after + orig_data = dataset.data + dataset.data = dataset.get_combined_rarity() + all_len = len(dataset.data) + n = min(self.n_samples, all_len) + rng = np.random.default_rng(42) + idx = rng.choice(all_len, size=n, replace=False) + real_data_subset = dataset.data.iloc[idx].reset_index(drop=True) + + # Build context tensors + continuous_vars = getattr(dataset, "continuous_vars", []) + static_context_vars = {} + for name in dataset.static_context_vars: + vals = real_data_subset[name].values + dtype = torch.float32 if name in continuous_vars else torch.long + static_context_vars[name] = torch.tensor(vals, dtype=dtype, device=device) + + dynamic_context_vars = {} + categorical_ts = getattr(dataset, "categorical_time_series", {}) + for name in dataset.dynamic_context_vars: + vals = real_data_subset[name].values + if len(vals) and hasattr(vals[0], "__len__") and not isinstance(vals[0], (str, bytes)): + arr = np.stack([ + np.asarray(v, dtype=np.float32 if name not in categorical_ts else np.int64) + for v in vals + ]) + else: + arr = np.asarray(vals, dtype=np.float32 if name not in categorical_ts else np.int64) + dtype = torch.long if name in categorical_ts else torch.float32 + dynamic_context_vars[name] = torch.tensor(arr, dtype=dtype, device=device) + + # Temporarily reduce sampling timesteps for speed, then restore + orig_sampling_timesteps = pl_module.sampling_timesteps + orig_fast_sampling = pl_module.fast_sampling + pl_module.sampling_timesteps = self.fast_timesteps + pl_module.fast_sampling = True + was_training = pl_module.training + pl_module.eval() + try: + generated_ts = pl_module.generate(static_context_vars, dynamic_context_vars).cpu().numpy() + finally: + pl_module.sampling_timesteps = orig_sampling_timesteps + pl_module.fast_sampling = orig_fast_sampling + dataset.data = orig_data + if was_training: + pl_module.train() + + if generated_ts.ndim == 2: + generated_ts = generated_ts.reshape(generated_ts.shape[0], -1, generated_ts.shape[1]) + + # Inverse-transform (mirrors evaluate_subset logic) + syn_data_subset = real_data_subset.copy() + syn_data_subset["timeseries"] = list(generated_ts) + normalizer = getattr(dataset, "_normalizer", None) + if not getattr(dataset, "normalize", True) and normalizer is not None: + def _inv(df): + split = dataset.split_timeseries(df.copy()) + split = normalizer.inverse_transform(split) + return dataset.merge_timeseries_columns(split) + real_data_inv = _inv(real_data_subset) + syn_data_inv = _inv(syn_data_subset) + else: + real_data_inv = dataset.inverse_transform(real_data_subset) + syn_data_inv = dataset.inverse_transform(syn_data_subset) + + real_data_array = np.stack(real_data_inv["timeseries"]) + syn_data_array = np.stack(syn_data_inv["timeseries"]) + + fid = Context_FID(real_data_array, syn_data_array) + print(f"[IntermediateFID] Epoch {epoch + 1}: Context-FID = {fid:.4f}") + + self._fid_records.append((fid, epoch, ckpt_path)) + self._prune_fid_checkpoints() + + self._log_csv(epoch + 1, fid) + self._log_wandb(trainer, epoch, fid) + pl_module.log("intermediate_context_fid", fid, on_step=False, on_epoch=True, prog_bar=True, sync_dist=False) + + # ------------------------------------------------------------------ + # Checkpoint pruning + # ------------------------------------------------------------------ + + def _prune_fid_checkpoints(self) -> None: + """Delete FID checkpoints outside the keep-set.""" + if len(self._fid_records) < 2: + return + + all_epochs = [r[1] for r in self._fid_records] + epoch_set = set(all_epochs) + + # Top-k by FID (ascending — lower is better) + sorted_records = sorted(self._fid_records, key=lambda x: x[0]) + top_k_epochs = {r[1] for r in sorted_records[: self.top_k]} + + # Neighbors: ±1 FID-check interval around each top-k epoch + neighbor_epochs: set = set() + for ep in top_k_epochs: + neighbor_epochs.add(ep - self.every_n_epochs) + neighbor_epochs.add(ep + self.every_n_epochs) + neighbor_epochs &= epoch_set # only those we actually have on disk + + # Rolling buffer: always keep the two most recent FID-check epochs + recent_epochs = set(sorted(all_epochs)[-2:]) + + keep_epochs = top_k_epochs | neighbor_epochs | recent_epochs + + new_records = [] + for fid, ep, path in self._fid_records: + if ep in keep_epochs: + new_records.append((fid, ep, path)) + else: + p = Path(path) + if p.exists(): + p.unlink() + print(f"[IntermediateFID] Pruned checkpoint epoch={ep} (FID={fid:.4f})") + self._fid_records = new_records + + # ------------------------------------------------------------------ + # Logging helpers + # ------------------------------------------------------------------ + + def _log_wandb(self, trainer, epoch: int, fid: float) -> None: + run = getattr(trainer.logger, "experiment", None) + if run is None: + return + sorted_records = sorted(self._fid_records, key=lambda x: x[0]) + top_k_epochs = [r[1] for r in sorted_records[: self.top_k]] + run.log( + { + "intermediate_context_fid": fid, + "fid_best": sorted_records[0][0] if sorted_records else float("nan"), + "fid_top_k_epochs": str(top_k_epochs), + }, + step=trainer.global_step, + ) + + def _log_csv(self, epoch: int, fid: float) -> None: + Path(self.cfg.run_dir).mkdir(parents=True, exist_ok=True) + if not self._header_written: + with open(self._csv_path, "w", newline="") as f: + csv.writer(f).writerow(["epoch", "context_fid"]) + self._header_written = True + with open(self._csv_path, "a", newline="") as f: + csv.writer(f).writerow([epoch, float(fid)]) diff --git a/cents/utils/config_loader.py b/cents/utils/config_loader.py new file mode 100644 index 0000000..7882107 --- /dev/null +++ b/cents/utils/config_loader.py @@ -0,0 +1,68 @@ +import os +from pathlib import Path +from typing import Any, List, Union + +from omegaconf import OmegaConf + + +Config = Union[dict, Any] + + +def load_yaml(path: Union[str, Path]) -> Config: + """ + Load a YAML file into an OmegaConf object. + """ + return OmegaConf.load(str(path)) + + +def deep_merge(*cfgs: Config) -> Config: + """ + Deep-merge multiple OmegaConf configs into one. + Later arguments override earlier ones. + """ + cfg = OmegaConf.create({}) + for c in cfgs: + if c is None: + continue + cfg = OmegaConf.merge(cfg, c) + return cfg + + +def _coerce_scalar(value: str): + v = value.strip() + if v.lower() in ("true", "false"): + return v.lower() == "true" + if v.lower() in ("null", "none"): + return None + try: + return int(v) + except ValueError: + try: + return float(v) + except ValueError: + return v + + +def apply_overrides(cfg: Config, overrides: List[str]) -> Config: + """ + Apply dot-path overrides like ["trainer.max_epochs=10", "dataset.normalize=False"]. + Works on OmegaConf configs. + """ + if not overrides: + return cfg + # materialize to a plain container then recreate, to ensure mutability + cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True)) + for s in overrides: + if "=" not in s: + continue + key, val = s.split("=", 1) + parts = key.split(".") + cur = cfg + for p in parts[:-1]: + if p not in cur or cur[p] is None: + cur[p] = {} + cur = cur[p] + cur[parts[-1]] = _coerce_scalar(val) + return cfg + + diff --git a/cents/utils/utils.py b/cents/utils/utils.py index 2beb7ff..5c4b3cf 100644 --- a/cents/utils/utils.py +++ b/cents/utils/utils.py @@ -7,8 +7,48 @@ ROOT_DIR = Path(__file__).parent.parent -def _ckpt_name(dataset: str, model: str, dims: int, *, ext: str = "ckpt") -> str: - return f"{dataset}_{model}_dim{dims}.{ext}" +def _ckpt_name( + dataset: str, + model: str, + dims: int, + *, + ext: str = "ckpt", + static_module_type: str = None, + stats_head_type: str = None, + dynamic_module_type: str = None, + use_global_stats_preprocessing: bool = True, +) -> str: + """ + Generate checkpoint filename with optional context_module_type and stats_head_type. + + Args: + dataset: Dataset name + model: Model name + dims: Number of dimensions + ext: File extension (default: "ckpt") + static_module_type: Optional context module type (e.g., "mlp", "sep_mlp") + stats_head_type: Optional stats head type (e.g., "mlp") + dynamic_module_type: Optional dynamic module type + use_global_stats_preprocessing: If False, suffix "noglobal" so direct-prediction normalizer uses a separate cache. + + Returns: + Formatted checkpoint filename + """ + parts = [dataset, model, f"dim{dims}"] + + if static_module_type: + parts.append(f"ctx{static_module_type}") + + if stats_head_type: + parts.append(f"stats{stats_head_type}") + + if dynamic_module_type: + parts.append(f"dyn{dynamic_module_type}") + + if not use_global_stats_preprocessing: + parts.append("noglobal") + + return "_".join(parts) + f".{ext}" def parse_dims_from_name(model_name: str) -> str: @@ -36,6 +76,68 @@ def get_normalizer_training_config(): ) return OmegaConf.load(config_path) +_context_config_path = None +_context_overrides = [] + + +def set_context_config_path(path: str): + """ + Set a custom path for the context configuration file. + This path will be used by get_context_config() instead of the default. + + Args: + path: Path to the context config YAML file. If None, resets to default. + """ + global _context_config_path + _context_config_path = path + + +def set_context_overrides(overrides: list): + """ + Set overrides to apply to the context configuration. + + Args: + overrides: List of override strings (e.g., ["static_context.type=mlp", "dynamic_context.type=cnn"]) + """ + global _context_overrides + _context_overrides = overrides if overrides else [] + + +def get_context_config(path: str = None): + """ + Load the context configuration from config/context/default.yaml or a custom path. + Overrides can be applied if set via set_context_overrides(). + + Args: + path: Optional path to a custom context config file. If None, uses the path + set by set_context_config_path() or defaults to config/context/default.yaml. + + Returns: + OmegaConf config with static_context, normalizer, and dynamic_context sections. + """ + if path is not None: + config_path = path + elif _context_config_path is not None: + print(f"Using custom context config path: {_context_config_path}") + config_path = _context_config_path + else: + config_path = os.path.join( + ROOT_DIR, + "config", + "context", + "default.yaml", + ) + + cfg = OmegaConf.load(config_path) + + # Apply overrides if any + if _context_overrides: + from cents.utils.config_loader import apply_overrides + cfg = apply_overrides(cfg, _context_overrides) + print(f"Applied context config overrides: {_context_overrides}") + + return cfg + def get_default_trainer_config(): config_path = os.path.join( diff --git a/scripts/eval_pretrained.py b/scripts/eval_pretrained.py index 5b4f0ea..2f7a260 100644 --- a/scripts/eval_pretrained.py +++ b/scripts/eval_pretrained.py @@ -1,54 +1,531 @@ import logging -from datetime import datetime +import math +import os +import random from pathlib import Path +import json + +import numpy as np +import torch +import torch.nn.functional as F -from hydra import compose, initialize_config_dir from omegaconf import OmegaConf +import argparse -import wandb from cents.data_generator import DataGenerator +from cents.models.registry import get_model_type_from_hf_name from cents.datasets.pecanstreet import PecanStreetDataset +from cents.datasets.commercial import CommercialDataset +from cents.datasets.airquality import AirQualityDataset +from cents.datasets.metraq import MetraqDataset +from cents.datasets.walmart import WalmartDataset from cents.eval.eval import Evaluator +from cents.utils.config_loader import load_yaml, apply_overrides +from cents.utils.utils import set_context_config_path, set_context_overrides + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", +) +DATASET_OVERRIDES = ["normalize=False", "max_samples=10000"] +PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] + +CONFIG_DATASET_DIR = Path(__file__).resolve().parent.parent / "cents" / "config" / "dataset" + + +def _load_dataset_config(dataset_name: str, overrides: list) -> OmegaConf: + """Load dataset-specific config from config/dataset/{dataset_name}.yaml and apply overrides.""" + config_path = CONFIG_DATASET_DIR / f"{dataset_name}.yaml" + if not config_path.exists(): + raise ValueError( + f"Dataset config not found for '{dataset_name}' at {config_path}. " + f"Available: {[p.name for p in CONFIG_DATASET_DIR.glob('*.yaml')]}" + ) + cfg = load_yaml(str(config_path)) + if overrides: + cfg = apply_overrides(cfg, overrides) + return cfg + + +def _load_dataset(name: str, dataset_cfg: OmegaConf, run_dir: str = None): + """Load a dataset by name using dataset-specific config. Optionally pass run_dir for normalizer/cache.""" + kwargs = {"cfg": dataset_cfg} + if run_dir is not None: + kwargs["run_dir"] = run_dir + if name == "pecanstreet": + return PecanStreetDataset(**kwargs) + if name == "commercial": + return CommercialDataset(**kwargs) + if name == "airquality": + return AirQualityDataset(**kwargs) + if name == "metraq": + return MetraqDataset(**kwargs) + if name == "walmart": + return WalmartDataset(**kwargs) + raise ValueError(f"Dataset {name} not supported. Use: pecanstreet, commercial, airquality.") -MODEL_KEY = "Watts_1_2D" -OVERRIDES = [ - "dataset.user_group=pv_users", - "dataset.time_series_dims=2", - "evaluator.eval_disentanglement=False", - "wandb.enabled=True", - "wandb.project=cents", - "wandb.entity=michael-fuest-technical-university-of-munich", - f"wandb.name=eval_{MODEL_KEY}_{datetime.now().strftime('%Y%m%d-%H%M%S')}_dim2", -] + +def _find_checkpoint_by_epoch(checkpoint_dir: Path, epoch: int) -> Path: + """Return path to a checkpoint matching the given epoch (e.g. epoch=0699). Prefer 4-digit zero-padded.""" + checkpoint_dir = Path(checkpoint_dir) + if not checkpoint_dir.exists(): + raise FileNotFoundError(f"Checkpoint dir not found: {checkpoint_dir}") + # Lightning ModelCheckpoint with epoch in filename: ..._epoch=0699.ckpt or ..._epoch=699.ckpt + pattern_4 = f"*epoch={epoch:04d}*" + pattern_1 = f"*epoch={epoch}*" + for pattern in (pattern_4, pattern_1): + matches = list(checkpoint_dir.glob(pattern)) + if matches: + return matches[0] + raise FileNotFoundError(f"No checkpoint for epoch {epoch} in {checkpoint_dir}") + + +def _resolve_run_path_config(run_path: Path, epoch: int = None): + """ + Load configs from run_path/config/ (or run_path/summary.yaml for older runs) and resolve checkpoint and epoch. + Returns (dataset_cfg, model_cfg, context_path, model_ckpt_path, normalizer_dir, metrics_epoch). + metrics_epoch is the epoch number or 'last' for saving metrics_{epoch}.json. + """ + run_path = Path(run_path) + config_dir = run_path / "config" + if config_dir.exists(): + dataset_cfg = OmegaConf.load(str(config_dir / "dataset.yaml")) + model_cfg = OmegaConf.load(str(config_dir / "model.yaml")) + context_path = config_dir / "context.yaml" + if not context_path.exists(): + context_path = None + else: + summary_path = run_path / "summary.yaml" + if not summary_path.exists(): + raise FileNotFoundError( + f"Run has no config/ or summary.yaml at {run_path}. Train with current code to write configs." + ) + summary = load_yaml(str(summary_path)) + dataset_cfg = OmegaConf.create(summary.get("dataset", {})) + model_cfg = OmegaConf.create(summary.get("model", {})) + config_dir.mkdir(parents=True, exist_ok=True) + if summary.get("context"): + OmegaConf.save(OmegaConf.create(summary["context"]), str(config_dir / "context.yaml")) + context_path = config_dir / "context.yaml" if summary.get("context") else None + checkpoint_dir = run_path / "checkpoints" + if epoch is not None: + model_ckpt_path = _find_checkpoint_by_epoch(checkpoint_dir, epoch) + metrics_epoch = epoch + else: + last_ckpt = checkpoint_dir / "last.ckpt" + if not last_ckpt.exists(): + raise FileNotFoundError( + f"No last.ckpt in {checkpoint_dir}. Specify --epoch or ensure training saved last.ckpt." + ) + model_ckpt_path = last_ckpt + metrics_epoch = "last" + normalizer_dir = run_path / "normalizer" + return dataset_cfg, model_cfg, context_path, model_ckpt_path, normalizer_dir, metrics_epoch def main() -> None: - logging.basicConfig( - level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" + parser = argparse.ArgumentParser( + description="Evaluate a trained model using comprehensive metrics.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model-ckpt", + type=str, + default=None, + help="Path to model checkpoint (.ckpt or .pt). Required unless --model-key is set.", + ) + parser.add_argument( + "--model-key", + type=str, + default=None, + help="HuggingFace model key (e.g. Watts_2_1D). If set, model and normalizer are loaded from HF instead of --model-ckpt.", + ) + parser.add_argument( + "--normalizer-ckpt", + type=str, + default=None, + help="Path to normalizer checkpoint. If omitted, evaluation uses normalized space (or HF normalizer when using --model-key).", + ) + parser.add_argument( + "--model-type", + type=str, + default=None, + help="Model type (e.g. diffusion_ts) used to load the checkpoint. Inferred from --model-key when loading from HF.", + ) + parser.add_argument( + "--dataset", + type=str, + default="pecanstreet", + choices=("pecanstreet", "commercial", "airquality", "metraq", "walmart"), + help="Dataset name (must match the one used to train the model).", + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="Device index to use for evaluation.", + ) + parser.add_argument( + "--dataset-overrides", + type=str, + nargs="*", + default=[], + help="Extra dataset overrides, e.g. time_series_dims=2 for multivariate. Config after overrides sets model shape.", + ) + parser.add_argument( + "--save-dir", + type=str, + default=None, + help="Directory to save evaluation results. If None, uses checkpoint parent + /eval or outputs/eval/ when using --model-key.", + ) + parser.add_argument( + "--job-name", + type=str, + default=None, + help="Job name for evaluation run. If None, uses default from evaluator config.", + ) + parser.add_argument( + "--evaluator-config", + type=str, + default="cents/config/evaluator/default.yaml", + help="Path to evaluator config YAML file.", + ) + parser.add_argument( + "--config", + type=str, + default="cents/config/config.yaml", + help="Path to main config YAML file.", + ) + parser.add_argument( + "--ema", + action="store_true", + help="Enable EMA sampling.", + ) + parser.add_argument( + "--eval-pv-shift", + action="store_true", + help="Enable PV shift evaluation.", + ) + parser.add_argument( + "--no-eval-metrics", + action="store_true", + help="Disable evaluation metrics computation.", + ) + parser.add_argument( + "--no-eval-context-sparse", + action="store_true", + help="Disable sparse context evaluation.", + ) + parser.add_argument( + "--no-eval-disentanglement", + action="store_true", + help="Disable disentanglement evaluation.", + ) + parser.add_argument( + "--no-save-results", + action="store_true", + help="Disable saving evaluation results.", + ) + parser.add_argument( + "--context-config-path", + type=str, + default=None, + help="Path to custom context config YAML file (optional).", ) + parser.add_argument( + "--context-overrides", + type=str, + nargs="*", + default=[], + help="Override context config values (e.g., 'static_context.type=mlp' 'dynamic_context.type=cnn').", + ) + parser.add_argument( + "--no-normalizer-global-preprocessing", + action="store_true", + help="Use normalizer without global-stats preprocessing (match training that used --no-normalizer-global-preprocessing).", + ) + parser.add_argument( + "--run-path", + type=str, + default=None, + help="Path to a run directory (e.g. runs/commercial_noglobal). Load config from run/config/, checkpoint from run/checkpoints/. Saves metrics to run/metrics/metrics_{epoch}.json.", + ) + parser.add_argument( + "--epoch", + type=int, + default=None, + help="Epoch number to evaluate when using --run-path (e.g. 699). If omitted, use last.ckpt and save as metrics_last.json.", + ) + parser.add_argument( + "--max-samples", + type=int, + default=None, + help="Limit evaluation to this many samples (applied as dataset max_samples override).", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducible sampling (sets Python, NumPy, and PyTorch seeds).", + ) + parser.add_argument( + "--cfg-scale", + type=float, + default=1.0, + help=( + "Classifier-free guidance scale (default 1.0 = no guidance). " + "Values >1 blend unconditional and conditional noise predictions: " + "pred = uncond + scale*(cond - uncond). " + "Requires model trained with context_embed_dropout > 0. " + "Only applies to fast (DDIM) sampling." + ), + ) + parser.add_argument( + "--save-path", + type=str, + default=None, + help="Path to save evaluation results." + ) + parser.add_argument( + "--model-config", + type=str, + default=None, + help="Path to a model config YAML file. Overrides the default cents/config/model/{model_type}.yaml when using --model-ckpt.", + ) + + args = parser.parse_args() + + use_run_path = args.run_path is not None + if use_run_path and args.model_key: + parser.error("Do not use --model-key with --run-path.") + if not use_run_path and not args.model_ckpt and not args.model_key: + parser.error("One of --model-ckpt, --model-key, or --run-path is required.") + if use_run_path and args.model_ckpt: + parser.error("Do not use --model-ckpt with --run-path; checkpoint is resolved from run-path and --epoch.") + + # Set random seed before any dataset loading so that subsampling is reproducible + if args.seed is not None: + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + logging.info("Random seed set to %d", args.seed) - if wandb.run is None: - wandb.init( - project="cents", - name=f"{MODEL_KEY}-eval-only-run_{datetime.now().strftime('%Y%m%d-%H%M%S')}", - entity="michael-fuest-technical-university-of-munich", + # Set custom context config path if provided + if args.context_config_path: + set_context_config_path(args.context_config_path) + + # Set context config overrides if provided + if args.context_overrides: + set_context_overrides(args.context_overrides) + + if use_run_path: + run_path = Path(args.run_path).resolve() + logging.info("Using run-path: %s (epoch=%s)", run_path, args.epoch) + dataset_cfg, model_cfg, context_path, model_ckpt_path, normalizer_dir, metrics_epoch = _resolve_run_path_config( + run_path, args.epoch ) + if context_path is not None: + set_context_config_path(str(context_path)) + # Apply dataset overrides (e.g. max_samples) so eval uses the requested subset. + # normalize=False prevents the dataset from re-training or reloading a normalizer + # during init; the normalizer checkpoint is loaded explicitly below instead. + overrides = DATASET_OVERRIDES + if getattr(args, "max_samples", None) is not None: + overrides.append(f"max_samples={args.max_samples}") + if args.dataset_overrides: + overrides.extend(args.dataset_overrides) + dataset_cfg = apply_overrides(dataset_cfg, overrides) + logging.info("Applied dataset overrides: %s", overrides) + dataset_name = dataset_cfg.get("name", "pecanstreet") + dataset = _load_dataset(dataset_name, dataset_cfg, run_dir=str(run_path)) + model_type = model_cfg.get("name", "diffusion_ts") + + # Resolve normalizer checkpoint from the run's normalizer directory so that + # z-space stats use the exact same normalizer the model was trained with. + normalizer_ckpts = sorted(normalizer_dir.glob("*.ckpt")) + if not normalizer_ckpts: + raise FileNotFoundError( + f"No normalizer checkpoint found in {normalizer_dir}. " + "Ensure the run was trained with the current code that saves normalizer/*.ckpt." + ) + run_normalizer_ckpt = normalizer_ckpts[0] + if len(normalizer_ckpts) > 1: + logging.warning( + "Multiple normalizer checkpoints found in %s; using %s", + normalizer_dir, run_normalizer_ckpt.name, + ) + logging.info("Using normalizer checkpoint: %s", run_normalizer_ckpt) + eval_cfg = load_yaml(args.evaluator_config) + top_cfg = load_yaml(args.config) + cfg = OmegaConf.create({}) + cfg.evaluator = eval_cfg + cfg.wandb = top_cfg.get("wandb", {}) + cfg.device = f"cuda:{args.device}" + cfg.model = model_cfg + cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) + cfg.model.use_ema_sampling = args.ema + cfg.eval_pv_shift = args.eval_pv_shift if args.eval_pv_shift else eval_cfg.get("eval_pv_shift", False) + cfg.eval_metrics = False if args.no_eval_metrics else eval_cfg.get("eval_metrics", True) + cfg.eval_context_sparse = False if args.no_eval_context_sparse else eval_cfg.get("eval_context_sparse", True) + cfg.eval_disentanglement = False if args.no_eval_disentanglement else eval_cfg.get("eval_disentanglement", True) + cfg.save_results = False if args.no_save_results else True + cfg.job_name = args.job_name or eval_cfg.get("job_name", "default_job") + cfg.save_dir = run_path / "metrics" + cfg.save_dir.mkdir(parents=True, exist_ok=True) + logging.info("Loading dataset from run config (run_dir=%s)...", run_path) + logging.info("Model checkpoint: %s", model_ckpt_path) + gen = DataGenerator(model_type=model_type, dataset=dataset, cfg=cfg) + gen.load_from_checkpoint(str(model_ckpt_path), normalizer_ckpt=str(run_normalizer_ckpt)) + args._metrics_epoch = metrics_epoch + else: + args._metrics_epoch = None + logging.info("Loading dataset %s...", args.dataset) + overrides = list(DATASET_OVERRIDES) + if args.dataset == "pecanstreet": + overrides = overrides + PECAN_OVERRIDES + if args.model_key: + overrides = overrides + ["normalize=False"] + if args.model_key: + overrides = overrides + ["scale=True"] + if args.dataset_overrides: + overrides = overrides + list(args.dataset_overrides) + if getattr(args, "max_samples", None) is not None: + overrides = overrides + [f"max_samples={args.max_samples}"] + dataset_cfg = _load_dataset_config(args.dataset, overrides) + dataset = _load_dataset(args.dataset, dataset_cfg) - CONF_DIR = Path(__file__).resolve().parents[1] / "cents" / "config" - with initialize_config_dir(str(CONF_DIR), version_base=None): - cfg = compose(config_name="config", overrides=[f"model=acgan"] + OVERRIDES) + if args.model_key: + model_type = get_model_type_from_hf_name(args.model_key) + else: + model_type = args.model_type or "diffusion_ts" - ds_overrides = [ - o.split("dataset.")[1] for o in OVERRIDES if o.startswith("dataset.") - ] - dataset = PecanStreetDataset(overrides=ds_overrides) - cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) + eval_cfg = load_yaml(args.evaluator_config) + top_cfg = load_yaml(args.config) - # Use the fixed checkpoint with DataGenerator - gen = DataGenerator(MODEL_KEY, cfg=cfg) + cfg = OmegaConf.create({}) + cfg.evaluator = eval_cfg + cfg.wandb = top_cfg.get("wandb", {}) + cfg.device = f"cuda:{args.device}" + model_config_path = args.model_config if args.model_config else f"cents/config/model/{model_type}.yaml" + cfg.model = OmegaConf.create( + OmegaConf.to_container(OmegaConf.load(model_config_path), resolve=True) + ) + cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) + if args.no_normalizer_global_preprocessing: + cfg.dataset.normalizer_use_global_stats_preprocessing = False + + cfg.model.use_ema_sampling = args.ema + cfg.eval_pv_shift = args.eval_pv_shift if args.eval_pv_shift else eval_cfg.get("eval_pv_shift", False) + cfg.eval_metrics = False if args.no_eval_metrics else eval_cfg.get("eval_metrics", True) + cfg.eval_context_sparse = False if args.no_eval_context_sparse else eval_cfg.get("eval_context_sparse", True) + cfg.eval_disentanglement = False if args.no_eval_disentanglement else eval_cfg.get("eval_disentanglement", True) + cfg.save_results = False if args.no_save_results else True + cfg.job_name = args.job_name if args.job_name else eval_cfg.get("job_name", "default_job") + + if args.save_dir: + cfg.save_dir = Path(args.save_dir) + elif args.model_key: + cfg.save_dir = Path("outputs/eval") / args.model_key + else: + model_ckpt_path = Path(args.model_ckpt) + cfg.save_dir = model_ckpt_path.parent / "eval" + + if not os.path.exists(cfg.save_dir): + os.makedirs(cfg.save_dir, exist_ok=True) + logging.info("Created evaluation directory: %s", cfg.save_dir) + + use_hf = args.model_key is not None + if use_hf: + logging.info("Setting up DataGenerator from HuggingFace (model_key=%s)...", args.model_key) + gen = DataGenerator(model_name=args.model_key, dataset=dataset, cfg=cfg) + else: + logging.info("Setting up DataGenerator (model_type=%s)...", model_type) + gen = DataGenerator(model_type=model_type, dataset=dataset, cfg=cfg) + logging.info("Loading checkpoint... EMA sampling %s", "enabled" if cfg.model.use_ema_sampling else "disabled") + gen.load_from_checkpoint(args.model_ckpt, args.normalizer_ckpt) + + # Ensure EMA setting is applied to the config used by the model at generate time + target = getattr(gen.model, "cfg", None) or gen.cfg + if target is not None and hasattr(target, "model"): + target.model.use_ema_sampling = cfg.model.use_ema_sampling + + # gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) + cfg.dataset = gen.model.cfg.dataset + + # Set CFG scale on the model instance (read by generate() at inference time) + if args.cfg_scale != 1.0: + logging.info("Classifier-free guidance scale: %.2f", args.cfg_scale) + gen.model._cfg_scale = args.cfg_scale + + print("\n" + "=" * 60) + print("EVALUATION RESULTS") + print("=" * 60 + "\n") results = Evaluator(cfg, dataset).evaluate_model(data_generator=gen) - print(results) + + print("\n📊 METRICS (raw domain):") + print("-" * 60) + metrics = results.get("metrics", {}) + normalized = metrics.pop("normalized_domain", None) + + def _print_metrics(m, prefix=" "): + for key, value in m.items(): + if key == "rare_subset": + print(f"\n{prefix}rare_subset:") + _print_metrics(value, prefix=prefix + " ") + elif isinstance(value, dict) and "mean" in value and "std" in value: + print(f"{prefix}{key}: mean={value['mean']:.6f}, std={value['std']:.6f}") + elif isinstance(value, dict): + print(f"\n{prefix}{key}:") + _print_metrics(value, prefix=prefix + " ") + elif isinstance(value, (int, float)): + print(f"{prefix}{key}: {value:.6f}") + else: + print(f"{prefix}{key}: {value}") + + for key, value in metrics.items(): + if isinstance(value, dict) and "rare_subset" not in value: + print(f"\n{key}:") + _print_metrics(value) + elif isinstance(value, dict): + print(f"\n{key}:") + _print_metrics(value) + else: + print(f"{key}: {value:.6f}" if isinstance(value, (int, float)) else f"{key}: {value}") + + if normalized is not None: + print("\n📊 METRICS (normalized domain, z-space — comparable across domains):") + print("-" * 60) + _print_metrics(normalized, prefix="") + metrics["normalized_domain"] = normalized # restore for save + + def _sanitize_for_json(obj): + """Recursively replace NaN/Inf with None so json.dump produces valid JSON.""" + if isinstance(obj, dict): + return {k: _sanitize_for_json(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_sanitize_for_json(v) for v in obj] + if isinstance(obj, float) and not math.isfinite(obj): + return None + return obj + + # Results are automatically saved if save_results=True + if use_run_path and getattr(args, "_metrics_epoch", None) is not None: + metrics_file = cfg.save_dir / f"metrics_{args._metrics_epoch}.json" + with open(metrics_file, "w") as f: + json.dump(_sanitize_for_json(metrics), f, indent=4) + print(f"\n✅ Results saved to {metrics_file}") + elif args.save_dir: + with open(Path(args.save_dir) / "metrics.json", "w") as f: + json.dump(_sanitize_for_json(metrics), f, indent=4) + print(f"\n✅ Results saved to {Path(args.save_dir) / "metrics.json"}") + elif args.save_path: + with open(args.save_path, "w") as f: + json.dump(_sanitize_for_json(metrics), f, indent=4) + print(f"\n✅ Results saved to {args.save_path}") + print("\n" + "=" * 60) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/scripts/generate.py b/scripts/generate.py new file mode 100644 index 0000000..9d88682 --- /dev/null +++ b/scripts/generate.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +""" +Generate synthetic time series samples from a trained model. + +Supports: + - Random context: sample context from the dataset's support (including continuous). + - Explicit context: provide context as JSON (categorical: int codes; continuous: z-scored floats). + - Sample rows: sample full context (static + dynamic) from real dataset rows, preserving correlations. + - Output to Parquet (default) or CSV. +""" + +import argparse +import json +import logging +import os +import random +from pathlib import Path + +import numpy as np +import torch +from omegaconf import OmegaConf + +from cents.data_generator import DataGenerator +from cents.datasets.pecanstreet import PecanStreetDataset +from cents.datasets.commercial import CommercialDataset +from cents.datasets.airquality import AirQualityDataset +from cents.datasets.walmart import WalmartDataset +from cents.datasets.utils import convert_generated_data_to_df +from cents.utils.config_loader import load_yaml, apply_overrides +from cents.utils.utils import set_context_config_path, set_context_overrides + +CONFIG_DATASET_DIR = Path(__file__).resolve().parent.parent / "cents" / "config" / "dataset" + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", +) +DATASET_OVERRIDES = ["normalize=False", "max_samples=10000", "skip_heavy_processing=True"] +PECAN_OVERRIDES = ["time_series_dims=1", "user_group=all"] + + +def _load_dataset_config(dataset_name: str, overrides: list) -> OmegaConf: + config_path = CONFIG_DATASET_DIR / f"{dataset_name}.yaml" + cfg = load_yaml(str(config_path)) + if overrides: + cfg = apply_overrides(cfg, overrides) + return cfg + + +def _load_dataset(name: str, dataset_cfg: OmegaConf): + kwargs = {"cfg": dataset_cfg} + if name == "pecanstreet": + return PecanStreetDataset(**kwargs) + if name == "commercial": + return CommercialDataset(**kwargs) + if name == "airquality": + return AirQualityDataset(**kwargs) + if name == "walmart": + return WalmartDataset(**kwargs) + raise ValueError(f"Dataset {name} not supported. Use: pecanstreet, commercial, airquality, walmart.") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate synthetic time series from a trained model.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model-ckpt", + type=str, + required=True, + help="Path to model checkpoint (.ckpt or .pt).", + ) + parser.add_argument( + "--normalizer-ckpt", + type=str, + default=None, + help="Path to normalizer checkpoint. If omitted, output stays in normalized space.", + ) + parser.add_argument( + "--model-type", + type=str, + default="diffusion_ts", + help="Model type (e.g. diffusion_ts) used to load the checkpoint.", + ) + parser.add_argument( + "--dataset", + type=str, + default="pecanstreet", + choices=("pecanstreet", "commercial", "airquality", "walmart"), + help="Dataset name (must match the one used to train the model).", + ) + parser.add_argument( + "-n", + "--num-samples", + type=int, + default=100, + help="Number of samples to generate.", + ) + parser.add_argument( + "-o", + "--out", + type=str, + default="samples.parquet", + help="Output path for generated samples.", + ) + parser.add_argument( + "--random-context", + action="store_true", + help="Sample context randomly from the dataset support (categorical and continuous).", + ) + parser.add_argument( + "--sample-rows", + action="store_true", + help=( + "Sample full context (static + dynamic) from real dataset rows. " + "Preserves correlations between covariates. " + "Samples with replacement if n > len(dataset)." + ), + ) + parser.add_argument( + "--context", + type=str, + default=None, + help='Explicit context as JSON, e.g. \'{"weekday":0,"month":3}\'. ' + "Categorical: int codes. Continuous: z-scored (normalized) floats.", + ) + parser.add_argument( + "--context-config-path", + type=str, + default=None, + help="Path to custom context config YAML (optional).", + ) + parser.add_argument( + "--dataset-overrides", + type=str, + nargs="*", + default=[], + help="Extra dataset overrides, e.g. time_series_dims=1.", + ) + parser.add_argument( + "--context-overrides", + type=str, + nargs="*", + default=[], + help="Override context config values (e.g., 'static_context.type=mlp' 'dynamic_context.type=cnn').", + ) + parser.add_argument( + "--no-ema", + action="store_true", + help="Disable EMA sampling (EMA is used by default when present in the checkpoint).", + ) + parser.add_argument( + "--stochastic-round", + action="store_true", + help=( + "Round output to non-negative integers using stochastic rounding " + "(floor + Bernoulli on fractional part). Applied after inverse-transform. " + "Useful for count-valued datasets such as Walmart unit sales." + ), + ) + parser.add_argument( + "--model-config", + type=str, + default=None, + help="Path to a model config YAML file. Overrides the default cents/config/model/{model_type}.yaml.", + ) + parser.add_argument( + "--dataset-config", + type=str, + default=None, + help="Path to a dataset config YAML file. Overrides the default cents/config/dataset/{dataset}.yaml.", + ) + args = parser.parse_args() + + use_random = args.random_context + use_explicit = args.context is not None and args.context.strip() != "" + use_rows = args.sample_rows + n_modes = sum([use_random, use_explicit, use_rows]) + if n_modes == 0: + parser.error("Provide one of --random-context, --context (JSON), or --sample-rows.") + if n_modes > 1: + parser.error("Provide only one of --random-context, --context, or --sample-rows.") + + if args.context_config_path: + set_context_config_path(args.context_config_path) + + if args.context_overrides: + set_context_overrides(args.context_overrides) + + base_overrides = list(DATASET_OVERRIDES) + if args.dataset == "pecanstreet": + base_overrides += PECAN_OVERRIDES + if args.dataset_overrides: + base_overrides += list(args.dataset_overrides) + + logging.info("Loading dataset %s...", args.dataset) + if args.dataset_config: + dataset_cfg = OmegaConf.load(args.dataset_config) + dataset_cfg = apply_overrides(dataset_cfg, base_overrides) + else: + dataset_cfg = _load_dataset_config(args.dataset, base_overrides) + dataset = _load_dataset(args.dataset, dataset_cfg) + cfg = OmegaConf.create({}) + model_config_path = args.model_config if args.model_config else f"cents/config/model/{args.model_type}.yaml" + cfg.model = OmegaConf.create(OmegaConf.to_container(OmegaConf.load(model_config_path), resolve=True)) + cfg.dataset = OmegaConf.create(OmegaConf.to_container(dataset.cfg, resolve=True)) + cfg.model.use_ema_sampling = not args.no_ema + + logging.info("Setting up DataGenerator (model_type=%s)...", args.model_type) + gen = DataGenerator(model_type=args.model_type, dataset=dataset, cfg=cfg) + gen.load_from_checkpoint(args.model_ckpt, args.normalizer_ckpt) + # Ensure EMA setting is applied to the config used by the model at generate time + target = getattr(gen.model, "cfg", None) or gen.cfg + if target is not None and hasattr(target, "model"): + target.model.use_ema_sampling = cfg.model.use_ema_sampling + gen.set_dataset_spec(gen.model.cfg.dataset, dataset.get_context_var_codes()) + + if use_rows: + # Sample full context (static + dynamic) from real dataset rows. + # With-replacement sampling handles n > len(dataset). + indices = random.choices(range(len(dataset)), k=args.num_samples) + samples = [dataset[i] for i in indices] + static_batch = { + k: torch.stack([s[1][k] for s in samples]).to(gen.device) + for k in samples[0][1].keys() + } + dynamic_batch = { + k: torch.stack([s[2][k] for s in samples]).to(gen.device) + for k in samples[0][2].keys() + } if samples[0][2] else {} + logging.info( + "Generating %d samples conditioned on %d real rows (static: %s, dynamic: %s)...", + args.num_samples, args.num_samples, + list(static_batch.keys()), list(dynamic_batch.keys()), + ) + with torch.no_grad(): + ts = gen.model.generate(static_batch, dynamic_batch or None) + ctx_batch = {**static_batch, **dynamic_batch} + df = convert_generated_data_to_df(ts, ctx_batch, decode=False) + elif use_random: + # Sample a new random context for each of the n samples + contexts = [dataset.sample_random_context_vars() for _ in range(args.num_samples)] + ctx_batch = { + k: torch.stack([c[k] for c in contexts]).to(gen.device) + for k in contexts[0].keys() + } + logging.info("Generating %d samples with %d random contexts...", args.num_samples, args.num_samples) + with torch.no_grad(): + ts = gen.model.generate(ctx_batch) + df = convert_generated_data_to_df(ts, ctx_batch, decode=False) + else: + context_dict = json.loads(args.context) + for k, v in context_dict.items(): + if isinstance(v, float): + context_dict[k] = v + else: + context_dict[k] = int(v) + gen.set_context(**context_dict) + logging.info("Generating %d samples with context %s...", args.num_samples, context_dict) + df = gen.generate(args.num_samples) + + if gen.normalizer is not None: + df = gen.normalizer.inverse_transform(df) + logging.info("Inverse-normalized outputs to original scale.") + else: + logging.warning("No normalizer loaded; outputs are in normalized space.") + + if args.stochastic_round: + def _stochastic_round(x): + x = np.clip(x, 0, None) + floor = np.floor(x).astype(int) + return floor + (np.random.random(x.shape) < (x - floor)).astype(int) + + col_name = dataset_cfg.time_series_columns[0] + df[col_name] = df[col_name].apply(_stochastic_round) + logging.info("Applied stochastic rounding to integer counts.") + + out = Path(args.out) + out.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(out, index=False) + logging.info("Wrote %d samples to %s", len(df), out.resolve()) + + +if __name__ == "__main__": + main() diff --git a/scripts/train.py b/scripts/train.py index 8bada3f..11961a5 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -1,26 +1,174 @@ from datetime import datetime +import yaml +from pathlib import Path +import numpy as np +import os +import torch +import random from cents.datasets.pecanstreet import PecanStreetDataset +from cents.datasets.commercial import CommercialDataset +from cents.datasets.airquality import AirQualityDataset +from cents.datasets.metraq import MetraqDataset +from cents.datasets.walmart import WalmartDataset + from cents.trainer import Trainer +from cents.utils.utils import set_context_config_path, set_context_overrides, get_context_config +from cents.utils.config_loader import load_yaml, apply_overrides +from omegaconf import OmegaConf +import warnings +import argparse + +warnings.simplefilter(action='ignore', category=FutureWarning) + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +RUNS_DIR = PROJECT_ROOT / "runs" +CONFIG_DATASET_DIR = PROJECT_ROOT / "cents" / "config" / "dataset" + + +def _load_dataset_config(dataset_name: str, overrides: list) -> OmegaConf: + """Load dataset-specific config from config/dataset/{dataset_name}.yaml and apply overrides.""" + config_path = CONFIG_DATASET_DIR / f"{dataset_name}.yaml" + if not config_path.exists(): + raise ValueError( + f"Dataset config not found for '{dataset_name}' at {config_path}. " + f"Available: {[p.name for p in CONFIG_DATASET_DIR.glob('*.yaml')]}" + ) + cfg = load_yaml(str(config_path)) + if overrides: + cfg = apply_overrides(cfg, overrides) + return cfg + + +def _write_run_summary(run_dir: Path, run_name: str, trainer: Trainer, random_seed: int) -> None: + """Write a summary YAML of run choices (context, model, dataset, trainer) to run_dir.""" + cfg = trainer.cfg + context_cfg = get_context_config() + summary = { + "run_name": run_name, + "run_dir": str(run_dir), + "random_seed": random_seed, + "dataset": OmegaConf.to_container(cfg.dataset, resolve=True) if hasattr(cfg, "dataset") and cfg.dataset else {}, + "model": OmegaConf.to_container(cfg.model, resolve=True) if hasattr(cfg, "model") and cfg.model else {}, + "context": OmegaConf.to_container(context_cfg, resolve=True) if context_cfg else {}, + "trainer": OmegaConf.to_container(cfg.trainer, resolve=True) if hasattr(cfg, "trainer") and cfg.trainer else {}, + } + path = run_dir / "summary.yaml" + with open(path, "w") as f: + yaml.dump(summary, f, default_flow_style=False, sort_keys=False) + print(f"[Cents] Wrote run summary to {path}") + + +def _write_run_configs(run_dir: Path, trainer: Trainer) -> None: + """Write dataset, model, context, and trainer configs to run_dir/config/ for eval and reproducibility.""" + cfg = trainer.cfg + context_cfg = get_context_config() + config_dir = run_dir / "config" + config_dir.mkdir(parents=True, exist_ok=True) + if hasattr(cfg, "dataset") and cfg.dataset: + with open(config_dir / "dataset.yaml", "w") as f: + yaml.dump(OmegaConf.to_container(cfg.dataset, resolve=True), f, default_flow_style=False, sort_keys=False) + if hasattr(cfg, "model") and cfg.model: + with open(config_dir / "model.yaml", "w") as f: + yaml.dump(OmegaConf.to_container(cfg.model, resolve=True), f, default_flow_style=False, sort_keys=False) + if hasattr(cfg, "trainer") and cfg.trainer: + with open(config_dir / "trainer.yaml", "w") as f: + yaml.dump(OmegaConf.to_container(cfg.trainer, resolve=True), f, default_flow_style=False, sort_keys=False) + if context_cfg: + with open(config_dir / "context.yaml", "w") as f: + yaml.dump(OmegaConf.to_container(context_cfg, resolve=True), f, default_flow_style=False, sort_keys=False) + print(f"[Cents] Wrote run configs to {config_dir}") + + +def main(args) -> None: + MODEL_NAME = args.model_name + CR_LOSS_WEIGHT = args.cr_loss_weight + TC_LOSS_WEIGHT = args.tc_loss_weight + run_name = args.run_name + np.random.seed(args.random_seed) + os.environ['PYTHONHASHSEED'] = str(args.random_seed) + random.seed(args.random_seed) + torch.manual_seed(args.random_seed) -def main() -> None: - MODEL_NAME = "acgan" - CR_LOSS_WEIGHT = 0.1 - TC_LOSS_WEIGHT = 0.1 - dataset = PecanStreetDataset(overrides=["user_group=all", "time_series_dims=2"]) + + # Create run directory under runs/{dataset}/{run_name} + RUNS_DIR.mkdir(parents=True, exist_ok=True) + run_dir = RUNS_DIR / args.dataset / run_name + run_dir.mkdir(parents=True, exist_ok=True) + print(f"[Cents] Run directory: {run_dir}") + + # Set custom context config path if provided + if args.context_config_path: + set_context_config_path(args.context_config_path) + + # Set context config overrides if provided + if args.context_overrides: + set_context_overrides(args.context_overrides) + + # Build dataset-specific overrides (key=value list; config is loaded from config/dataset/{dataset}.yaml) + dataset_overrides = [f"skip_heavy_processing={args.skip_heavy_processing}"] + if getattr(args, "no_normalizer_global_preprocessing", False): + dataset_overrides.append("normalizer_use_global_stats_preprocessing=false") + if args.dataset == "pecanstreet": + dataset_overrides.extend(["time_series_dims=1", "user_group=all"]) + dataset_cfg = _load_dataset_config(args.dataset, dataset_overrides) + + if args.dataset == "pecanstreet": + dataset = PecanStreetDataset( + cfg=dataset_cfg, + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), + ) + elif args.dataset == "commercial": + dataset = CommercialDataset( + cfg=dataset_cfg, + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), + ) + elif args.dataset == "airquality": + dataset = AirQualityDataset( + cfg=dataset_cfg, + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), + ) + + elif args.dataset == "metraq": + dataset = MetraqDataset( + cfg=dataset_cfg, + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), + ) + elif args.dataset == "walmart": + dataset = WalmartDataset( + cfg=dataset_cfg, + force_retrain_normalizer=args.force_retrain_normalizer, + run_dir=str(run_dir), + ) + else: + raise ValueError(f"Dataset {args.dataset} not supported") + + print("Initialized Dataset") trainer_overrides = [ - "trainer.max_epochs=5000", - "trainer.strategy=ddp_find_unused_parameters_true", - "trainer.eval_after_training=True", - "wandb.enabled=True", - "wandb.project=cents", - "wandb.entity=michael-fuest-technical-university-of-munich", + f"run_dir={run_dir}", + f"trainer.max_epochs={args.epochs}", + f"trainer.checkpoint.every_n_epochs={args.every_n_epochs}", + f"trainer.strategy={args.ddp_strategy}", + f"trainer.devices={args.devices}", + f"trainer.accelerator={args.accelerator}", + f"trainer.eval_after_training={args.eval_after_training}", + f"trainer.intermediate_fid.enabled={args.intermediate_fid}", + f"trainer.intermediate_fid.every_n_epochs={args.fid_every_n_epochs}", + f"wandb.enabled={args.wandb_enabled}", + f"wandb.project={args.wandb_project}", + f"wandb.entity={args.wandb_entity}", + f"wandb.name={run_name}_{MODEL_NAME}_{datetime.now().strftime('%Y%m%d-%H%M%S')}", f"model.context_reconstruction_loss_weight={CR_LOSS_WEIGHT}", f"model.tc_loss_weight={TC_LOSS_WEIGHT}", - f"wandb.name=training_dai_{MODEL_NAME}_{datetime.now().strftime('%Y%m%d-%H%M%S')}_L{CR_LOSS_WEIGHT}_TC_{TC_LOSS_WEIGHT}_dim2", ] + if args.model_overrides: + trainer_overrides.extend(args.model_overrides) trainer = Trainer( model_type=MODEL_NAME, @@ -28,8 +176,58 @@ def main() -> None: overrides=trainer_overrides, ) - trainer.fit() + _write_run_summary(run_dir, run_name, trainer, args.random_seed) + _write_run_configs(run_dir, trainer) + trainer.fit(ckpt_path=args.resume_from_checkpoint) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--every_n_epochs", type=int, default=250) + parser.add_argument("--devices", type=str, default="auto") + parser.add_argument("--accelerator", type=str, default="gpu") + parser.add_argument("--model_name", type=str, default="diffusion_ts") + parser.add_argument("--cr_loss_weight", type=float, default=0.1) + parser.add_argument("--tc_loss_weight", type=float, default=0.1) + parser.add_argument("--dataset", type=str, default="pecanstreet") + parser.add_argument("--epochs", type=int, default=2500) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--wandb-enabled", action="store_true", + help="Enable Weights and Biases logging") + parser.add_argument("--wandb-project", type=str, default="cents") + parser.add_argument("--wandb-entity", type=str, default=None) + parser.add_argument("--eval-after-training", action="store_true", + help="Evaluate after training") + parser.add_argument("--skip-heavy-processing", action="store_true", + help="Skip heavy processing of dataset") + parser.add_argument("--ddp-strategy", type=str, default="ddp_find_unused_parameters_true") + parser.add_argument("--enable-checkpointing", action="store_true", + help="Enable checkpointing") + parser.add_argument("--context-config-path", type=str, default=None, + help="Path to custom context config YAML file (optional)") + parser.add_argument("--context-overrides", type=str, nargs="*", default=[], + help="Override context config values (e.g., 'static_context.type=mlp' 'dynamic_context.type=cnn')") + parser.add_argument("--force-retrain-normalizer", action="store_true", + help="Force retraining of normalizer even if cached version exists") + parser.add_argument("--no-normalizer-global-preprocessing", action="store_true", + help="Predict normalizer mu/sigma directly (no global-stats scaling). Use if commercial runs had NaNs with global preprocessing.") + parser.add_argument("--resume-from-checkpoint", type=str, default=None, + help="Path to checkpoint file (.ckpt) to resume training from", + ) + parser.add_argument("--model-overrides", type=str, nargs="*", default=[], + help="Override model config values (e.g., 'model.cond_emb_dim=16' 'model.attn_pd=0.0')", + ) + parser.add_argument("--run-name", type=str, required=True, + help="Name of this run. A directory runs// will be created for checkpoints, cache, and summary.", + ) + parser.add_argument("--random-seed", type=int, default=42, + help="Random seed for reproducibility",) + parser.add_argument("--intermediate-fid", action="store_true", default=True, + help="Enable intermediate context-FID checks during training (saves top-k checkpoints)") + parser.add_argument("--no-intermediate-fid", dest="intermediate_fid", action="store_false", + help="Disable intermediate context-FID checks") + parser.add_argument("--fid-every-n-epochs", type=int, default=20, + help="How often (in epochs) to compute context-FID during training") + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index e4af31a..ea5af39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,8 @@ import numpy as np import pandas as pd import pytest -from hydra import compose, initialize_config_dir from omegaconf import DictConfig, OmegaConf +from cents.utils.config_loader import load_yaml from cents.datasets.timeseries_dataset import TimeSeriesDataset from cents.trainer import Trainer @@ -135,32 +135,25 @@ def normalized_dataset_2d(raw_df_2d): def load_top_level_config() -> DictConfig: - config_dir = os.path.join(ROOT_DIR, "tests", "test_configs") - with initialize_config_dir(config_dir=config_dir, version_base=None): - cfg = compose(config_name="test_config", overrides=[]) - return cfg + path = os.path.join(ROOT_DIR, "tests", "test_configs", "test_config.yaml") + return load_yaml(path) def load_dataset_config(case: str) -> DictConfig: - config_dir = os.path.join(ROOT_DIR, "tests", "test_configs", "dataset") - with initialize_config_dir(config_dir=config_dir, version_base=None): - ds_cfg = compose(config_name=case, overrides=[]) + path = os.path.join(ROOT_DIR, "tests", "test_configs", "dataset", f"{case}.yaml") + ds_cfg = load_yaml(path) OmegaConf.set_struct(ds_cfg, False) return ds_cfg def load_model_config(model_type: str) -> DictConfig: - config_dir = os.path.join(ROOT_DIR, "tests", "test_configs", "model") - with initialize_config_dir(config_dir=config_dir, version_base=None): - model_cfg = compose(config_name=model_type, overrides=[]) - return model_cfg + path = os.path.join(ROOT_DIR, "tests", "test_configs", "model", f"{model_type}.yaml") + return load_yaml(path) def load_trainer_config(trainer_name: str) -> DictConfig: - config_dir = os.path.join(ROOT_DIR, "tests", "test_configs", "trainer") - with initialize_config_dir(config_dir=config_dir, version_base=None): - trainer_cfg = compose(config_name=trainer_name, overrides=[]) - return trainer_cfg + path = os.path.join(ROOT_DIR, "tests", "test_configs", "trainer", f"{trainer_name}.yaml") + return load_yaml(path) @pytest.fixture diff --git a/tests/test_configs/model/diffusion_ts.yaml b/tests/test_configs/model/diffusion_ts.yaml index 1750bc9..e52666c 100644 --- a/tests/test_configs/model/diffusion_ts.yaml +++ b/tests/test_configs/model/diffusion_ts.yaml @@ -11,6 +11,9 @@ n_steps: 1000 sampling_timesteps: 1000 sampling_batch_size: 1 loss_type: l1 #l2 +training_objective: x0 +loss_weighting: min_snr +min_snr_gamma: 5.0 beta_schedule: cosine #linear n_heads: 4 mlp_hidden_times: 4