Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9, 3.10.6, 3.11]
python-version: ["3.10", "3.11"]
steps:
- uses: actions/checkout@v3

Expand All @@ -21,9 +21,10 @@ jobs:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
echo "${{ runner.tool_cache }}/.local/bin" >> $GITHUB_PATH
run: pip install poetry

- name: Upgrade virtualenv
run: pip install --upgrade virtualenv

- name: Configure Poetry (disable virtualenv creation)
run: poetry config virtualenvs.create false
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ poetry install
Once installed, activate the virtual environment:

```bash
poetry shell
poetry env activate
```

This gives you a clean, reproducible setup for development.
Expand Down
28 changes: 14 additions & 14 deletions cents/config/dataset/default.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# name: default
# normalize: True
# scale: True
# use_learned_normalizer: True
# shuffle: True
# threshold: 6
# time_series_dims: 1
# time_series_columns: []
# seq_len: 8
# user_group: null
name: default
normalize: True
scale: True
use_learned_normalizer: True
shuffle: True
threshold: 6
time_series_dims: 1
time_series_columns: []
seq_len: 8
user_group: null

# numeric_context_bins: 5
# context_vars: {} # Dict mapping variable names to category counts (for categorical) or placeholders (for continuous)
# continuous_context_vars: [] # Optional: list of variable names that should be kept as continuous (not binned)
# stats_head_type: mlp
numeric_context_bins: 5
context_vars: {} # Dict mapping variable names to category counts (for categorical) or placeholders (for continuous)
continuous_context_vars: [] # Optional: list of variable names that should be kept as continuous (not binned)
stats_head_type: mlp
4 changes: 2 additions & 2 deletions cents/config/trainer/normalizer.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
strategy: ddp_find_unused_parameters_true
accelerator: gpu
devices: 1,
accelerator: auto
devices: auto
log_every_n_steps: 1
hidden_dim: 512
embedding_dim: 256
Expand Down
20 changes: 18 additions & 2 deletions cents/datasets/timeseries_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
context_var_column_names = list(context_var_column_names)

self.time_series_column_names = time_series_column_names
self.time_series_dims = self.cfg.time_series_dims
self.time_series_dims = len(time_series_column_names)
self.context_vars = context_var_column_names or []
self.seq_len = seq_len

Expand Down Expand Up @@ -618,7 +618,15 @@ def _get_normalization_cache_path(self):
context_cfg = get_context_config()
context_module_type = context_cfg.dynamic_context.type
stats_head_type = context_cfg.normalizer.stats_head_type
cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{self.normalize}_{self.scale}_{context_module_type or ''}_{stats_head_type or ''}"
try:
ts_parts = []
for col in self.time_series_column_names:
ts_parts.extend(np.asarray(arr, dtype=np.float32).flatten() for arr in self.data[col].values)
ts_bytes = np.concatenate(ts_parts).tobytes() if ts_parts else b""
data_hash = hashlib.md5(ts_bytes).hexdigest()[:8]
except Exception:
data_hash = "nohash"
cache_key = f"{self.name}_{len(self.data)}_{self.seq_len}_{self.normalize}_{self.scale}_{self.time_series_dims}_{context_module_type or ''}_{stats_head_type or ''}_{data_hash}"
cache_hash = hashlib.md5(cache_key.encode()).hexdigest()[:8]
if self.run_dir is not None:
cache_dir = self.run_dir / "cache" / "normalized_data"
Expand Down Expand Up @@ -728,6 +736,14 @@ def _init_normalizer(self) -> None:
elif self.force_retrain_normalizer and cache_path.exists():
print(f"[Cents] Force retrain enabled, ignoring cached normalizer at {cache_path}")

# Unconditional normalizer: no model to train, just compute global stats from data
if self._normalizer.normalizer_model is None:
print("[Cents] Unconditional normalizer: computing global statistics (no model training).")
self._normalizer.setup()
torch.save(self._normalizer.state_dict(), cache_path)
print(f"[Cents] Saved normalizer to {cache_path}")
return

