Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
54955a9
Minimal replacement of Hydra
pmfeen Oct 3, 2025
98275aa
Remove dependency on Hydra
pmfeen Oct 3, 2025
be3e37c
Added commercial energy dataset; WIP
pmfeen Oct 7, 2025
13cde34
Fixed training for PecanStreet on GPU, fixed indexing errors in Comme…
Oct 13, 2025
5bdccff
Removed debugging print statements, moved back to GPU
Oct 13, 2025
121fb56
Fixed merge error in CommercialDataset, Introduced logic for sampling…
Oct 20, 2025
c18968d
Eval pretrain fixes, normalization caching for subprocesses, EMA weig…
Oct 27, 2025
4f63eb4
Modularized ContextModule, StatsHead for future work
Oct 28, 2025
05d9950
Changed train script to use argparser
Oct 29, 2025
c9568e7
minor fixes
Oct 30, 2025
8dc3156
Added new context embedder
Nov 3, 2025
9eb35ca
Added registries that were untracked
Nov 5, 2025
7ffc6df
stable normalized training
Nov 13, 2025
ea2ba44
Stabilized training with cont. cvs
Nov 13, 2025
f0f9647
Began adding airquality dataset
Nov 19, 2025
8ece83a
Modified configs for easier customization of context variables and mo…
Jan 7, 2026
e90410a
Dynamic Context Added; CNN Support, Airquality Dataset for Testing
Jan 20, 2026
11e8f46
Tranformer context, distributed normalizer training
Jan 20, 2026
83da97c
Added Gaussian NLL Objective for Normalizer Training
Jan 20, 2026
d8c8877
Resume from checkpoint
Jan 21, 2026
bdb78ed
Added simple generate script
Jan 21, 2026
fef7f88
changed eval script to use commandline args
Jan 21, 2026
74034a4
UNSTABLE - fixes to dyn context in base model
Jan 23, 2026
288ab18
Tracking tools for gradient flow to context
Jan 26, 2026
ddbc53c
Vehicle Training Runs
Feb 2, 2026
a899028
Added vehicle files ; fixes for bininng numerics
Feb 2, 2026
eff190c
Added additional diffusion training objectives ; snr ; revised EMA
Feb 3, 2026
9f77f1d
Added better run tracking
Feb 4, 2026
d87d2db
Added AdaLN for stronger conditioning
Feb 9, 2026
606be7c
cleanup
Feb 9, 2026
bd5d13d
fixes to eval
Feb 10, 2026
0128858
Removed comments, config changes
Feb 10, 2026
adb16f6
config chagnes for training, eval, remove nan checks
Feb 13, 2026
32dc636
Removed print statement
Feb 13, 2026
c5fac35
Global smoothing for normalizer
Feb 13, 2026
ae48d13
intermitent checkpointing, better result display
Feb 17, 2026
3653211
Fixed checkpointing, use ff loss, optional conditional guidance
Feb 18, 2026
e1ed0e8
New run tracking, eval in normalized and raw domain, stabilization fo…
Feb 24, 2026
8fd58cf
Eval in raw, normalized domain, select vars for normalization, dynami…
Feb 28, 2026
a4756d1
cross-attention for dynamic context, cfg guidance, dropout regulariza…
Mar 2, 2026
fb5f645
Select normalization, dropout regularization (massive improvement
Mar 3, 2026
7a58811
Added metraq dataset
Mar 8, 2026
da10776
metraq dataset implementation
Mar 11, 2026
eeff929
ema + gradient clipping
Mar 11, 2026
051dd8e
Intermittent FID eval
Mar 16, 2026
15393de
changes to generate. also, good 750 epoch aqr run
Mar 18, 2026
10525f8
Implemented walmart dataset; layernomr for dyn context ; config changes
Mar 31, 2026
bb5f213
Joint dyn transformer, new eval metrics
Apr 12, 2026
b07a57f
Context Recovery Score, rounding for walmart dataset, additional stat…
Apr 24, 2026
e442ddf
Removed unused code, removes nan checks, removed print statements
May 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ ENV/
.*.swp

# Repository Specific
runs/
cents/data/*
cents/data/pecanstreet/*
cents/data/commercial/*
cents/data/custom/
.DS_Store
.ipynb_checkpoints
Expand Down
26 changes: 0 additions & 26 deletions cents/config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,5 @@
defaults:
- model: null
- dataset: pecanstreet
- evaluator: default
- trainer: null
- _self_

device: auto
job_name: ${model.name}_${dataset.name}_${dataset.user_group}
run_dir: outputs/${job_name}/${now:%Y-%m-%d_%H-%M-%S}
model_ckpt: null
hydra:
job_logging:
version: 1
formatters:
simple:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
handlers:
console:
class: logging.StreamHandler
formatter: simple
level: INFO
root:
handlers: [console]
level: INFO
run:
dir: ${run_dir}


wandb:
enabled: false
Expand Down
20 changes: 20 additions & 0 deletions cents/config/context/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Context configuration
# This file defines the context modules used across the codebase

static_context:
type: mlp # Options: "mlp", "sep_mlp", "transformer"
# TransformerStaticContextModule hyperparameters (ignored by mlp/sep_mlp):
# n_heads: 4
# n_layers: 2
# dropout: 0.1
# dim_feedforward: 256

# Normalizer: stats head configuration for the normalizer
normalizer:
stats_head_type: mlp # Stats head type (e.g., "mlp")
n_layers: 5
# hidden_dim: 512

# Dynamic context: context module used by the normalizer for time series context variables
dynamic_context:
type: null # Context module type for dynamic context (e.g., "cnn")
59 changes: 59 additions & 0 deletions cents/config/dataset/airquality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
name: airquality
geography: null
normalize: True
scale: False
use_learned_normalizer: True
threshold: 8
seq_len: 24
shuffle: True
skip_heavy_processing: False
max_samples: null
path: "./data/airquality"
numeric_context_bins: 1
reduce_cardinality: False
time_series_dims: 1
normalizer_stats_mode: group
# Normalizer conditions only on these (e.g. per-station); diffusion still gets full context_vars
normalizer_group_vars: ["station"]

# Targets (what becomes the merged "timeseries" dims)
# NOTE: use PMcoarse instead of PM10
time_series_columns: ["PM2.5"]

# Raw CSV columns to load
# Keep wd/WSPM because we need them to engineer wind_u/wind_v
# Keep PM10 because we need it to engineer PMcoarse
data_columns:
- "No"
- "year"
- "month"
- "day"
- "hour"
- "PM2.5"
- "PM10"
- "SO2"
- "NO2"
- "CO"
- "TEMP"
- "DEWP"
- "PRES"
- "RAIN"
- "WSPM"
- "wd"
- "station"

context_vars:
# static categorical
year: ["categorical", 5]
month: ["categorical", 12]
weekday: ["categorical", 7]
station: ["categorical", 12]

# dynamic time-series context
TEMP: ["time_series", null]
DEWP: ["time_series", null]
PRES: ["time_series", null]
RAIN: ["time_series", null]
wind_u: ["time_series", null]
wind_v: ["time_series", null]
wd_valid: ["time_series", null]
30 changes: 30 additions & 0 deletions cents/config/dataset/commercial.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: commercial
geography: null
user_group: all
normalize: True
scale: False
use_learned_normalizer: True
threshold: 8
seq_len: 24
time_series_dims: 1
shuffle: True
skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP)
max_samples: null # Limit dataset size (null = use all data)
path: "./data/commercial/csv"
time_series_columns: "energy_meter"
data_columns: ["dataid","energy_meter","timestamp"]
metadata_columns: ["building_id", "site_id", "primaryspaceusage", "sqft", "yearbuilt", "sub_primaryspaceusage"]
numeric_context_bins: 5
reduce_cardinality: False
normalizer_stats_mode: group
normalizer_group_vars: null

context_vars:
year: ["categorical", 2]
month: ["categorical", 12]
weekday: ["categorical", 7]
site_id: ["categorical", 19]
primaryspaceusage: ["categorical", 16]
sqft: ["categorical", null]
yearbuilt: ["categorical", null]
sub_primaryspaceusage: ["categorical", 104]
26 changes: 14 additions & 12 deletions cents/config/dataset/default.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
name: default
normalize: True
scale: True
use_learned_normalizer: True
shuffle: True
threshold: 6
time_series_dims: 1
time_series_columns: []
seq_len: 8
user_group: null
# name: default
# normalize: True
# scale: True
# use_learned_normalizer: True
# shuffle: True
# threshold: 6
# time_series_dims: 1
# time_series_columns: []
# seq_len: 8
# user_group: null

numeric_context_bins: 5
context_vars: {}
# numeric_context_bins: 5
# context_vars: {} # Dict mapping variable names to category counts (for categorical) or placeholders (for continuous)
# continuous_context_vars: [] # Optional: list of variable names that should be kept as continuous (not binned)
# stats_head_type: mlp
54 changes: 54 additions & 0 deletions cents/config/dataset/metraq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: metraq
geography: null
normalize: True
scale: False
use_learned_normalizer: True
threshold: 8
seq_len: 24
shuffle: True
skip_heavy_processing: False
max_samples: null
path: "./data/metraq"
numeric_context_bins: 1
reduce_cardinality: False
time_series_dims: 1
normalizer_stats_mode: group
# Normalizer conditions only on these (e.g. per-station); diffusion still gets full context_vars
normalizer_group_vars: ["sensor_name"]
max_z_threshold: 15.0

# Targets (what becomes the merged "timeseries" dims)
# NOTE: use PMcoarse instead of PM10
time_series_columns: ["PM2.5"]

# Raw CSV columns to load
# Keep wd/WSPM because we need them to engineer wind_u/wind_v
# Keep PM10 because we need it to engineer PMcoarse
data_columns:
- "entry_date"
- "magnitude_name"
- "sensor_name"
- "value"
# - "utm_x"
# - "utm_y"

context_vars:
# static categorical
year: ["categorical", 6]
month: ["categorical", 12]
weekday: ["categorical", 7]
sensor_name: ["categorical", 24]
# utm_x: ["continuous", null]
# utm_y: ["continuous", null]

# dynamic time-series context
# WS and WD are decomposed into wind_u/wind_v in preprocessing to handle
# the circularity of wind direction (WD=355° ≈ WD=5°, but z-score would give opposite signs).
T: ["time_series", null]
# wind_u: ["time_series", null]
# wind_v: ["time_series", null]
# wd_valid: ["time_series", null]
# RH, AP, R dropped — per-sample correlation with PM2.5 < 0.025 across all stations
# Traffic: TI = vehicles/hour (Kriging interpolation); SP = avg speed km/h
TI: ["time_series", null]
SP: ["time_series", null]
26 changes: 15 additions & 11 deletions cents/config/dataset/pecanstreet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,24 @@ threshold: 8
seq_len: 96
time_series_dims: 1
shuffle: True
skip_heavy_processing: False # Skip rarity computation (for faster loading/DDP)
max_samples: null # Limit dataset size (null = use all data)
path: "./data/pecanstreet/csv"
time_series_columns: ["grid", "solar"]
time_series_columns: ["grid"]
data_columns: ["dataid","local_15min","car1","grid","solar"]
metadata_columns: ["dataid","building_type","solar","car1","city","state","total_square_footage","house_construction_year"]
user_group: all # non_pv_users, all, pv_users
numeric_context_bins: 5
normalizer_stats_mode: group

context_vars: # for each desired context variable, add the name and number of categories
month: 12
weekday: 7
building_type: 3
has_solar: 2 # note that the metadata csv file column name is 'solar', which is renamed to avoid conflicts with the 'solar' column in the data csv.
car1: 2
city: 7
state: 3
total_square_footage: 5
house_construction_year: 5

context_vars:
month: ["categorical", 12]
weekday: ["categorical", 7]
building_type: ["categorical", 3]
has_solar: ["categorical", 2]
car1: ["categorical", 2]
city: ["categorical", 7]
state: ["categorical", 3]
total_square_footage: ["categorical", null]
house_construction_year: ["categorical", null]
37 changes: 37 additions & 0 deletions cents/config/dataset/walmart.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: walmart
geography: null
normalize: True
scale: False
use_learned_normalizer: True
threshold: 8
seq_len: 28
shuffle: True
skip_heavy_processing: False
max_samples: null
path: "./data/walmart"
numeric_context_bins: 1
reduce_cardinality: False
time_series_dims: 1
normalizer_stats_mode: group
# Normalizer conditions on category × store to capture per-group sales distributions
normalizer_group_vars: ["cat_id", "store_id"]
max_z_threshold: 15.0

# Target: daily unit sales
time_series_columns: ["sales"]

context_vars:
# Static categorical — characterise the window by when it starts
year: ["categorical", 6] # 2011–2016
month: ["categorical", 12]
# Static categorical — item / store identity
cat_id: ["categorical", 3] # FOODS, HOBBIES, HOUSEHOLD
dept_id: ["categorical", 7] # e.g. FOODS_1 … HOUSEHOLD_2
store_id: ["categorical", 10] # CA_1 … WI_3
state_id: ["categorical", 3] # CA, TX, WI

# Dynamic time-series context (co-occurring with target, length = seq_len)
sell_price: ["time_series", null] # weekly price broadcast to daily, z-scored
snap: ["time_series", null] # binary SNAP eligibility for the item's state
event_binary: ["time_series", null] # 1 if a named calendar event falls on that day
weekday: ["time_series", null] # day of week encoded as 0 (Mon) – 6 (Sun), z-scored
32 changes: 32 additions & 0 deletions cents/config/evaluator/airquality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
model:
name: diffusion_ts
dataset:
name: airquality
eval_pv_shift: False
eval_metrics: True
eval_context_sparse: True
save_results: False
eval_disentanglement: True
eval_context_recovery: True
job_name: diffusion_ts_airquality
save_dir: outputs/diffusion_ts_airquality/eval

# Context Faithfulness Score (CFS) and Granger Causality Preservation (GCP).
# Runs only when enabled=True AND either:
# - the generated signal has multiple dimensions (multivariate), OR
# - the dataset uses dynamic (time-series) context variables.
#
# pairs: list of {x, c} dicts specifying which time series to evaluate against each other.
# x — name of a generated output dimension (must match time_series_columns in dataset config)
# c — name of a dynamic context variable (from context_vars with type "time_series")
# OR another generated output dimension (multivariate case, GCP only)
#
# CFS is computed only when c is a dynamic context variable (shared context).
# GCP is computed for all pairs.
#
eval_context_faithfulness:
enabled: true
gcp_max_lag: 5
pairs:
- {x: "PM2.5", c: "TEMP"}
- {x: "PM2.5", c: "DEWP"}
37 changes: 35 additions & 2 deletions cents/config/evaluator/default.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,40 @@
model_name: ${model.name}
model:
name: diffusion_ts # Set this to your model name
dataset:
name: commercial # Set this to your dataset name (e.g., "commercial")
eval_pv_shift: False
eval_metrics: True
pred_score_trtr: True # If True, also trains on real data (TRTR) and reports MAE delta alongside TSTR MAE
eval_context_sparse: True
save_results: False
eval_disentanglement: True
save_dir: ${run_dir}/eval
eval_context_recovery: True
job_name: diffusion_ts_commercial
save_dir: outputs/diffusion_ts_commercial/eval

# Context Faithfulness Score (CFS) and Granger Causality Preservation (GCP).
# Runs only when enabled=True AND either:
# - the generated signal has multiple dimensions (multivariate), OR
# - the dataset uses dynamic (time-series) context variables.
#
# pairs: list of {x, c} dicts specifying which time series to evaluate against each other.
# x — name of a generated output dimension (must match time_series_columns in dataset config)
# c — name of a dynamic context variable (from context_vars with type "time_series")
# OR another generated output dimension (multivariate case, GCP only)
#
# CFS is computed only when c is a dynamic context variable (shared context).
# GCP is computed for all pairs.
#
# Example for airquality dataset (PM2.5 generated, TEMP/DEWP as context):
# pairs:
# - {x: "PM2.5", c: "TEMP"}
# - {x: "PM2.5", c: "DEWP"}
#
eval_context_faithfulness:
enabled: true
gcp_max_lag: 5
pairs:
- {x: "PM2.5", c: "T"}
- {x: "PM2.5", c: "TI"}
- {x: "PM2.5", c: "SP"}
# - {x: "PM2.5", c: "TEMP"}
Loading
Loading