Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1a6e50f
innit with contents from other neural network modules that are shared
JonasETHZ Nov 17, 2025
cc90a5f
changed __innit____ to __init__
JonasETHZ Nov 17, 2025
11c285a
start of implementing conditional inputs for
JonasETHZ Nov 22, 2025
acc01de
model mostly implemented, needs testing
JonasETHZ Nov 28, 2025
ca91de7
loss function added
JonasETHZ Nov 30, 2025
b44d771
loss function added
JonasETHZ Nov 30, 2025
94be7a6
removed unneccessary file
JonasETHZ Nov 30, 2025
2e8e5d9
partially debugged
JonasETHZ Dec 2, 2025
9900866
working pixel_cnn_pp_2d.py version
JonasETHZ Dec 3, 2025
52b0caa
working pixel_cnn_pp_2d.py model, added function to save model every …
JonasETHZ Dec 3, 2025
ead782b
finished up evaluate_pixel_cnn_pp_2d.py and fixed parameter handling …
JonasETHZ Dec 5, 2025
7081abd
Added multiple GPU support for sampling
JonasETHZ Dec 5, 2025
bd34475
increased number of generated samples to 25
JonasETHZ Dec 5, 2025
f5c475a
multi GPU v2
JonasETHZ Dec 5, 2025
66a1e8f
added info message for multi-GPU sampling
JonasETHZ Dec 5, 2025
d690781
model with max settings and 1 GPU
JonasETHZ Dec 5, 2025
1ef0254
fixed scaling (input/output) and removed multi GPU sampling
JonasETHZ Dec 8, 2025
eeb1812
sample interval 600
JonasETHZ Dec 8, 2025
811167a
updated evaluate_pixel_cnn_pp_2d.py
JonasETHZ Dec 10, 2025
e3081c3
added some debugging lines
JonasETHZ Dec 10, 2025
c5a143a
added sampling in batches for pixel_cnn_pp_2d
JonasETHZ Dec 10, 2025
b73b420
removed comments and time.sleep()
JonasETHZ Dec 10, 2025
93e34ec
added info text
JonasETHZ Dec 10, 2025
c246690
another print statement
JonasETHZ Dec 10, 2025
e8fbf36
lowered sampling batch size to reduce memory consumption during evalu…
JonasETHZ Dec 10, 2025
fd2bb14
nr of samples changed to 25
JonasETHZ Dec 12, 2025
b2858de
evaluation preparation and removal of some debug prints
JonasETHZ Dec 15, 2025
9497b4e
Final version of PixelCNN++ implementation
JonasETHZ Dec 17, 2025
f29cca1
Fixed typos
JonasETHZ Dec 17, 2025
23b9091
Merge branch 'main' into main
JonasETHZ Dec 17, 2025
ea6fe9a
Formatting with ruff
JonasETHZ Dec 17, 2025
b2abfe9
Merge branch 'main' of https://github.com/JonasETHZ/EngiOpt
JonasETHZ Dec 17, 2025
c2f726b
replaced NIN by NetworkInNetwork
JonasETHZ Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
164 changes: 164 additions & 0 deletions engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Evaluation of the PixelCNN++ model."""

from __future__ import annotations

from dataclasses import dataclass
import os

from engibench.utils.all_problems import BUILTIN_PROBLEMS
import numpy as np
import pandas as pd
import torch as th
import tyro
import wandb

from engiopt import metrics
from engiopt.dataset_sample_conditions import sample_conditions
from engiopt.pixel_cnn_pp_2d.pixel_cnn_pp_2d import PixelCNNpp
from engiopt.pixel_cnn_pp_2d.pixel_cnn_pp_2d import sample_from_discretized_mix_logistic


@dataclass
class Args:
"""Command-line arguments for a single-seed PixelCNN++ 2D evaluation."""

problem_id: str = "beams2d"
"""Problem identifier."""
seed: int = 1
"""Random seed to run."""
sampling_batch_size: int = 5
"""Batch size to use during sampling."""
wandb_project: str = "engiopt"
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
"""Kernel bandwidth for MMD and DPP metrics."""
output_csv: str = "pixel_cnn_pp_2d_{problem_id}_metrics.csv"
"""Output CSV path template; may include {problem_id}."""


if __name__ == "__main__":
args = tyro.cli(Args)

seed = args.seed
problem = BUILTIN_PROBLEMS[args.problem_id]()
problem.reset(seed=seed)

# Seeding for reproducibility
th.manual_seed(seed)
rng = np.random.default_rng(seed)
th.backends.cudnn.deterministic = True

# Select device
if th.backends.mps.is_available():
device = th.device("mps")
elif th.cuda.is_available():
device = th.device("cuda")
else:
device = th.device("cpu")

# Set up testing conditions
conditions_tensor, sampled_conditions, sampled_designs_np, selected_indices = sample_conditions(
problem=problem,
n_samples=args.n_samples,
device=device,
seed=seed,
)

# adapt to PixelCNN++ input shape requirements
conditions_tensor = conditions_tensor.unsqueeze(-1).unsqueeze(-1)
design_shape = (problem.design_space.shape[0], problem.design_space.shape[1])

# Set up PixelCNN++ model
if args.wandb_entity is not None:
artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_pixel_cnn_pp_2d_model:seed_{seed}"
else:
artifact_path = f"{args.wandb_project}/{args.problem_id}_pixel_cnn_pp_2d_model:seed_{seed}"

api = wandb.Api()
artifact = api.artifact(artifact_path, type="model")

class RunRetrievalError(ValueError):
def __init__(self):
super().__init__("Failed to retrieve the run")

run = artifact.logged_by()
if run is None or not hasattr(run, "config"):
raise RunRetrievalError

artifact_dir = artifact.download()
ckpt_path = os.path.join(artifact_dir, "model.pth")
ckpt = th.load(ckpt_path, map_location=device)

model = PixelCNNpp(
nr_resnet=run.config["nr_resnet"],
nr_filters=run.config["nr_filters"],
nr_logistic_mix=run.config["nr_logistic_mix"],
resnet_nonlinearity=run.config["resnet_nonlinearity"],
dropout_p=run.config["dropout_p"],
input_channels=1,
nr_conditions=conditions_tensor.shape[1],
)

model.load_state_dict(ckpt["model"])
model.eval()
model.to(device)

# Sample designs in batches
all_batches: list[th.Tensor] = []

for start in range(0, args.n_samples, args.sampling_batch_size):
end = min(args.n_samples, start + args.sampling_batch_size)
b = end - start

# prepare batch-local tensors on the same device as the model
batch_conds = conditions_tensor[start:end]
data = th.zeros((b, 1, *design_shape), device=device)

# Autoregressive pixel sampling for this batch
for i in range(design_shape[0]):
for j in range(design_shape[1]):
out = model(data, batch_conds)
out_sample = sample_from_discretized_mix_logistic(out, run.config["nr_logistic_mix"])
data[:, :, i, j] = out_sample.data[:, :, i, j]

# move completed batch to CPU to free GPU memory and store
all_batches.append(data.cpu())

# concatenate all batches on CPU
gen_designs = th.cat(all_batches, dim=0)

gen_designs_np = gen_designs.detach().cpu().numpy()
gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) # remove channel dim
gen_designs_np = np.clip(gen_designs_np, 1e-3, 1.0)

# Compute metrics
metrics_dict = metrics.metrics(
problem,
gen_designs_np,
sampled_designs_np,
sampled_conditions,
sigma=args.sigma,
)

# Add metadata to metrics
metrics_dict.update(
{
"seed": seed,
"problem_id": args.problem_id,
"model_id": "pixel_cnn_pp_2d",
"n_samples": args.n_samples,
"sigma": args.sigma,
}
)

# Append result row to CSV
metrics_df = pd.DataFrame([metrics_dict])
out_path = args.output_csv.format(problem_id=args.problem_id)
write_header = not os.path.exists(out_path)
metrics_df.to_csv(out_path, mode="a", header=write_header, index=False)

print(f"Seed {seed} done; appended to {out_path}")
Loading
Loading