# train and cache a single state dict
print("[Cents] Training normalizer…")
print(f"[Cents] devices: {ncfg.devices}")
Expand Down
4 changes: 2 additions & 2 deletions cents/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,13 @@ def evaluate_subset(
real_data_subset = dataset.data.iloc[indices].reset_index(drop=True)
continuous_vars = getattr(dataset, "continuous_vars", [])
static_context_vars = {}
for name in dataset.static_context_vars:
for name in getattr(dataset, "static_context_vars", []):
vals = real_data_subset[name].values
dtype = torch.float32 if name in continuous_vars else torch.long
static_context_vars[name] = torch.tensor(vals, dtype=dtype, device=self.device)
dynamic_context_vars = {}
categorical_ts = getattr(dataset, "categorical_time_series", {})
for name in dataset.dynamic_context_vars:
for name in getattr(dataset, "dynamic_context_vars", []):
vals = real_data_subset[name].values
# Dynamic module expects tensors (training path uses torch.from_numpy in dataset __getitem__)
if len(vals) and hasattr(vals[0], "__len__") and not isinstance(vals[0], (str, bytes)):
Expand Down
1 change: 1 addition & 0 deletions cents/eval/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def calculate_banded_mse(
Dict with keys "band_1" … "band_N", each containing
{"mean": float, "std": float, "range": [lo, hi]}.
"""
print(real_data.shape, syn_data.shape)
assert real_data.shape == syn_data.shape, "real_data and syn_data must have the same shape"
N, T, D = real_data.shape

Expand Down
2 changes: 1 addition & 1 deletion cents/models/acgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
self.base_channels = base_channels

self.context_vars = context_vars
self.context_module = get_context_module_cls(context_module_type)(context_vars, embedding_dim, continuous_vars=continuous_vars)
self.context_module = get_context_module_cls(context_module_type)(context_vars, embedding_dim)

in_dim = noise_dim + (embedding_dim if context_vars else 0)
self.fc = nn.Linear(in_dim, self.final_window_length * base_channels)
Expand Down
46 changes: 31 additions & 15 deletions cents/models/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(

def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_dict: dict = None):
embeddings = []

# Process static context variables
if self.static_cond_module is not None:
if static_context_vars_dict:
Expand All @@ -159,7 +159,7 @@ def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_di
}
static_embedding, _ = self.static_cond_module(static_context_vars_dict)
embeddings.append(static_embedding)

# Process dynamic context variables
if self.dynamic_cond_module is not None:
if dynamic_context_vars_dict:
Expand All @@ -172,7 +172,7 @@ def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_di
if torch.isnan(dynamic_embedding).any() or torch.isinf(dynamic_embedding).any():
raise ValueError(f"NaN/Inf detected in dynamic embedding.")
embeddings.append(dynamic_embedding)

# Combine embeddings
if len(embeddings) == 2:
combined = torch.cat(embeddings, dim=1)
Expand All @@ -181,7 +181,7 @@ def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_di
embedding = embeddings[0]
else:
raise ValueError("No context variables provided")

return self.stats_head(embedding)


Expand Down Expand Up @@ -291,31 +291,39 @@ def __init__(
seq_len=dynamic_seq_len,
)

self.normalizer_model = _NormalizerModule(
static_cond_module=self.static_context_module,
dynamic_cond_module=self.dynamic_context_module,
hidden_dim=512,
time_series_dims=self.time_series_dims,
do_scale=self.do_scale,
stats_head_type=self.stats_head_type,
n_layers=context_cfg.normalizer.n_layers,
)
has_context = bool(self.static_context_module or self.dynamic_context_module)
if has_context:
self.normalizer_model = _NormalizerModule(
static_cond_module=self.static_context_module,
dynamic_cond_module=self.dynamic_context_module,
hidden_dim=512,
time_series_dims=self.time_series_dims,
do_scale=self.do_scale,
stats_head_type=self.stats_head_type,
n_layers=context_cfg.normalizer.n_layers,
)
else:
self.normalizer_model = None

# Will be populated in setup()
self.sample_stats = []
self._verify_parameters()

def _verify_parameters(self):
if self.normalizer_model is None:
print("[Normalizer] Unconditional mode: no context vars, will use global statistics only.")
return

all_param_names = [name for name, _ in self.named_parameters()]
context_param_names = [name for name in all_param_names if 'cond_module' in name or 'context_module' in name]
stats_head_param_names = [name for name in all_param_names if 'stats_head' in name]

if not context_param_names:
raise RuntimeError(
"Context module parameters not found! "
f"Found parameter names: {all_param_names[:10]}..."
)

print(f"[Normalizer] Found {len(context_param_names)} context module parameters")
print(f"[Normalizer] Found {len(stats_head_param_names)} stats head parameters")
print(f"[Normalizer] Total trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad):,}")
Expand Down Expand Up @@ -383,6 +391,14 @@ def forward(self, static_context_vars_dict: dict = None, dynamic_context_vars_di
Returns:
Tuple of (pred_mu_real, pred_sigma_real, pred_z_min, pred_z_max, pred_log_sigma_raw).
"""
if self.normalizer_model is None:
dev = self.global_mu_mean.device
mu = self.global_mu_mean.expand(1, self.time_series_dims).to(dev)
sigma = torch.exp(self.global_log_sigma_mean).clamp(min=self.min_sigma).expand(1, self.time_series_dims).to(dev)
zeros = torch.zeros(1, self.time_series_dims, device=dev)
log_sigma = self.global_log_sigma_mean.expand(1, self.time_series_dims).to(dev)
return mu, sigma, zeros, zeros, log_sigma

pred_mu_raw, pred_sigma, pred_zmin, pred_zmax, pred_log_sigma_raw = self.normalizer_model(static_context_vars_dict, dynamic_context_vars_dict)

pred_mu_real = self._raw_mu_to_real(pred_mu_raw)
Expand Down
Loading
Loading