From 1a6e50f337fabb3078c8308b4951077f428fac9c Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 17 Nov 2025 11:16:48 +0100 Subject: [PATCH 01/31] innit with contents from other neural network modules that are shared --- engiopt/pixel_cnn_pp_2d/__innit__.py | 0 .../evaluate_pixel_cnn_pp_2d.py | 124 +++++++++++ engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 192 ++++++++++++++++++ 3 files changed, 316 insertions(+) create mode 100644 engiopt/pixel_cnn_pp_2d/__innit__.py create mode 100644 engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py create mode 100644 engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py diff --git a/engiopt/pixel_cnn_pp_2d/__innit__.py b/engiopt/pixel_cnn_pp_2d/__innit__.py new file mode 100644 index 0000000..e69de29 diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py new file mode 100644 index 0000000..0fcc5df --- /dev/null +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +import pandas as pd +import os +from engibench.utils.all_problems import BUILTIN_PROBLEMS +import numpy as np +import torch as th +import tyro +import wandb +from engiopt import metrics +from engiopt.dataset_sample_conditions import sample_conditions + + +@dataclass +class Args: + """Command-line arguments for a single-seed PixelCNN++ 2D evaluation.""" + + problem_id: str = "beams2d" + """Problem identifier.""" + seed: int = 1 + """Random seed to run.""" + wandb_project: str = "engiopt" + """Wandb project name.""" + wandb_entity: str | None = None + """Wandb entity name.""" + n_samples: int = 50 + """Number of generated samples per seed.""" + sigma: float = 10.0 + """Kernel bandwidth for MMD and DPP metrics.""" + output_csv: str = "pixel_cnn_pp_2d_{problem_id}_metrics.csv" + """Output CSV path template; may include {problem_id}.""" + + +if __name__ == "__main__": + args = tyro.cli(Args) + + seed = args.seed + problem = BUILTIN_PROBLEMS[args.problem_id]() + problem.reset(seed=seed) + + # Seeding for reproducibility + th.manual_seed(seed) + rng = np.random.default_rng(seed) + th.backends.cudnn.deterministic = True + + # Select device + if th.backends.mps.is_available(): + device = th.device("mps") + elif th.cuda.is_available(): + device = th.device("cuda") + else: + device = th.device("cpu") + + ### Set up testing conditions ### + conditions_tensor, sampled_conditions, sampled_designs_np, _ = sample_conditions( + problem=problem, + n_samples=args.n_samples, + device=device, + seed=seed, + ) + + # -------------------------------------------------------- + # adapt to PixelCNN++ input shape requirements + conditions_tensor = conditions_tensor.unsqueeze(1) + + ### Set Up Diffusion Model ### + if args.wandb_entity is not None: + artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_diffusion_2d_cond_model:seed_{seed}" + else: + artifact_path = f"{args.wandb_project}/{args.problem_id}_diffusion_2d_cond_model:seed_{seed}" + + api = wandb.Api() + artifact = api.artifact(artifact_path, type="model") + + class RunRetrievalError(ValueError): + def __init__(self): + super().__init__("Failed to retrieve the run") + + run = artifact.logged_by() + if run is None or not hasattr(run, "config"): + raise RunRetrievalError + + artifact_dir = artifact.download() + ckpt_path = os.path.join(artifact_dir, "model.pth") # change model.pth if necessary + ckpt = th.load(ckpt_path, map_location=device) # or th.device(device) + + + # Build PixelCNN++ Model + # ... + + + # Generate a batch of designs + # ... + + gen_designs_np = None, #gen_designs.detach().cpu().numpy().reshape(args.n_samples, *problem.design_space.shape) + # Clip to boundaries for running THIS IS PROBLEM DEPENDENT + gen_designs_np = np.clip(gen_designs_np, 1e-3, 1.0) + + # Compute metrics + metrics_dict = metrics.metrics( + problem, + gen_designs_np, + sampled_designs_np, + sampled_conditions, + sigma=args.sigma, + ) + + # Add metadata to metrics + metrics_dict.update( + { + "seed": seed, + "problem_id": args.problem_id, + "model_id": "pixel_cnn_pp_2d", + "n_samples": args.n_samples, + "sigma": args.sigma, + } + ) + + # Append result row to CSV + metrics_df = pd.DataFrame([metrics_dict]) + out_path = args.output_csv.format(problem_id=args.problem_id) + write_header = not os.path.exists(out_path) + metrics_df.to_csv(out_path, mode="a", header=write_header, index=False) + + print(f"Seed {seed} done; appended to {out_path}") \ No newline at end of file diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py new file mode 100644 index 0000000..bd50c94 --- /dev/null +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -0,0 +1,192 @@ +from dataclasses import dataclass +import os +import tyro +from engibench.utils.all_problems import BUILTIN_PROBLEMS +import torch as th +import numpy as np +import random +import tqdm +import time +import matplotlib.pyplot as plt + +import wandb + + +@dataclass +class Args: + """Command-line arguments.""" + + problem_id: str = "beams2d" + """Problem identifier.""" + algo: str = os.path.basename(__file__)[: -len(".py")] + """The name of this algorithm.""" + + # Tracking + track: bool = True + """Track the experiment with wandb.""" + wandb_project: str = "engiopt" + """Wandb project name.""" + wandb_entity: str | None = None + """Wandb entity name.""" + seed: int = 1 + """Random seed.""" + save_model: bool = False + """Saves the model to disk.""" + + # CHANGE! + # Algorithm specific + n_epochs: int = 200 + """number of epochs of training""" + batch_size: int = 32 + """size of the batches""" + lr: float = 0.0001 + """learning rate""" + b1: float = 0.5 + """decay of first order momentum of gradient""" + b2: float = 0.999 + """decay of first order momentum of gradient""" + n_cpu: int = 8 + """number of cpu threads to use during batch generation""" + latent_dim: int = 32 + """dimensionality of the latent space""" + sample_interval: int = 400 + """interval between image samples""" + + +# IMPLEMENT PIXELCNN++ HERE + + + +if __name__ == "__main__": + args = tyro.cli(Args) + + problem = BUILTIN_PROBLEMS[args.problem_id]() + problem.reset(seed=args.seed) + + design_shape = problem.design_space.shape + + + # Logging + run_name = f"{args.problem_id}__{args.algo}__{args.seed}__{int(time.time())}" + if args.track: + wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), save_code=True, name=run_name) + + # Seeding + th.manual_seed(args.seed) + rng = np.random.default_rng(args.seed) + random.seed(args.seed) + th.backends.cudnn.deterministic = True + + os.makedirs("images", exist_ok=True) + + if th.backends.mps.is_available(): + device = th.device("mps") + elif th.cuda.is_available(): + device = th.device("cuda") + else: + device = th.device("cpu") + + # Loss function + # ... implement + + # Initialize model + # ... implement + + # model.to(device) + # loss.to(device) + + # Configure data loader + training_ds = problem.dataset.with_format("torch", device=device)["train"] + # ... + + dataloader = th.utils.data.DataLoader( + training_ds, + batch_size=args.batch_size, + shuffle=True, + ) + + # Training loop + # optimizer = th.optim.Adam(model.parameters(), lr=args.lr) # add other args if necessary + + # @th.no_grad() + # def sample_designs(model: ---, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: + # ... implement + + + + # ---------- + # Training + # ---------- + for epoch in tqdm.trange(args.n_epochs): + for i, data in enumerate(dataloader): + batch_start_time = time.time() + # ... implement + + # Backpropagation + # loss.backward() + # optimizer.step() + + + # ---------- + # Logging + # ---------- + if args.track: + batches_done = epoch * len(dataloader) + i + wandb.log( + { + "loss": None, #loss.item(), + "epoch": epoch, + "batch": batches_done, + } + ) + print( + f"[Epoch {epoch}/{args.n_epochs}] [Batch {i}/{len(dataloader)}] [loss: {None}]] [{time.time() - batch_start_time:.2f} sec]" #loss.item() + ) + + # This saves a grid image of 25 generated designs every sample_interval + if batches_done % args.sample_interval == 0: + # Extract 25 designs + + designs, hidden_states = None #sample_designs(model, 25) + fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + + # Flatten axes for easy indexing + axes = axes.flatten() + + # Plot the image created by each output + for j, tensor in enumerate(designs): + img = tensor.cpu().numpy() # Extract x and y coordinates + dc = hidden_states[j, 0, :].cpu() + axes[j].imshow(img[0]) # image plot + title = [(problem.conditions[i][0], f"{dc[i]:.2f}") for i in range(len(problem.conditions))] + title_string = "\n ".join(f"{condition}: {value}" for condition, value in title) + axes[j].title.set_text(title_string) # Set title + axes[j].set_xticks([]) # Hide x ticks + axes[j].set_yticks([]) # Hide y ticks + + plt.tight_layout() + img_fname = f"images/{batches_done}.png" + plt.savefig(img_fname) + plt.close() + wandb.log({"designs": wandb.Image(img_fname)}) + + # -------------- + # Save models + # -------------- + if args.save_model and epoch == args.n_epochs - 1 and i == len(dataloader) - 1: + ckpt_model = { + "epoch": epoch, + "batches_done": batches_done, + "model": None, # model.state_dict(), + "optimizer_generator": None, # optimizer.state_dict(), + "loss": None, # loss.item(), + } + + th.save(ckpt_model, "model.pth") + if args.track: + artifact_model = wandb.Artifact(f"{args.problem_id}_{args.algo}_model", type="model") + artifact_model.add_file("model.pth") + + wandb.log_artifact(artifact_model, aliases=[f"seed_{args.seed}"]) + + wandb.finish() \ No newline at end of file From cc90a5fd14b034c18a29285bc6accc5836cece80 Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 17 Nov 2025 14:27:08 +0100 Subject: [PATCH 02/31] changed __innit____ to __init__ --- engiopt/pixel_cnn_pp_2d/{__innit__.py => __init__.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename engiopt/pixel_cnn_pp_2d/{__innit__.py => __init__.py} (100%) diff --git a/engiopt/pixel_cnn_pp_2d/__innit__.py b/engiopt/pixel_cnn_pp_2d/__init__.py similarity index 100% rename from engiopt/pixel_cnn_pp_2d/__innit__.py rename to engiopt/pixel_cnn_pp_2d/__init__.py From 11c285a5255ca93990efa6be3178944acddaa0a1 Mon Sep 17 00:00:00 2001 From: Jonas Date: Sat, 22 Nov 2025 10:28:56 +0100 Subject: [PATCH 03/31] start of implementing conditional inputs for PixelCNN++ --- .../evaluate_pixel_cnn_pp_2d.py | 2 +- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 120 ++++++-- .../pixel_cnn_pp_2d/print_dataset_element.py | 284 ++++++++++++++++++ 3 files changed, 384 insertions(+), 22 deletions(-) create mode 100644 engiopt/pixel_cnn_pp_2d/print_dataset_element.py diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index 0fcc5df..69805bd 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -121,4 +121,4 @@ def __init__(self): write_header = not os.path.exists(out_path) metrics_df.to_csv(out_path, mode="a", header=write_header, index=False) - print(f"Seed {seed} done; appended to {out_path}") \ No newline at end of file + print(f"Seed {seed} done; appended to {out_path}") diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index bd50c94..7bbc22d 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -1,8 +1,10 @@ from dataclasses import dataclass +from email import generator import os import tyro from engibench.utils.all_problems import BUILTIN_PROBLEMS import torch as th +from torch import nn import numpy as np import random import tqdm @@ -22,7 +24,7 @@ class Args: """The name of this algorithm.""" # Tracking - track: bool = True + track: bool = False """Track the experiment with wandb.""" wandb_project: str = "engiopt" """Wandb project name.""" @@ -35,26 +37,80 @@ class Args: # CHANGE! # Algorithm specific - n_epochs: int = 200 + n_epochs: int = 1 """number of epochs of training""" batch_size: int = 32 """size of the batches""" lr: float = 0.0001 """learning rate""" - b1: float = 0.5 - """decay of first order momentum of gradient""" - b2: float = 0.999 - """decay of first order momentum of gradient""" - n_cpu: int = 8 - """number of cpu threads to use during batch generation""" - latent_dim: int = 32 - """dimensionality of the latent space""" - sample_interval: int = 400 - """interval between image samples""" + # b1: float = 0.5 + # """decay of first order momentum of gradient""" + # b2: float = 0.999 + # """decay of first order momentum of gradient""" + # n_cpu: int = 8 + # """number of cpu threads to use during batch generation""" + # latent_dim: int = 32 + # """dimensionality of the latent space""" + # sample_interval: int = 400 + # """interval between image samples""" + nr_resnet: int = 5 + """Number of residual blocks per stage of the model.""" + nr_filters: int = 160 + """Number of filters to use across the model. Higher = larger model.""" + nr_logistic_mix: int = 10 + """Number of logistic components in the mixture. Higher = more flexible model.""" + resnet_nonlinearity: str = "concat_elu" + """Nonlinearity to use in the ResNet blocks. One of 'concat_elu', 'elu', 'relu'.""" # IMPLEMENT PIXELCNN++ HERE - +class PixelCNNpp(nn.Module): + def __init__(self): + super().__init__() + + def discretized_mix_logistic_loss(x, l): + pass + # """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ + # xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) + # ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100) + # nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics + # logit_probs = l[:,:,:,:nr_mix] + # l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3]) + # means = l[:,:,:,:,:nr_mix] + # log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.) + # coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix]) + # x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels + # m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix]) + # m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix]) + # means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3) + # centered_x = x - means + # inv_stdv = tf.exp(-log_scales) + # plus_in = inv_stdv * (centered_x + 1./255.) + # cdf_plus = tf.nn.sigmoid(plus_in) + # min_in = inv_stdv * (centered_x - 1./255.) + # cdf_min = tf.nn.sigmoid(min_in) + # log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling) + # log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling) + # cdf_delta = cdf_plus - cdf_min # probability for all other cases + # mid_in = inv_stdv * centered_x + # log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) + + # # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) + + # # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() + # # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) + + # # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) + # # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs + # # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue + # # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value + # log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5)))) + + # log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs) + # if sum_all: + # return -tf.reduce_sum(log_sum_exp(log_probs)) + # else: + # return -tf.reduce_sum(log_sum_exp(log_probs),[1,2]) if __name__ == "__main__": @@ -64,6 +120,9 @@ class Args: problem.reset(seed=args.seed) design_shape = problem.design_space.shape + print(f"Design shape: {design_shape}") + conditions = problem.conditions + nr_conditions = len(conditions) # Logging @@ -87,7 +146,7 @@ class Args: device = th.device("cpu") # Loss function - # ... implement + loss = PixelCNNpp.discretized_mix_logistic_loss # Initialize model # ... implement @@ -97,28 +156,47 @@ class Args: # Configure data loader training_ds = problem.dataset.with_format("torch", device=device)["train"] - # ... - + condition_tensors = [training_ds[key][:] for key in problem.conditions_keys] + + training_ds = th.utils.data.TensorDataset(training_ds["optimal_design"][:].flatten(1), *condition_tensors) + dataloader = th.utils.data.DataLoader( training_ds, batch_size=args.batch_size, shuffle=True, ) - # Training loop + # Optimzer # optimizer = th.optim.Adam(model.parameters(), lr=args.lr) # add other args if necessary - # @th.no_grad() - # def sample_designs(model: ---, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: - # ... implement + @th.no_grad() + def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]: + """Sample designs from trained model.""" + # Is that needed? + # z = th.randn((n_designs, args.latent_dim), device=device, dtype=th.float) + + # Create condition grid + all_conditions = th.stack(condition_tensors, dim=1) + linspaces = [ + th.linspace(all_conditions[:, i].min(), all_conditions[:, i].max(), n_designs, device=device) + for i in range(all_conditions.shape[1]) + ] + desired_conds = th.stack(linspaces, dim=1) + + # implement sampling from model here + gen_imgs = None + + return desired_conds, gen_imgs + - # ---------- # Training # ---------- for epoch in tqdm.trange(args.n_epochs): for i, data in enumerate(dataloader): + print(data[0]) + print(data[1:]) batch_start_time = time.time() # ... implement diff --git a/engiopt/pixel_cnn_pp_2d/print_dataset_element.py b/engiopt/pixel_cnn_pp_2d/print_dataset_element.py new file mode 100644 index 0000000..1ff47a1 --- /dev/null +++ b/engiopt/pixel_cnn_pp_2d/print_dataset_element.py @@ -0,0 +1,284 @@ +"""This code is largely based on the excellent PyTorch GAN repo: https://github.com/eriklindernoren/PyTorch-GAN. + +We essentially refreshed the Python style, use wandb for logging, and made a few little improvements. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import os +import random +import time + +from engibench.utils.all_problems import BUILTIN_PROBLEMS +import matplotlib.pyplot as plt +import numpy as np +import torch as th +from torch import nn +import tqdm +import tyro + +import wandb + + +@dataclass +class Args: + """Command-line arguments.""" + + problem_id: str = "beams2d" + """Problem identifier.""" + algo: str = os.path.basename(__file__)[: -len(".py")] + """The name of this algorithm.""" + + # Tracking + track: bool = True + """Track the experiment with wandb.""" + wandb_project: str = "engiopt" + """Wandb project name.""" + wandb_entity: str | None = None + """Wandb entity name.""" + seed: int = 1 + """Random seed.""" + save_model: bool = False + """Saves the model to disk.""" + + # Algorithm specific + n_epochs: int = 1 + """number of epochs of training""" + batch_size: int = 32 + """size of the batches""" + lr_gen: float = 0.0001 + """learning rate for the generator""" + lr_disc: float = 0.0004 + """learning rate for the discriminator""" + b1: float = 0.5 + """decay of first order momentum of gradient""" + b2: float = 0.999 + """decay of first order momentum of gradient""" + n_cpu: int = 8 + """number of cpu threads to use during batch generation""" + latent_dim: int = 32 + """dimensionality of the latent space""" + sample_interval: int = 400 + """interval between image samples""" + + +class Generator(nn.Module): + def __init__(self, latent_dim: int, design_shape: tuple): + super().__init__() + self.design_shape = design_shape # Store design shape + + def block(in_feat: int, out_feat: int, *, normalize: bool = True) -> list[nn.Module]: + layers: list[nn.Module] = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *block(latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(design_shape))), + nn.Tanh(), + ) + + def forward(self, z: th.Tensor) -> th.Tensor: + """Forward pass to generate an image from latent space.""" + img = self.model(z) + return img.view(img.size(0), *self.design_shape) + + +class Discriminator(nn.Module): + def __init__(self): + super().__init__() + + self.model = nn.Sequential( + nn.Linear(int(np.prod(design_shape)), 512), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(512, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 1), + nn.Sigmoid(), + ) + + def forward(self, img: th.Tensor) -> th.Tensor: + """Forward pass to compute the validity of an input image.""" + img_flat = img.view(img.size(0), -1) + return self.model(img_flat) + + +if __name__ == "__main__": + args = tyro.cli(Args) + + problem = BUILTIN_PROBLEMS[args.problem_id]() + problem.reset(seed=args.seed) + + design_shape = problem.design_space.shape + + # Logging + run_name = f"{args.problem_id}__{args.algo}__{args.seed}__{int(time.time())}" + if args.track: + wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), save_code=True, name=run_name) + + # Seeding + th.manual_seed(args.seed) + rng = np.random.default_rng(args.seed) + random.seed(args.seed) + th.backends.cudnn.deterministic = True + + os.makedirs("images", exist_ok=True) + + if th.backends.mps.is_available(): + device = th.device("mps") + elif th.cuda.is_available(): + device = th.device("cuda") + else: + device = th.device("cpu") + + # Loss function + adversarial_loss = th.nn.BCELoss() + + # Initialize generator and discriminator + generator = Generator(args.latent_dim, design_shape) + discriminator = Discriminator() + + generator.to(device) + discriminator.to(device) + adversarial_loss.to(device) + + # Configure data loader + training_ds = problem.dataset.with_format("torch", device=device)["train"] + + training_ds = th.utils.data.TensorDataset(training_ds["optimal_design"][:].flatten(1)) + dataloader = th.utils.data.DataLoader( + training_ds, + batch_size=args.batch_size, + shuffle=True, + ) + + # Optimizers + optimizer_generator = th.optim.Adam(generator.parameters(), lr=args.lr_gen, betas=(args.b1, args.b2)) + optimizer_discriminator = th.optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(args.b1, args.b2)) + + @th.no_grad() + def sample_designs(n_designs: int) -> th.Tensor: + """Samples n_designs from the generator.""" + # Sample noise + z = th.randn((n_designs, args.latent_dim), device=device, dtype=th.float) + return generator(z) + + # ---------- + # Training + # ---------- + for epoch in tqdm.trange(args.n_epochs): + for i, data in enumerate(dataloader): + designs = data[0] + print(designs.shape) + # Adversarial ground truths + valid = th.ones((designs.size(0), 1), requires_grad=False, device=device) + fake = th.zeros((designs.size(0), 1), requires_grad=False, device=device) + + # ----------------- + # Train Generator + # min log(1 - D(G(z))) <==> max log(D(G(z))) + # ----------------- + optimizer_generator.zero_grad() + + # Sample noise as generator input + z = th.randn((designs.size(0), args.latent_dim), device=device, dtype=th.float) + + # Generate a batch of images + gen_designs = generator(z) + + # Loss measures generator's ability to fool the discriminator + g_loss = adversarial_loss(discriminator(gen_designs), valid) + + g_loss.backward() + optimizer_generator.step() + + # --------------------- + # Train Discriminator + # max log(D(real)) + log(1 - D(G(z))) + # --------------------- + optimizer_discriminator.zero_grad() + + # Measure discriminator's ability to classify real from generated samples + real_loss = adversarial_loss(discriminator(designs), valid) + fake_loss = adversarial_loss(discriminator(gen_designs.detach()), fake) + d_loss = (real_loss + fake_loss) / 2 + + d_loss.backward() + optimizer_discriminator.step() + + # ---------- + # Logging + # ---------- + if args.track: + batches_done = epoch * len(dataloader) + i + wandb.log( + { + "d_loss": d_loss.item(), + "g_loss": g_loss.item(), + "epoch": epoch, + "batch": batches_done, + } + ) + print( + f"[Epoch {epoch}/{args.n_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]" + ) + + # This saves a grid image of 25 generated designs every sample_interval + if batches_done % args.sample_interval == 0: + # Extract 25 designs + designs = sample_designs(25) + fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + + # Flatten axes for easy indexing + axes = axes.flatten() + + # Plot each tensor as a scatter plot + for j, tensor in enumerate(designs): + img = tensor.cpu().numpy() # Extract x and y coordinates + axes[j].imshow(img) # Scatter plot + axes[j].set_xticks([]) # Hide x ticks + axes[j].set_yticks([]) # Hide y ticks + + plt.tight_layout() + img_fname = f"images/{batches_done}.png" + plt.savefig(img_fname) + plt.close() + wandb.log({"designs": wandb.Image(img_fname)}) + + # -------------- + # Save models + # -------------- + if args.save_model and epoch == args.n_epochs - 1 and i == len(dataloader) - 1: + ckpt_gen = { + "epoch": epoch, + "batches_done": batches_done, + "generator": generator.state_dict(), + "optimizer_generator": optimizer_generator.state_dict(), + "loss": g_loss.item(), + } + ckpt_disc = { + "epoch": epoch, + "batches_done": batches_done, + "discriminator": discriminator.state_dict(), + "optimizer_discriminator": optimizer_discriminator.state_dict(), + "loss": d_loss.item(), + } + + th.save(ckpt_gen, "generator.pth") + th.save(ckpt_disc, "discriminator.pth") + if args.track: + artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator", type="model") + artifact_gen.add_file("generator.pth") + artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model") + artifact_disc.add_file("discriminator.pth") + + wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"]) + wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"]) + + wandb.finish() From acc01de544d55785ae2c4a13257b7103b2813ae8 Mon Sep 17 00:00:00 2001 From: Jonas Date: Fri, 28 Nov 2025 08:58:11 +0100 Subject: [PATCH 04/31] model mostly implemented, needs testing --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 319 ++++++++++++++++++++- 1 file changed, 308 insertions(+), 11 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 7bbc22d..d48bdf6 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -4,7 +4,9 @@ import tyro from engibench.utils.all_problems import BUILTIN_PROBLEMS import torch as th +import torch.nn.functional as F from torch import nn +from torch.nn.utils import weight_norm import numpy as np import random import tqdm @@ -63,12 +65,7 @@ class Args: """Nonlinearity to use in the ResNet blocks. One of 'concat_elu', 'elu', 'relu'.""" -# IMPLEMENT PIXELCNN++ HERE -class PixelCNNpp(nn.Module): - def __init__(self): - super().__init__() - - def discretized_mix_logistic_loss(x, l): +def discretized_mix_logistic_loss(x, l): pass # """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ # xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) @@ -112,6 +109,284 @@ def discretized_mix_logistic_loss(x, l): # else: # return -tf.reduce_sum(log_sum_exp(log_probs),[1,2]) +def concat_elu(x): + """Like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU.""" + # Pytorch ordering + axis = len(x.size()) - 3 + #print(axis) + return F.elu(th.cat([x, -x], dim=axis)) + +class nin(nn.Module): + def __init__(self, nr_filters_in, nr_filters_out): + super().__init__() + self.lin_a = weight_norm(nn.Linear(nr_filters_in, nr_filters_out)) + self.nr_filters_out = nr_filters_out + + def forward(self, x): + x = x.permute(0, 2, 3, 1) # BCHW -> BHWC + xs = list(x.shape) + x = x.reshape(-1, xs[3]) # -> [B*H*W, C] + out = self.lin_a(x) + out = out.view(xs[0], xs[1], xs[2], self.nr_filters_out) + return out.permute(0, 3, 1, 2) # BHWC -> BCHW + +class GatedResnet(nn.Module): + def __init__(self, nr_filters, conv_op, resnet_nonlinearity=concat_elu, skip_connection=0, dropout_p=0.5): + super().__init__() + self.skip_connection = skip_connection + self.resnet_nonlinearity = resnet_nonlinearity + + if resnet_nonlinearity == "concat_elu": + self.filter_doubling = 2 + else: + self.filter_doubling = 1 + + self.conv_input = conv_op(self.filter_doubling * nr_filters, nr_filters) + + if skip_connection != 0: + self.nin_skip = nin(self.filter_doubling * skip_connection * nr_filters, nr_filters) + + self.dropout = nn.Dropout2d(dropout_p) + self.conv_out = conv_op(self.filter_doubling * nr_filters, 2 * nr_filters) # output has to be doubled for gating + + self.h_lin = nn.Linear(nr_conditions, 2 * nr_filters) + + + def forward(self, x, a=None, h=None): + c1 = self.conv_input(self.resnet_nonlinearity(x)) + if a is not None : + c1 += self.nin_skip(self.resnet_nonlinearity(a)) + c1 = self.resnet_nonlinearity(c1) + c1 = self.dropout(c1) + c2 = self.conv_out(c1) + if h is not None: + # in forward, when `h` is [B, nr_conditions, 1, 1] + h_flat = h.view(h.size(0), -1) # [B, nr_conditions] + h_proj = self.h_lin(h_flat).unsqueeze(-1).unsqueeze(-1) # [B, 2*nr_filters, 1, 1] + c2 += h_proj + a, b = th.chunk(c2, 2, dim=1) + c3 = a * F.sigmoid(b) + return x + c3 + +def downShift(x, pad): + x = pad(x) + return x + +def downRightShift(x, pad): + pad(x) + return x + +class DownShiftedConv2d(nn.Module): + def __init__(self, + nr_filters_in, + nr_filters_out, + filter_size=(2,3), + stride=(1,1), + shift_output_down=False): + + super().__init__() + self.pad = nn.ZeroPad2d((int((filter_size[1] - 1) / 2), int((filter_size[1] - 1) / 2), filter_size[0]-1, 0)) # padding left, right, top, bottom + self.conv = weight_norm(nn.Conv2d(nr_filters_in, nr_filters_out, filter_size, stride)) + self.shift_output_down = shift_output_down + self.down_shift = downShift + self.down_shift_pad = nn.ZeroPad2d((0,0,1,0)) + + def forward(self, x): + x = self.pad(x) + x = self.conv(x) + if self.shift_output_down: + x = self.down_shift(x, pad=self.down_shift_pad) + return x + +class DownShiftedDeconv2d(nn.Module): + def __init__(self, + nr_filters_in, + nr_filters_out, + filter_size=(2,3), + stride=(1,1)): + + super().__init__() + self.deconv = weight_norm(nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=1)) + self.filter_size = filter_size + self.stride = stride + + def forward(self, x): + x = self.deconv(x) + xs = list(x.shape) + return x[:, :, :(xs[2] - self.filter_size[0] + 1), int((self.filter_size[1] - 1) / 2):(xs[3] - int((self.filter_size[1] - 1) / 2))] + +class DownRightShiftedConv2d(nn.Module): + def __init__(self, + nr_filters_in, + nr_filters_out, + filter_size=(2,2), + stride=(1,1), + shift_output_right_down=False): + + super().__init__() + self.pad = nn.ZeroPad2d((filter_size[1] - 1, 0, filter_size[0]-1, 0)) # padding left, right, top, bottom + self.conv = weight_norm(nn.Conv2d(nr_filters_in, nr_filters_out, filter_size, stride)) + self.shift_output_right_down = shift_output_right_down + self.down_right_shift = downRightShift + self.down_right_shift_pad = nn.ZeroPad2d((1,0,0,0)) + + def forward(self, x): + x = self.pad(x) + x = self.conv(x) + if self.shift_output_right_down: + x = self.down_right_shift(x, pad=self.down_right_shift_pad) + return x + +class DownRightShiftedDeconv2d(nn.Module): + def __init__(self, + nr_filters_in, + nr_filters_out, + filter_size=(2,2), + stride=(1,1)): + + super().__init__() + self.deconv = weight_norm(nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=1)) + self.filter_size = filter_size + self.stride = stride + + def forward(self, x): + x = self.deconv(x) + xs = list(x.shape) + return x[:, :, :(xs[2] - self.filter_size[0] + 1), :(xs[3] - self.filter_size[1] + 1)] + + +# IMPLEMENT PIXELCNN++ HERE +class PixelCNNpp(nn.Module): + def __init__(self, + nr_resnet: int, + nr_filters: int, + nr_logistic_mix: int, + resnet_nonlinearity: str, + dropout_p: float, + input_channels: int = 1): + + super().__init__() + if resnet_nonlinearity == "concat_elu" : + self.resnet_nonlinearity = concat_elu + elif resnet_nonlinearity == "elu" : + self.resnet_nonlinearity = F.elu + elif resnet_nonlinearity == "relu" : + self.resnet_nonlinearity = F.relu + else: + raise Exception("Only concat elu, elu and relu are supported as resnet nonlinearity.") # noqa: TRY002 + + self.nr_resnet = nr_resnet + self.nr_filters = nr_filters + self.nr_logistic_mix = nr_logistic_mix + self.dropout_p = dropout_p + self.input_channels = input_channels + + + # UP PASS blocks + self.u_init = DownShiftedConv2d(input_channels + 1, nr_filters, filter_size=(2,3), stride=(1,1), shift_output_down=True) + self.ul_init = nn.ModuleList([DownShiftedConv2d(input_channels + 1, nr_filters, filter_size=(1,3), stride=(1,1), shift_output_down=True), + DownRightShiftedConv2d(input_channels + 1, nr_filters, filter_size=(2,1), stride=(1,1), shift_output_right_down=True)]) + + self.gated_resnet_block_u_up_1 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0) for _ in range(nr_resnet)]) + self.gated_resnet_block_ul_up_1 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) + + self.downsize_u_1 = DownShiftedConv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) + self.downsize_ul_1 = DownRightShiftedConv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) + + self.gated_resnet_block_u_up_2 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0) for _ in range(nr_resnet)]) + self.gated_resnet_block_ul_up_2 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) + + self.downsize_u_2 = DownShiftedConv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) + self.downsize_ul_2 = DownRightShiftedConv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) + + self.gated_resnet_block_u_up_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0) for _ in range(nr_resnet)]) + self.gated_resnet_block_ul_up_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) + + + # DOWN PASS blocks + self.gated_resnet_block_u_down_1 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) + self.gated_resnet_block_ul_down_1 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet)]) + + self.upsize_u_1 = DownShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) + self.upsize_ul_1 = DownRightShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) + + self.gated_resnet_block_u_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) + self.gated_resnet_block_ul_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet)]) + + self.upsize_u_2 = DownShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) + self.upsize_ul_2 = DownRightShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) + + self.gated_resnet_block_u_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) + self.gated_resnet_block_ul_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet)]) + + num_mix = 3 if self.input_channels == 1 else 10 + self.nin_out = nin(nr_filters, num_mix * nr_logistic_mix) + + def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: + xs = list(x.shape) + padding = th.ones(xs[0], 1, xs[2], xs[3], device=x.device) + x = th.cat((x, padding), dim=1) # add extra channel for padding + + # UP PASS ("encoder") + u_list = [self.u_init(x)] + ul_list = [self.ul_init[0](x) + self.ul_init[1](x)] + + for i in range(self.nr_resnet): + u_list.append(self.gated_resnet_block_u_up_1[i](u_list[-1], a=None, h=c)) + ul_list.append(self.gated_resnet_block_ul_up_1[i](ul_list[-1], a=u_list[-1], h=c)) + + u_list.append(self.downsize_u_1(u_list[-1])) + ul_list.append(self.downsize_ul_1(ul_list[-1])) + + for i in range(self.nr_resnet): + u_list.append(self.gated_resnet_block_u_up_2[i](u_list[-1], a=None, h=c)) + ul_list.append(self.gated_resnet_block_ul_up_2[i](ul_list[-1], a=u_list[-1], h=c)) + + u_list.append(self.downsize_u_2(u_list[-1])) + ul_list.append(self.downsize_ul_2(ul_list[-1])) + + for i in range(self.nr_resnet): + u_list.append(self.gated_resnet_block_u_up_3[i](u_list[-1], a=None, h=c)) + ul_list.append(self.gated_resnet_block_ul_up_3[i](ul_list[-1], a=u_list[-1], h=c)) + + # DOWN PASS ("decoder") + u = u_list.pop() + ul = ul_list.pop() + + for i in range(self.nr_resnet): + u = self.gated_resnet_block_u_down_1[i](u, a=u_list.pop(), h=c) + ul = self.gated_resnet_block_ul_down_1[i](ul, a=th.cat((u, ul_list.pop()), dim=1), h=c) + + u = self.upsize_u_1(u) + ul = self.upsize_ul_1(ul) + + for i in range(self.nr_resnet): + u = self.gated_resnet_block_u_down_2[i](u, a=u_list.pop(), h=c) + ul = self.gated_resnet_block_ul_down_2[i](ul, a=th.cat((u, ul_list.pop()), dim=1), h=c) + + u = self.upsize_u_2(u) + ul = self.upsize_ul_2(ul) + + for i in range(self.nr_resnet): + u = self.gated_resnet_block_u_down_3[i](u, a=u_list.pop(), h=c) + ul = self.gated_resnet_block_ul_down_3[i](ul, a=th.cat((u, ul_list.pop()), dim=1), h=c) + + + x_out = self.nin_out(F.elu(ul)) + + assert len(u_list) == 0 + assert len(ul_list) == 0 + + return x_out + +def log_sum_exp(x): + """Numerically stable log_sum_exp implementation that prevents overflow.""" + # TF ordering + axis = len(x.size()) - 1 + m, _ = th.max(x, dim=axis) + m2, _ = th.max(x, dim=axis, keepdim=True) + return m + th.log(th.sum(th.exp(x - m2), dim=axis)) + if __name__ == "__main__": args = tyro.cli(Args) @@ -121,7 +396,7 @@ def discretized_mix_logistic_loss(x, l): design_shape = problem.design_space.shape print(f"Design shape: {design_shape}") - conditions = problem.conditions + conditions = problem.conditions_keys nr_conditions = len(conditions) @@ -144,9 +419,10 @@ def discretized_mix_logistic_loss(x, l): device = th.device("cuda") else: device = th.device("cpu") + device = th.device("cpu") # Loss function - loss = PixelCNNpp.discretized_mix_logistic_loss + loss = discretized_mix_logistic_loss # Initialize model # ... implement @@ -158,7 +434,7 @@ def discretized_mix_logistic_loss(x, l): training_ds = problem.dataset.with_format("torch", device=device)["train"] condition_tensors = [training_ds[key][:] for key in problem.conditions_keys] - training_ds = th.utils.data.TensorDataset(training_ds["optimal_design"][:].flatten(1), *condition_tensors) + training_ds = th.utils.data.TensorDataset(training_ds["optimal_design"][:], *condition_tensors) # .flatten(1) ? dataloader = th.utils.data.DataLoader( training_ds, @@ -195,10 +471,31 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]: # ---------- for epoch in tqdm.trange(args.n_epochs): for i, data in enumerate(dataloader): - print(data[0]) - print(data[1:]) + designs = data[0].unsqueeze(dim=1) # add channel dim (for concat_elu) + + print(designs.shape) + #print(data[1:]) + + # reshape needed? + conds = th.stack((data[1:]), dim=1).reshape(-1, nr_conditions, 1, 1) + #print(conds.shape) + #print(conds) + # in PixelCNNpp.__init__ + h_lin = nn.Linear(nr_conditions, 2 * args.nr_filters) + + # in forward, when `h` is [B, nr_conditions, 1, 1] + h_flat = conds.view(conds.size(0), -1) # [B, nr_conditions] + h_proj = h_lin(h_flat).unsqueeze(-1).unsqueeze(-1) # [B, 2*nr_filters, 1, 1] + print(h_proj.shape) + print(h_proj) + + # conds = th.stack((data[1:]), dim=1).reshape(-1, nr_conditions, 1, 1) + # print(conds.shape) + # print(conds) + batch_start_time = time.time() # ... implement + # optimizer.zero_grad() # Backpropagation # loss.backward() From ca91de7b6a859acb57a4a9fc1a0483e68ffb0d04 Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 30 Nov 2025 19:49:13 +0100 Subject: [PATCH 05/31] loss function added --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 157 ++++++++++++--------- 1 file changed, 89 insertions(+), 68 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index d48bdf6..a2d236f 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -5,6 +5,7 @@ from engibench.utils.all_problems import BUILTIN_PROBLEMS import torch as th import torch.nn.functional as F +from torch.autograd import Variable from torch import nn from torch.nn.utils import weight_norm import numpy as np @@ -43,18 +44,8 @@ class Args: """number of epochs of training""" batch_size: int = 32 """size of the batches""" - lr: float = 0.0001 + lr: float = 0.001 """learning rate""" - # b1: float = 0.5 - # """decay of first order momentum of gradient""" - # b2: float = 0.999 - # """decay of first order momentum of gradient""" - # n_cpu: int = 8 - # """number of cpu threads to use during batch generation""" - # latent_dim: int = 32 - # """dimensionality of the latent space""" - # sample_interval: int = 400 - # """interval between image samples""" nr_resnet: int = 5 """Number of residual blocks per stage of the model.""" nr_filters: int = 160 @@ -63,52 +54,10 @@ class Args: """Number of logistic components in the mixture. Higher = more flexible model.""" resnet_nonlinearity: str = "concat_elu" """Nonlinearity to use in the ResNet blocks. One of 'concat_elu', 'elu', 'relu'.""" + dropout_p: float = 0.5 + """Dropout probability.""" -def discretized_mix_logistic_loss(x, l): - pass - # """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ - # xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) - # ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100) - # nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics - # logit_probs = l[:,:,:,:nr_mix] - # l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3]) - # means = l[:,:,:,:,:nr_mix] - # log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.) - # coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix]) - # x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels - # m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix]) - # m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix]) - # means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3) - # centered_x = x - means - # inv_stdv = tf.exp(-log_scales) - # plus_in = inv_stdv * (centered_x + 1./255.) - # cdf_plus = tf.nn.sigmoid(plus_in) - # min_in = inv_stdv * (centered_x - 1./255.) - # cdf_min = tf.nn.sigmoid(min_in) - # log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling) - # log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling) - # cdf_delta = cdf_plus - cdf_min # probability for all other cases - # mid_in = inv_stdv * centered_x - # log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) - - # # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) - - # # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() - # # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) - - # # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) - # # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs - # # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue - # # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value - # log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5)))) - - # log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs) - # if sum_all: - # return -tf.reduce_sum(log_sum_exp(log_probs)) - # else: - # return -tf.reduce_sum(log_sum_exp(log_probs),[1,2]) - def concat_elu(x): """Like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU.""" # Pytorch ordering @@ -387,6 +336,71 @@ def log_sum_exp(x): m2, _ = th.max(x, dim=axis, keepdim=True) return m + th.log(th.sum(th.exp(x - m2), dim=axis)) +def log_prob_from_logits(x): + """Numerically stable log_softmax implementation that prevents overflow.""" + # TF ordering + axis = len(x.size()) - 1 + m, _ = th.max(x, dim=axis, keepdim=True) + return x - m - th.log(th.sum(th.exp(x - m), dim=axis, keepdim=True)) + + +def discretized_mix_logistic_loss(x, l): + """Log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval.""" + # Pytorch ordering + x = x.permute(0, 2, 3, 1) + l = l.permute(0, 2, 3, 1) + xs = list(x.shape) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) + ls = list(l.shape) # predicted distribution, e.g. (B,32,32,100) + # different for 3 channels: / 10 + nr_mix = int(ls[-1] / 3) # here and below: unpacking the params of the mixture of logistics + logit_probs = l[:,:,:,:nr_mix] + # different for 3 channels: nr_mix * 3 #, coeff + l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # 3 for mean, scale + means = l[:,:,:,:,:nr_mix] + log_scales = th.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) + # for 3 channels: + # coeffs = F.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) + x = x.contiguous() + x = x.unsqueeze(-1) + Variable(th.zeros([*xs, nr_mix]).cuda(), requires_grad=False) + # for 3 channels: + # m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] + # * x[:, :, :, 0, :]).view(xs[0], xs[1], xs[2], 1, nr_mix) + + # m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + + # coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view(xs[0], xs[1], xs[2], 1, nr_mix) + # means = th.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3) + + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1./255.) + cdf_plus = F.sigmoid(plus_in) + min_in = inv_stdv * (centered_x - 1./255.) + cdf_min = F.sigmoid(min_in) + # log probability for edge case of 0 (before scaling) + log_cdf_plus = plus_in - F.softplus(plus_in) + # log probability for edge case of 255 (before scaling) + log_one_minus_cdf_min = -F.softplus(min_in) + cdf_delta = cdf_plus - cdf_min # probability for all other cases + mid_in = inv_stdv * centered_x + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in this code) + log_pdf_mid = mid_in - log_scales - 2.*F.softplus(mid_in) + + # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen here) + + # this is what is really done, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() + # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) + + # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) + # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs + # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue + # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value + + # log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5)))) + log_probs = th.where(x < -0.999, log_cdf_plus, th.where(x > 0.999, log_one_minus_cdf_min, th.where(cdf_delta > 1e-5, th.log(th.clamp(cdf_delta, min=1e-12)), log_pdf_mid - np.log(127.5)))) # noqa: PLR2004 + log_probs = th.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) + return -th.sum(log_sum_exp(log_probs)) + if __name__ == "__main__": args = tyro.cli(Args) @@ -425,10 +439,17 @@ def log_sum_exp(x): loss = discretized_mix_logistic_loss # Initialize model - # ... implement + model = PixelCNNpp( + nr_resnet=args.nr_resnet, + nr_filters=args.nr_filters, + nr_logistic_mix=args.nr_logistic_mix, + resnet_nonlinearity=args.resnet_nonlinearity, + dropout_p=args.dropout_p, + input_channels=1 + ) - # model.to(device) - # loss.to(device) + model.to(device) + loss.to(device) # Configure data loader training_ds = problem.dataset.with_format("torch", device=device)["train"] @@ -443,7 +464,7 @@ def log_sum_exp(x): ) # Optimzer - # optimizer = th.optim.Adam(model.parameters(), lr=args.lr) # add other args if necessary + optimizer = th.optim.Adam(model.parameters(), lr=args.lr) # add other args if necessary @th.no_grad() def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]: @@ -495,11 +516,11 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]: batch_start_time = time.time() # ... implement - # optimizer.zero_grad() + optimizer.zero_grad() # Backpropagation - # loss.backward() - # optimizer.step() + loss.backward() + optimizer.step() # ---------- @@ -509,13 +530,13 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]: batches_done = epoch * len(dataloader) + i wandb.log( { - "loss": None, #loss.item(), + "loss": loss.item(), "epoch": epoch, "batch": batches_done, } ) print( - f"[Epoch {epoch}/{args.n_epochs}] [Batch {i}/{len(dataloader)}] [loss: {None}]] [{time.time() - batch_start_time:.2f} sec]" #loss.item() + f"[Epoch {epoch}/{args.n_epochs}] [Batch {i}/{len(dataloader)}] [loss: {loss.item()}]] [{time.time() - batch_start_time:.2f} sec]" ) # This saves a grid image of 25 generated designs every sample_interval @@ -552,9 +573,9 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]: ckpt_model = { "epoch": epoch, "batches_done": batches_done, - "model": None, # model.state_dict(), - "optimizer_generator": None, # optimizer.state_dict(), - "loss": None, # loss.item(), + "model": model.state_dict(), + "optimizer_generator": optimizer.state_dict(), + "loss": loss.item(), } th.save(ckpt_model, "model.pth") @@ -564,4 +585,4 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]: wandb.log_artifact(artifact_model, aliases=[f"seed_{args.seed}"]) - wandb.finish() \ No newline at end of file + wandb.finish() From b44d7711619cfb26fa72cd218a6ecef889160d1e Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 30 Nov 2025 19:49:41 +0100 Subject: [PATCH 06/31] loss function added --- engiopt/cgan_cnn_2d/cgan_cnn_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engiopt/cgan_cnn_2d/cgan_cnn_2d.py b/engiopt/cgan_cnn_2d/cgan_cnn_2d.py index 80cf116..23119df 100644 --- a/engiopt/cgan_cnn_2d/cgan_cnn_2d.py +++ b/engiopt/cgan_cnn_2d/cgan_cnn_2d.py @@ -44,7 +44,7 @@ class Args: """Saves the model to disk.""" # Algorithm specific - n_epochs: int = 200 + n_epochs: int = 2 """number of epochs of training""" batch_size: int = 32 """size of the batches""" From 94be7a6f4f931e88cd8aeae41bcfbb6f6cb08d7c Mon Sep 17 00:00:00 2001 From: Jonas Date: Sun, 30 Nov 2025 19:58:37 +0100 Subject: [PATCH 07/31] removed unneccessary file --- .../pixel_cnn_pp_2d/print_dataset_element.py | 284 ------------------ 1 file changed, 284 deletions(-) delete mode 100644 engiopt/pixel_cnn_pp_2d/print_dataset_element.py diff --git a/engiopt/pixel_cnn_pp_2d/print_dataset_element.py b/engiopt/pixel_cnn_pp_2d/print_dataset_element.py deleted file mode 100644 index 1ff47a1..0000000 --- a/engiopt/pixel_cnn_pp_2d/print_dataset_element.py +++ /dev/null @@ -1,284 +0,0 @@ -"""This code is largely based on the excellent PyTorch GAN repo: https://github.com/eriklindernoren/PyTorch-GAN. - -We essentially refreshed the Python style, use wandb for logging, and made a few little improvements. -""" - -from __future__ import annotations - -from dataclasses import dataclass -import os -import random -import time - -from engibench.utils.all_problems import BUILTIN_PROBLEMS -import matplotlib.pyplot as plt -import numpy as np -import torch as th -from torch import nn -import tqdm -import tyro - -import wandb - - -@dataclass -class Args: - """Command-line arguments.""" - - problem_id: str = "beams2d" - """Problem identifier.""" - algo: str = os.path.basename(__file__)[: -len(".py")] - """The name of this algorithm.""" - - # Tracking - track: bool = True - """Track the experiment with wandb.""" - wandb_project: str = "engiopt" - """Wandb project name.""" - wandb_entity: str | None = None - """Wandb entity name.""" - seed: int = 1 - """Random seed.""" - save_model: bool = False - """Saves the model to disk.""" - - # Algorithm specific - n_epochs: int = 1 - """number of epochs of training""" - batch_size: int = 32 - """size of the batches""" - lr_gen: float = 0.0001 - """learning rate for the generator""" - lr_disc: float = 0.0004 - """learning rate for the discriminator""" - b1: float = 0.5 - """decay of first order momentum of gradient""" - b2: float = 0.999 - """decay of first order momentum of gradient""" - n_cpu: int = 8 - """number of cpu threads to use during batch generation""" - latent_dim: int = 32 - """dimensionality of the latent space""" - sample_interval: int = 400 - """interval between image samples""" - - -class Generator(nn.Module): - def __init__(self, latent_dim: int, design_shape: tuple): - super().__init__() - self.design_shape = design_shape # Store design shape - - def block(in_feat: int, out_feat: int, *, normalize: bool = True) -> list[nn.Module]: - layers: list[nn.Module] = [nn.Linear(in_feat, out_feat)] - if normalize: - layers.append(nn.BatchNorm1d(out_feat, 0.8)) - layers.append(nn.LeakyReLU(0.2, inplace=True)) - return layers - - self.model = nn.Sequential( - *block(latent_dim, 128, normalize=False), - *block(128, 256), - *block(256, 512), - *block(512, 1024), - nn.Linear(1024, int(np.prod(design_shape))), - nn.Tanh(), - ) - - def forward(self, z: th.Tensor) -> th.Tensor: - """Forward pass to generate an image from latent space.""" - img = self.model(z) - return img.view(img.size(0), *self.design_shape) - - -class Discriminator(nn.Module): - def __init__(self): - super().__init__() - - self.model = nn.Sequential( - nn.Linear(int(np.prod(design_shape)), 512), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(512, 256), - nn.LeakyReLU(0.2, inplace=True), - nn.Linear(256, 1), - nn.Sigmoid(), - ) - - def forward(self, img: th.Tensor) -> th.Tensor: - """Forward pass to compute the validity of an input image.""" - img_flat = img.view(img.size(0), -1) - return self.model(img_flat) - - -if __name__ == "__main__": - args = tyro.cli(Args) - - problem = BUILTIN_PROBLEMS[args.problem_id]() - problem.reset(seed=args.seed) - - design_shape = problem.design_space.shape - - # Logging - run_name = f"{args.problem_id}__{args.algo}__{args.seed}__{int(time.time())}" - if args.track: - wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), save_code=True, name=run_name) - - # Seeding - th.manual_seed(args.seed) - rng = np.random.default_rng(args.seed) - random.seed(args.seed) - th.backends.cudnn.deterministic = True - - os.makedirs("images", exist_ok=True) - - if th.backends.mps.is_available(): - device = th.device("mps") - elif th.cuda.is_available(): - device = th.device("cuda") - else: - device = th.device("cpu") - - # Loss function - adversarial_loss = th.nn.BCELoss() - - # Initialize generator and discriminator - generator = Generator(args.latent_dim, design_shape) - discriminator = Discriminator() - - generator.to(device) - discriminator.to(device) - adversarial_loss.to(device) - - # Configure data loader - training_ds = problem.dataset.with_format("torch", device=device)["train"] - - training_ds = th.utils.data.TensorDataset(training_ds["optimal_design"][:].flatten(1)) - dataloader = th.utils.data.DataLoader( - training_ds, - batch_size=args.batch_size, - shuffle=True, - ) - - # Optimizers - optimizer_generator = th.optim.Adam(generator.parameters(), lr=args.lr_gen, betas=(args.b1, args.b2)) - optimizer_discriminator = th.optim.Adam(discriminator.parameters(), lr=args.lr_disc, betas=(args.b1, args.b2)) - - @th.no_grad() - def sample_designs(n_designs: int) -> th.Tensor: - """Samples n_designs from the generator.""" - # Sample noise - z = th.randn((n_designs, args.latent_dim), device=device, dtype=th.float) - return generator(z) - - # ---------- - # Training - # ---------- - for epoch in tqdm.trange(args.n_epochs): - for i, data in enumerate(dataloader): - designs = data[0] - print(designs.shape) - # Adversarial ground truths - valid = th.ones((designs.size(0), 1), requires_grad=False, device=device) - fake = th.zeros((designs.size(0), 1), requires_grad=False, device=device) - - # ----------------- - # Train Generator - # min log(1 - D(G(z))) <==> max log(D(G(z))) - # ----------------- - optimizer_generator.zero_grad() - - # Sample noise as generator input - z = th.randn((designs.size(0), args.latent_dim), device=device, dtype=th.float) - - # Generate a batch of images - gen_designs = generator(z) - - # Loss measures generator's ability to fool the discriminator - g_loss = adversarial_loss(discriminator(gen_designs), valid) - - g_loss.backward() - optimizer_generator.step() - - # --------------------- - # Train Discriminator - # max log(D(real)) + log(1 - D(G(z))) - # --------------------- - optimizer_discriminator.zero_grad() - - # Measure discriminator's ability to classify real from generated samples - real_loss = adversarial_loss(discriminator(designs), valid) - fake_loss = adversarial_loss(discriminator(gen_designs.detach()), fake) - d_loss = (real_loss + fake_loss) / 2 - - d_loss.backward() - optimizer_discriminator.step() - - # ---------- - # Logging - # ---------- - if args.track: - batches_done = epoch * len(dataloader) + i - wandb.log( - { - "d_loss": d_loss.item(), - "g_loss": g_loss.item(), - "epoch": epoch, - "batch": batches_done, - } - ) - print( - f"[Epoch {epoch}/{args.n_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]" - ) - - # This saves a grid image of 25 generated designs every sample_interval - if batches_done % args.sample_interval == 0: - # Extract 25 designs - designs = sample_designs(25) - fig, axes = plt.subplots(5, 5, figsize=(12, 12)) - - # Flatten axes for easy indexing - axes = axes.flatten() - - # Plot each tensor as a scatter plot - for j, tensor in enumerate(designs): - img = tensor.cpu().numpy() # Extract x and y coordinates - axes[j].imshow(img) # Scatter plot - axes[j].set_xticks([]) # Hide x ticks - axes[j].set_yticks([]) # Hide y ticks - - plt.tight_layout() - img_fname = f"images/{batches_done}.png" - plt.savefig(img_fname) - plt.close() - wandb.log({"designs": wandb.Image(img_fname)}) - - # -------------- - # Save models - # -------------- - if args.save_model and epoch == args.n_epochs - 1 and i == len(dataloader) - 1: - ckpt_gen = { - "epoch": epoch, - "batches_done": batches_done, - "generator": generator.state_dict(), - "optimizer_generator": optimizer_generator.state_dict(), - "loss": g_loss.item(), - } - ckpt_disc = { - "epoch": epoch, - "batches_done": batches_done, - "discriminator": discriminator.state_dict(), - "optimizer_discriminator": optimizer_discriminator.state_dict(), - "loss": d_loss.item(), - } - - th.save(ckpt_gen, "generator.pth") - th.save(ckpt_disc, "discriminator.pth") - if args.track: - artifact_gen = wandb.Artifact(f"{args.problem_id}_{args.algo}_generator", type="model") - artifact_gen.add_file("generator.pth") - artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model") - artifact_disc.add_file("discriminator.pth") - - wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"]) - wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"]) - - wandb.finish() From 2e8e5d9cbf4551ef27f80170626fc58e8c800bf8 Mon Sep 17 00:00:00 2001 From: Jonas Date: Tue, 2 Dec 2025 09:08:12 +0100 Subject: [PATCH 08/31] partially debugged --- .../evaluate_pixel_cnn_pp_2d.py | 29 +- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 283 ++++++++++++++---- 2 files changed, 240 insertions(+), 72 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index 69805bd..487b025 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -8,6 +8,7 @@ import wandb from engiopt import metrics from engiopt.dataset_sample_conditions import sample_conditions +from engiopt.pixel_cnn_pp_2d.pixel_cnn_pp_2d import PixelCNNpp @dataclass @@ -60,13 +61,13 @@ class Args: # -------------------------------------------------------- # adapt to PixelCNN++ input shape requirements - conditions_tensor = conditions_tensor.unsqueeze(1) + conditions_tensor = conditions_tensor.unsqueeze(-1).unsqueeze(-1) - ### Set Up Diffusion Model ### + ### Set Up PixelCNN++ Model ### if args.wandb_entity is not None: - artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_diffusion_2d_cond_model:seed_{seed}" + artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_pixel_cnn_pp_2d_model:seed_{seed}" else: - artifact_path = f"{args.wandb_project}/{args.problem_id}_diffusion_2d_cond_model:seed_{seed}" + artifact_path = f"{args.wandb_project}/{args.problem_id}_pixel_cnn_pp_2d_model:seed_{seed}" api = wandb.Api() artifact = api.artifact(artifact_path, type="model") @@ -85,13 +86,27 @@ def __init__(self): # Build PixelCNN++ Model - # ... + model = PixelCNNpp( + nr_resnet=run.config["nr_resnet"], + nr_filters=run.config["nr_filters"], + nr_logistic_mix=run.config["nr_logistic_mix"], + resnet_nonlinearity=run.config["resnet_nonlinearity"], + dropout_p=run.config["dropout_p"], + input_channels=1 + ) + + model.load_state_dict(ckpt["generator"]) + model.eval() # Set to evaluation mode + model.to(device) + # Sample noise as generator input + z = th.randn((args.n_samples, run.config["latent_dim"], 1, 1), device=device, dtype=th.float) # Generate a batch of designs - # ... + gen_designs = model(z, conditions_tensor) - gen_designs_np = None, #gen_designs.detach().cpu().numpy().reshape(args.n_samples, *problem.design_space.shape) + gen_designs_np = gen_designs.detach().cpu().numpy() + gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) # Clip to boundaries for running THIS IS PROBLEM DEPENDENT gen_designs_np = np.clip(gen_designs_np, 1e-3, 1.0) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index a2d236f..7e0a7b3 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -1,19 +1,22 @@ +"""PixelCNN++ 2D model implementation. + +Provides the model classes, shifted convolutional blocks, and the +discretized mixture of logistics loss used for training and sampling. +""" from dataclasses import dataclass -from email import generator import os -import tyro +import random +import time + from engibench.utils.all_problems import BUILTIN_PROBLEMS +import matplotlib.pyplot as plt +import numpy as np import torch as th -import torch.nn.functional as F -from torch.autograd import Variable from torch import nn -from torch.nn.utils import weight_norm -import numpy as np -import random +import torch.nn.functional as F +from torch.nn.utils.parametrizations import weight_norm import tqdm -import time -import matplotlib.pyplot as plt - +import tyro import wandb @@ -27,7 +30,7 @@ class Args: """The name of this algorithm.""" # Tracking - track: bool = False + track: bool = True """Track the experiment with wandb.""" wandb_project: str = "engiopt" """Wandb project name.""" @@ -40,12 +43,18 @@ class Args: # CHANGE! # Algorithm specific - n_epochs: int = 1 + n_epochs: int = 2 """number of epochs of training""" + sample_interval: int = 1 + """interval between image samples""" batch_size: int = 32 """size of the batches""" lr: float = 0.001 """learning rate""" + b1: float = 0.95 + """decay of first order momentum of gradient""" + b2: float = 0.9995 + """decay of first order momentum of gradient""" nr_resnet: int = 5 """Number of residual blocks per stage of the model.""" nr_filters: int = 160 @@ -62,7 +71,6 @@ def concat_elu(x): """Like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU.""" # Pytorch ordering axis = len(x.size()) - 3 - #print(axis) return F.elu(th.cat([x, -x], dim=axis)) class nin(nn.Module): @@ -85,7 +93,7 @@ def __init__(self, nr_filters, conv_op, resnet_nonlinearity=concat_elu, skip_con self.skip_connection = skip_connection self.resnet_nonlinearity = resnet_nonlinearity - if resnet_nonlinearity == "concat_elu": + if resnet_nonlinearity is concat_elu: self.filter_doubling = 2 else: self.filter_doubling = 1 @@ -103,7 +111,9 @@ def __init__(self, nr_filters, conv_op, resnet_nonlinearity=concat_elu, skip_con def forward(self, x, a=None, h=None): c1 = self.conv_input(self.resnet_nonlinearity(x)) + # print(f"c1 shape: {c1.shape}") if a is not None : + # print(f"a shape: {a.shape}") c1 += self.nin_skip(self.resnet_nonlinearity(a)) c1 = self.resnet_nonlinearity(c1) c1 = self.dropout(c1) @@ -117,12 +127,17 @@ def forward(self, x, a=None, h=None): c3 = a * F.sigmoid(b) return x + c3 + def downShift(x, pad): + xs = list(x.shape) + x = x[:, :, :xs[2] - 1, :] x = pad(x) return x def downRightShift(x, pad): - pad(x) + xs = list(x.shape) + x = x[:, :, :, :xs[3] - 1] + x = pad(x) return x class DownShiftedConv2d(nn.Module): @@ -152,15 +167,31 @@ def __init__(self, nr_filters_in, nr_filters_out, filter_size=(2,3), - stride=(1,1)): + stride=(1,1), + output_padding=(0,1)): super().__init__() - self.deconv = weight_norm(nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=1)) + self.deconv = weight_norm(nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=output_padding)) self.filter_size = filter_size self.stride = stride - def forward(self, x): - x = self.deconv(x) + def forward(self, x, output_padding=None): + # Allow callers to pass a dynamic `output_padding` (some callers compute + # this to handle odd/even spatial sizes). If not provided, use the + # configured ConvTranspose2d module directly. + if output_padding is None: + x = self.deconv(x) + else: + x = F.conv_transpose2d( + x, + self.deconv.weight, + self.deconv.bias, + stride=self.stride, + padding=self.deconv.padding, + output_padding=output_padding, + dilation=self.deconv.dilation, + groups=self.deconv.groups, + ) xs = list(x.shape) return x[:, :, :(xs[2] - self.filter_size[0] + 1), int((self.filter_size[1] - 1) / 2):(xs[3] - int((self.filter_size[1] - 1) / 2))] @@ -191,15 +222,28 @@ def __init__(self, nr_filters_in, nr_filters_out, filter_size=(2,2), - stride=(1,1)): + stride=(1,1), + output_padding=(1,0)): super().__init__() - self.deconv = weight_norm(nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=1)) + self.deconv = weight_norm(nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=output_padding)) self.filter_size = filter_size self.stride = stride - def forward(self, x): - x = self.deconv(x) + def forward(self, x, output_padding=None): + if output_padding is None: + x = self.deconv(x) + else: + x = F.conv_transpose2d( + x, + self.deconv.weight, + self.deconv.bias, + stride=self.stride, + padding=self.deconv.padding, + output_padding=output_padding, + dilation=self.deconv.dilation, + groups=self.deconv.groups, + ) xs = list(x.shape) return x[:, :, :(xs[2] - self.filter_size[0] + 1), :(xs[3] - self.filter_size[1] + 1)] @@ -259,14 +303,14 @@ def __init__(self, self.upsize_u_1 = DownShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) self.upsize_ul_1 = DownRightShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) - self.gated_resnet_block_u_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) - self.gated_resnet_block_ul_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet)]) + self.gated_resnet_block_u_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet + 1)]) + self.gated_resnet_block_ul_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet + 1)]) self.upsize_u_2 = DownShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) self.upsize_ul_2 = DownRightShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) - self.gated_resnet_block_u_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) - self.gated_resnet_block_ul_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet)]) + self.gated_resnet_block_u_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet + 1)]) + self.gated_resnet_block_ul_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet + 1)]) num_mix = 3 if self.input_channels == 1 else 10 self.nin_out = nin(nr_filters, num_mix * nr_logistic_mix) @@ -276,6 +320,8 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: padding = th.ones(xs[0], 1, xs[2], xs[3], device=x.device) x = th.cat((x, padding), dim=1) # add extra channel for padding + output_padding_list = [] + # UP PASS ("encoder") u_list = [self.u_init(x)] ul_list = [self.ul_init[0](x) + self.ul_init[1](x)] @@ -287,6 +333,16 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: u_list.append(self.downsize_u_1(u_list[-1])) ul_list.append(self.downsize_ul_1(ul_list[-1])) + # Handle images with odd height/width + # print(f"Before first downsize: u.shape[2]: {u_list[-1].shape[2]}, u.shape[3]: {u_list[-1].shape[3]}, u.shape[2]: {u_list[-2].shape[2]}, u.shape[3]: {u_list[-2].shape[3]}") + pad_height = 1 + pad_width = 1 + if u_list[-2].shape[2] % u_list[-1].shape[2] != 0: + pad_height = 0 + if u_list[-2].shape[3] % u_list[-1].shape[3] != 0: + pad_width = 0 + output_padding_list.append((pad_height, pad_width)) + for i in range(self.nr_resnet): u_list.append(self.gated_resnet_block_u_up_2[i](u_list[-1], a=None, h=c)) ul_list.append(self.gated_resnet_block_ul_up_2[i](ul_list[-1], a=u_list[-1], h=c)) @@ -294,10 +350,24 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: u_list.append(self.downsize_u_2(u_list[-1])) ul_list.append(self.downsize_ul_2(ul_list[-1])) + # Handle images with odd height/width + pad_height = 1 + pad_width = 1 + if u_list[-2].shape[2] % u_list[-1].shape[2] != 0: + pad_height = 0 + if u_list[-2].shape[3] % u_list[-1].shape[3] != 0: + pad_width = 0 + output_padding_list.append((pad_height, pad_width)) + for i in range(self.nr_resnet): u_list.append(self.gated_resnet_block_u_up_3[i](u_list[-1], a=None, h=c)) ul_list.append(self.gated_resnet_block_ul_up_3[i](ul_list[-1], a=u_list[-1], h=c)) + # for i, u in enumerate(u_list): + # print(f"u_list[{i}] shape: {u.shape}") + # for i, ul in enumerate(ul_list): + # print(f"ul_list[{i}] shape: {ul.shape}") + # print(f"output_padding_list: {output_padding_list}") # DOWN PASS ("decoder") u = u_list.pop() ul = ul_list.pop() @@ -306,17 +376,21 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: u = self.gated_resnet_block_u_down_1[i](u, a=u_list.pop(), h=c) ul = self.gated_resnet_block_ul_down_1[i](ul, a=th.cat((u, ul_list.pop()), dim=1), h=c) - u = self.upsize_u_1(u) - ul = self.upsize_ul_1(ul) + #print(f"After first down pass: u shape: {u.shape}, ul shape: {ul.shape}") + u = self.upsize_u_1(u, output_padding=output_padding_list[-1]) + ul = self.upsize_ul_1(ul, output_padding=output_padding_list[-1]) + # print(f"After first upsize: u shape: {u.shape}, ul shape: {ul.shape}") - for i in range(self.nr_resnet): + for i in range(self.nr_resnet + 1): u = self.gated_resnet_block_u_down_2[i](u, a=u_list.pop(), h=c) ul = self.gated_resnet_block_ul_down_2[i](ul, a=th.cat((u, ul_list.pop()), dim=1), h=c) - u = self.upsize_u_2(u) - ul = self.upsize_ul_2(ul) + # print(f"After second down pass: u shape: {u.shape}, ul shape: {ul.shape}") + u = self.upsize_u_2(u, output_padding=output_padding_list[-2]) + ul = self.upsize_ul_2(ul, output_padding=output_padding_list[-2]) + # print(f"After second upsize: u shape: {u.shape}, ul shape: {ul.shape}") - for i in range(self.nr_resnet): + for i in range(self.nr_resnet + 1): u = self.gated_resnet_block_u_down_3[i](u, a=u_list.pop(), h=c) ul = self.gated_resnet_block_ul_down_3[i](ul, a=th.cat((u, ul_list.pop()), dim=1), h=c) @@ -361,7 +435,8 @@ def discretized_mix_logistic_loss(x, l): # for 3 channels: # coeffs = F.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) x = x.contiguous() - x = x.unsqueeze(-1) + Variable(th.zeros([*xs, nr_mix]).cuda(), requires_grad=False) + zeros = th.zeros([*xs, nr_mix], device=x.device) + x = x.unsqueeze(-1) + zeros # for 3 channels: # m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] # * x[:, :, :, 0, :]).view(xs[0], xs[1], xs[2], 1, nr_mix) @@ -402,6 +477,41 @@ def discretized_mix_logistic_loss(x, l): return -th.sum(log_sum_exp(log_probs)) +def to_one_hot(tensor, n, fill_with=1.): + # we perform one hot encore with respect to the last axis + one_hot = th.zeros((*tensor.size(), n), device=tensor.device) + one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) + return one_hot + + +def sample_from_discretized_mix_logistic(l, nr_mix): + # Pytorch ordering + l = l.permute(0, 2, 3, 1) + ls = list(l.shape) + xs = [*ls[:-1], 1] #[3] + + # unpack parameters + logit_probs = l[:, :, :, :nr_mix] + l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # for mean, scale + + # sample mixture indicator from softmax + temp = th.empty_like(logit_probs).uniform_(1e-5, 1. - 1e-5) + temp = logit_probs.detach() - th.log(- th.log(temp)) + _, argmax = temp.max(dim=3) + + one_hot = to_one_hot(argmax, nr_mix) + sel = one_hot.view([*xs[:-1], 1, nr_mix]) + # select logistic parameters + means = th.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) + log_scales = th.clamp(th.sum( + l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) + u = th.empty_like(means).uniform_(1e-5, 1. - 1e-5) + x = means + th.exp(log_scales) * (th.log(u) - th.log(1. - u)) + x0 = th.clamp(th.clamp(x[:, :, :, 0], min=-1.), max=1.) + out = x0.unsqueeze(1) + return out + + if __name__ == "__main__": args = tyro.cli(Args) @@ -436,7 +546,7 @@ def discretized_mix_logistic_loss(x, l): device = th.device("cpu") # Loss function - loss = discretized_mix_logistic_loss + loss_operator = discretized_mix_logistic_loss # Initialize model model = PixelCNNpp( @@ -449,7 +559,7 @@ def discretized_mix_logistic_loss(x, l): ) model.to(device) - loss.to(device) + # loss.to(device) # Configure data loader training_ds = problem.dataset.with_format("torch", device=device)["train"] @@ -464,26 +574,66 @@ def discretized_mix_logistic_loss(x, l): ) # Optimzer - optimizer = th.optim.Adam(model.parameters(), lr=args.lr) # add other args if necessary - - @th.no_grad() - def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]: - """Sample designs from trained model.""" - # Is that needed? - # z = th.randn((n_designs, args.latent_dim), device=device, dtype=th.float) - - # Create condition grid - all_conditions = th.stack(condition_tensors, dim=1) - linspaces = [ - th.linspace(all_conditions[:, i].min(), all_conditions[:, i].max(), n_designs, device=device) - for i in range(all_conditions.shape[1]) - ] - desired_conds = th.stack(linspaces, dim=1) + optimizer = th.optim.Adam(model.parameters(), lr=args.lr, betas=(args.b1, args.b2)) # add other args if necessary + + # @th.no_grad() + # def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: + # """Samples n_designs designs.""" + # model.eval() + # data = torch.zeros(n_designs, design_shape[0], design_shape[1], design_shape[2]) + # data = data.cuda() + # with torch.no_grad(): + # for i in range(design_shape[1]): + # for j in range(design_shape[2]): + # out = model(data, sample=True) + # out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) + # data[:, :, i, j] = out_sample.data[:, :, i, j] + # return data - # implement sampling from model here - gen_imgs = None - return desired_conds, gen_imgs + @th.no_grad() + def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: + """Samples n_designs designs using dataset conditions. + + This builds `encoder_hidden_states` by linearly interpolating each + condition between its dataset min and max, returns the sampled + designs and the `encoder_hidden_states` used. + """ + model.eval() + device = next(model.parameters()).device + + # Build per-condition min/max from the dataset tensors (device-safe) + # `condition_tensors` is defined in the outer scope above when the + # dataset is prepared: it's a list of 1-D tensors (one per condition). + all_conditions = th.stack(condition_tensors, dim=1).to(device) # [N_examples, nr_conditions] + conds_min = all_conditions.amin(dim=0) + conds_max = all_conditions.amax(dim=0) + + # Create a sweep of condition vectors between min and max (diagonal sweep) + steps = th.linspace(0.0, 1.0, n_designs, device=device).unsqueeze(1) # [n_designs, 1] + encoder_hidden_states = conds_min.unsqueeze(0) + steps * (conds_max - conds_min).unsqueeze(0) + # reshape to [B, nr_conditions, 1, 1] as expected by the model's conditional input + encoder_hidden_states = encoder_hidden_states.view(n_designs, len(problem.conditions_keys), 1, 1).to(device) + + # Prepare an empty image batch on the same device as the model + data = th.zeros((n_designs, dim, *design_shape), device=device) + print(f"final data shape: {data.shape}") + print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}") + + # Autoregressive pixel sampling: iterate spatial positions and condition on + # previously sampled pixels and the encoder_hidden_states. + with th.no_grad(): + for i in range(design_shape[0]): + for j in range(design_shape[1]): + print(f"Sampling pixel ({i}, {j})") + #out = model(data, encoder_hidden_states) + print(f"out shape: {out.shape}") + #out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) + #print(f"out_sample shape: {out_sample.shape}") + # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) + data[:, :, i, j] = None #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] + + return data, encoder_hidden_states @@ -491,33 +641,36 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]: # Training # ---------- for epoch in tqdm.trange(args.n_epochs): + model.train() for i, data in enumerate(dataloader): designs = data[0].unsqueeze(dim=1) # add channel dim (for concat_elu) print(designs.shape) #print(data[1:]) - # reshape needed? conds = th.stack((data[1:]), dim=1).reshape(-1, nr_conditions, 1, 1) - #print(conds.shape) + print(f"conds.shape: {conds.shape}") #print(conds) + + # in PixelCNNpp.__init__ - h_lin = nn.Linear(nr_conditions, 2 * args.nr_filters) + # h_lin = nn.Linear(nr_conditions, 2 * args.nr_filters) - # in forward, when `h` is [B, nr_conditions, 1, 1] - h_flat = conds.view(conds.size(0), -1) # [B, nr_conditions] - h_proj = h_lin(h_flat).unsqueeze(-1).unsqueeze(-1) # [B, 2*nr_filters, 1, 1] - print(h_proj.shape) - print(h_proj) + # # in forward, when `h` is [B, nr_conditions, 1, 1] + # h_flat = conds.view(conds.size(0), -1) # [B, nr_conditions] + # h_proj = h_lin(h_flat).unsqueeze(-1).unsqueeze(-1) # [B, 2*nr_filters, 1, 1] + # print(h_proj.shape) + # print(h_proj) # conds = th.stack((data[1:]), dim=1).reshape(-1, nr_conditions, 1, 1) # print(conds.shape) # print(conds) batch_start_time = time.time() - # ... implement + out = model(designs, conds) + # Compute loss + loss = loss_operator(designs, out) optimizer.zero_grad() - # Backpropagation loss.backward() optimizer.step() @@ -543,7 +696,7 @@ def sample_designs(n_designs: int) -> tuple[th.Tensor, th.Tensor]: if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, hidden_states = None #sample_designs(model, 25) + designs, hidden_states = sample_designs(model, design_shape, dim=1, n_designs=25) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing From 990086647b6d3cec41083e9c3bad361fad514400 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 3 Dec 2025 13:35:37 +0100 Subject: [PATCH 09/31] working pixel_cnn_pp_2d.py version --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 100 +++++++++++---------- 1 file changed, 51 insertions(+), 49 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 7e0a7b3..27aebad 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -41,13 +41,12 @@ class Args: save_model: bool = False """Saves the model to disk.""" - # CHANGE! # Algorithm specific n_epochs: int = 2 """number of epochs of training""" - sample_interval: int = 1 + sample_interval: int = 500 """interval between image samples""" - batch_size: int = 32 + batch_size: int = 8 """size of the batches""" lr: float = 0.001 """learning rate""" @@ -55,11 +54,11 @@ class Args: """decay of first order momentum of gradient""" b2: float = 0.9995 """decay of first order momentum of gradient""" - nr_resnet: int = 5 + nr_resnet: int = 2 """Number of residual blocks per stage of the model.""" - nr_filters: int = 160 + nr_filters: int = 40 """Number of filters to use across the model. Higher = larger model.""" - nr_logistic_mix: int = 10 + nr_logistic_mix: int = 5 """Number of logistic components in the mixture. Higher = more flexible model.""" resnet_nonlinearity: str = "concat_elu" """Nonlinearity to use in the ResNet blocks. One of 'concat_elu', 'elu', 'relu'.""" @@ -502,7 +501,7 @@ def sample_from_discretized_mix_logistic(l, nr_mix): one_hot = to_one_hot(argmax, nr_mix) sel = one_hot.view([*xs[:-1], 1, nr_mix]) # select logistic parameters - means = th.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) + means = th.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) log_scales = th.clamp(th.sum( l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) u = th.empty_like(means).uniform_(1e-5, 1. - 1e-5) @@ -519,7 +518,7 @@ def sample_from_discretized_mix_logistic(l, nr_mix): problem.reset(seed=args.seed) design_shape = problem.design_space.shape - print(f"Design shape: {design_shape}") + #print(f"Design shape: {design_shape}") conditions = problem.conditions_keys nr_conditions = len(conditions) @@ -543,7 +542,7 @@ def sample_from_discretized_mix_logistic(l, nr_mix): device = th.device("cuda") else: device = th.device("cpu") - device = th.device("cpu") + #device = th.device("cpu") # Loss function loss_operator = discretized_mix_logistic_loss @@ -593,47 +592,48 @@ def sample_from_discretized_mix_logistic(l, nr_mix): @th.no_grad() def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: - """Samples n_designs designs using dataset conditions. - - This builds `encoder_hidden_states` by linearly interpolating each - condition between its dataset min and max, returns the sampled - designs and the `encoder_hidden_states` used. - """ + """Samples n_designs designs using dataset conditions.""" model.eval() device = next(model.parameters()).device - # Build per-condition min/max from the dataset tensors (device-safe) - # `condition_tensors` is defined in the outer scope above when the - # dataset is prepared: it's a list of 1-D tensors (one per condition). - all_conditions = th.stack(condition_tensors, dim=1).to(device) # [N_examples, nr_conditions] - conds_min = all_conditions.amin(dim=0) - conds_max = all_conditions.amax(dim=0) + linspaces = [ + th.linspace(conds[:, i].min(), conds[:, i].max(), n_designs, device=device) for i in range(conds.shape[1]) + ] + + desired_conds = th.stack(linspaces, dim=1).reshape(-1, nr_conditions, 1, 1) + + # # Build per-condition min/max from the dataset tensors (device-safe) + # # `condition_tensors` is defined in the outer scope above when the + # # dataset is prepared: it's a list of 1-D tensors (one per condition). + # all_conditions = th.stack(condition_tensors, dim=1).to(device) # [N_examples, nr_conditions] + # conds_min = all_conditions.amin(dim=0) + # conds_max = all_conditions.amax(dim=0) - # Create a sweep of condition vectors between min and max (diagonal sweep) - steps = th.linspace(0.0, 1.0, n_designs, device=device).unsqueeze(1) # [n_designs, 1] - encoder_hidden_states = conds_min.unsqueeze(0) + steps * (conds_max - conds_min).unsqueeze(0) - # reshape to [B, nr_conditions, 1, 1] as expected by the model's conditional input - encoder_hidden_states = encoder_hidden_states.view(n_designs, len(problem.conditions_keys), 1, 1).to(device) + # # Create a sweep of condition vectors between min and max (diagonal sweep) + # steps = th.linspace(0.0, 1.0, n_designs, device=device).unsqueeze(1) # [n_designs, 1] + # encoder_hidden_states = conds_min.unsqueeze(0) + steps * (conds_max - conds_min).unsqueeze(0) + # # reshape to [B, nr_conditions, 1, 1] as expected by the model's conditional input + # encoder_hidden_states = encoder_hidden_states.view(n_designs, len(problem.conditions_keys), 1, 1).to(device) # Prepare an empty image batch on the same device as the model data = th.zeros((n_designs, dim, *design_shape), device=device) - print(f"final data shape: {data.shape}") - print(f"encoder_hidden_states shape: {encoder_hidden_states.shape}") # Autoregressive pixel sampling: iterate spatial positions and condition on - # previously sampled pixels and the encoder_hidden_states. - with th.no_grad(): - for i in range(design_shape[0]): - for j in range(design_shape[1]): - print(f"Sampling pixel ({i}, {j})") - #out = model(data, encoder_hidden_states) - print(f"out shape: {out.shape}") - #out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) - #print(f"out_sample shape: {out_sample.shape}") - # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) - data[:, :, i, j] = None #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] + # previously sampled pixels and the desired_conds. + for i in range(design_shape[0]): + for j in range(design_shape[1]): + # print(f"Sampling pixel ({i}, {j})") + out = model(data, desired_conds) + # print(f"out shape: {out.shape}") + out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) + # print(f"out_sample shape: {out_sample.shape}") + # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) + data[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] + #print(f"Completed row {i+1}/{design_shape[0]}") - return data, encoder_hidden_states + #print(f"final data shape: {data.shape}") + #print(f"desired_conds shape: {desired_conds.shape}") + return data, desired_conds @@ -643,13 +643,13 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i for epoch in tqdm.trange(args.n_epochs): model.train() for i, data in enumerate(dataloader): - designs = data[0].unsqueeze(dim=1) # add channel dim (for concat_elu) + designs = data[0].unsqueeze(dim=1) # add channel dim - print(designs.shape) + #print(designs.shape) #print(data[1:]) conds = th.stack((data[1:]), dim=1).reshape(-1, nr_conditions, 1, 1) - print(f"conds.shape: {conds.shape}") + # print(f"conds.shape: {conds.shape}") #print(conds) @@ -696,18 +696,20 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, hidden_states = sample_designs(model, design_shape, dim=1, n_designs=25) - fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=5) + fig, axes = plt.subplots(1, 5, figsize=(12, 12)) # Flatten axes for easy indexing axes = axes.flatten() # Plot the image created by each output for j, tensor in enumerate(designs): - img = tensor.cpu().numpy() # Extract x and y coordinates - dc = hidden_states[j, 0, :].cpu() - axes[j].imshow(img[0]) # image plot - title = [(problem.conditions[i][0], f"{dc[i]:.2f}") for i in range(len(problem.conditions))] + img = tensor.cpu().numpy().reshape(design_shape[0], design_shape[1]) # Extract x and y coordinates + #print(f"img shape: {img.shape}") + dc = desired_conds[j].cpu().squeeze() # Extract design conditions + #print(f"dc shape: {dc.shape}") + axes[j].imshow(img) # image plot + title = [(problem.conditions_keys[i][0], f"{dc[i]:.2f}") for i in range(nr_conditions)] title_string = "\n ".join(f"{condition}: {value}" for condition, value in title) axes[j].title.set_text(title_string) # Set title axes[j].set_xticks([]) # Hide x ticks From 52b0caab38dd98ee9988a139d5c03de59314680e Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 3 Dec 2025 14:25:32 +0100 Subject: [PATCH 10/31] working pixel_cnn_pp_2d.py model, added function to save model every n epochs --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 27aebad..8298378 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -42,10 +42,12 @@ class Args: """Saves the model to disk.""" # Algorithm specific - n_epochs: int = 2 + n_epochs: int = 4 """number of epochs of training""" sample_interval: int = 500 """interval between image samples""" + model_storage_interval: int = 2 + """interval between model storage""" batch_size: int = 8 """size of the batches""" lr: float = 0.001 @@ -724,7 +726,8 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i # -------------- # Save models # -------------- - if args.save_model and epoch == args.n_epochs - 1 and i == len(dataloader) - 1: + #if args.save_model and epoch == args.n_epochs - 1 and i == len(dataloader) - 1: + if args.save_model and (((epoch + 1) % args.model_storage_interval == 0) or (epoch == args.n_epochs - 1)) and i == len(dataloader) - 1: ckpt_model = { "epoch": epoch, "batches_done": batches_done, @@ -733,10 +736,10 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i "loss": loss.item(), } - th.save(ckpt_model, "model.pth") + th.save(ckpt_model, f"model_epoch{epoch+1}.pth") if args.track: artifact_model = wandb.Artifact(f"{args.problem_id}_{args.algo}_model", type="model") - artifact_model.add_file("model.pth") + artifact_model.add_file(f"model_epoch{epoch+1}.pth") wandb.log_artifact(artifact_model, aliases=[f"seed_{args.seed}"]) From ead782b082987f965460c84339963a701b39be1d Mon Sep 17 00:00:00 2001 From: Jonas Date: Fri, 5 Dec 2025 13:16:33 +0100 Subject: [PATCH 11/31] finished up evaluate_pixel_cnn_pp_2d.py and fixed parameter handling in pixel_cnn_pp_2d.py --- .../evaluate_pixel_cnn_pp_2d.py | 39 ++++++++++++++----- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 35 +++++++++-------- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index 487b025..63983f1 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -1,14 +1,18 @@ +"""Evaluation of the PixelCNN++.""" from dataclasses import dataclass -import pandas as pd import os + from engibench.utils.all_problems import BUILTIN_PROBLEMS import numpy as np +import pandas as pd import torch as th import tyro import wandb + from engiopt import metrics from engiopt.dataset_sample_conditions import sample_conditions from engiopt.pixel_cnn_pp_2d.pixel_cnn_pp_2d import PixelCNNpp +from engiopt.pixel_cnn_pp_2d.pixel_cnn_pp_2d import sample_from_discretized_mix_logistic @dataclass @@ -21,7 +25,7 @@ class Args: """Random seed to run.""" wandb_project: str = "engiopt" """Wandb project name.""" - wandb_entity: str | None = None + wandb_entity: str = "jstehlin-eth-z-rich" #| None = None """Wandb entity name.""" n_samples: int = 50 """Number of generated samples per seed.""" @@ -52,7 +56,7 @@ class Args: device = th.device("cpu") ### Set up testing conditions ### - conditions_tensor, sampled_conditions, sampled_designs_np, _ = sample_conditions( + conditions_tensor, sampled_conditions, sampled_designs_np, selected_indices = sample_conditions( problem=problem, n_samples=args.n_samples, device=device, @@ -62,6 +66,14 @@ class Args: # -------------------------------------------------------- # adapt to PixelCNN++ input shape requirements conditions_tensor = conditions_tensor.unsqueeze(-1).unsqueeze(-1) + # print(f"Conditions tensor shape: {conditions_tensor.shape}") + # print(conditions_tensor) + # print(f"Sampled conditions shape: {sampled_conditions.shape}") + # print(sampled_conditions) + # print(f"Sampled designs shape: {sampled_designs_np.shape}") + # print(sampled_designs_np) + # print(f"Selected indices: {selected_indices}") + design_shape = (problem.design_space.shape[0], problem.design_space.shape[1]) ### Set Up PixelCNN++ Model ### if args.wandb_entity is not None: @@ -81,7 +93,7 @@ def __init__(self): raise RunRetrievalError artifact_dir = artifact.download() - ckpt_path = os.path.join(artifact_dir, "model.pth") # change model.pth if necessary + ckpt_path = os.path.join(artifact_dir, "model_epoch400.pth") # change model.pth if necessary ckpt = th.load(ckpt_path, map_location=device) # or th.device(device) @@ -92,18 +104,27 @@ def __init__(self): nr_logistic_mix=run.config["nr_logistic_mix"], resnet_nonlinearity=run.config["resnet_nonlinearity"], dropout_p=run.config["dropout_p"], - input_channels=1 + input_channels=1, + nr_conditions=conditions_tensor.shape[1] ) - model.load_state_dict(ckpt["generator"]) + model.load_state_dict(ckpt["model"]) model.eval() # Set to evaluation mode model.to(device) - # Sample noise as generator input - z = th.randn((args.n_samples, run.config["latent_dim"], 1, 1), device=device, dtype=th.float) + # input + gen_designs = th.zeros((args.n_samples, 1, *design_shape), device=device) # Generate a batch of designs - gen_designs = model(z, conditions_tensor) + # Autoregressive pixel sampling: iterate spatial positions and condition on + # previously sampled pixels and the desired_conds. + for i in range(design_shape[0]): + for j in range(design_shape[1]): + out = model(gen_designs, conditions_tensor) + out_sample = sample_from_discretized_mix_logistic(out, run.config["nr_logistic_mix"]) + # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) + gen_designs[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] + gen_designs_np = gen_designs.detach().cpu().numpy() gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 8298378..ca59b03 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -89,7 +89,7 @@ def forward(self, x): return out.permute(0, 3, 1, 2) # BHWC -> BCHW class GatedResnet(nn.Module): - def __init__(self, nr_filters, conv_op, resnet_nonlinearity=concat_elu, skip_connection=0, dropout_p=0.5): + def __init__(self, nr_filters, conv_op, resnet_nonlinearity=concat_elu, skip_connection=0, dropout_p=0.5, nr_conditions=0): super().__init__() self.skip_connection = skip_connection self.resnet_nonlinearity = resnet_nonlinearity @@ -257,7 +257,8 @@ def __init__(self, nr_logistic_mix: int, resnet_nonlinearity: str, dropout_p: float, - input_channels: int = 1): + input_channels: int = 1, + nr_conditions: int = 0): super().__init__() if resnet_nonlinearity == "concat_elu" : @@ -272,7 +273,6 @@ def __init__(self, self.nr_resnet = nr_resnet self.nr_filters = nr_filters self.nr_logistic_mix = nr_logistic_mix - self.dropout_p = dropout_p self.input_channels = input_channels @@ -281,37 +281,37 @@ def __init__(self, self.ul_init = nn.ModuleList([DownShiftedConv2d(input_channels + 1, nr_filters, filter_size=(1,3), stride=(1,1), shift_output_down=True), DownRightShiftedConv2d(input_channels + 1, nr_filters, filter_size=(2,1), stride=(1,1), shift_output_right_down=True)]) - self.gated_resnet_block_u_up_1 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0) for _ in range(nr_resnet)]) - self.gated_resnet_block_ul_up_1 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) + self.gated_resnet_block_u_up_1 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) + self.gated_resnet_block_ul_up_1 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) self.downsize_u_1 = DownShiftedConv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) self.downsize_ul_1 = DownRightShiftedConv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) - self.gated_resnet_block_u_up_2 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0) for _ in range(nr_resnet)]) - self.gated_resnet_block_ul_up_2 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) + self.gated_resnet_block_u_up_2 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) + self.gated_resnet_block_ul_up_2 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) self.downsize_u_2 = DownShiftedConv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) self.downsize_ul_2 = DownRightShiftedConv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) - self.gated_resnet_block_u_up_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0) for _ in range(nr_resnet)]) - self.gated_resnet_block_ul_up_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) + self.gated_resnet_block_u_up_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) + self.gated_resnet_block_ul_up_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) # DOWN PASS blocks - self.gated_resnet_block_u_down_1 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet)]) - self.gated_resnet_block_ul_down_1 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet)]) + self.gated_resnet_block_u_down_1 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) + self.gated_resnet_block_ul_down_1 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) self.upsize_u_1 = DownShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) self.upsize_ul_1 = DownRightShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) - self.gated_resnet_block_u_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet + 1)]) - self.gated_resnet_block_ul_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet + 1)]) + self.gated_resnet_block_u_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet + 1)]) + self.gated_resnet_block_ul_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet + 1)]) self.upsize_u_2 = DownShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) self.upsize_ul_2 = DownRightShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) - self.gated_resnet_block_u_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1) for _ in range(nr_resnet + 1)]) - self.gated_resnet_block_ul_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2) for _ in range(nr_resnet + 1)]) + self.gated_resnet_block_u_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet + 1)]) + self.gated_resnet_block_ul_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet + 1)]) num_mix = 3 if self.input_channels == 1 else 10 self.nin_out = nin(nr_filters, num_mix * nr_logistic_mix) @@ -556,7 +556,8 @@ def sample_from_discretized_mix_logistic(l, nr_mix): nr_logistic_mix=args.nr_logistic_mix, resnet_nonlinearity=args.resnet_nonlinearity, dropout_p=args.dropout_p, - input_channels=1 + input_channels=1, + nr_conditions=nr_conditions ) model.to(device) @@ -732,7 +733,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i "epoch": epoch, "batches_done": batches_done, "model": model.state_dict(), - "optimizer_generator": optimizer.state_dict(), + "optimizer": optimizer.state_dict(), "loss": loss.item(), } From 7081abdeff18743152d8f7d5cc4c7507360152bd Mon Sep 17 00:00:00 2001 From: Jonas Date: Fri, 5 Dec 2025 15:37:50 +0100 Subject: [PATCH 12/31] Added multiple GPU support for sampling --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 107 ++++++++++----------- 1 file changed, 53 insertions(+), 54 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index ca59b03..6cfa591 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -42,9 +42,9 @@ class Args: """Saves the model to disk.""" # Algorithm specific - n_epochs: int = 4 + n_epochs: int = 100 """number of epochs of training""" - sample_interval: int = 500 + sample_interval: int = 400 """interval between image samples""" model_storage_interval: int = 2 """interval between model storage""" @@ -578,46 +578,29 @@ def sample_from_discretized_mix_logistic(l, nr_mix): # Optimzer optimizer = th.optim.Adam(model.parameters(), lr=args.lr, betas=(args.b1, args.b2)) # add other args if necessary - # @th.no_grad() - # def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: - # """Samples n_designs designs.""" - # model.eval() - # data = torch.zeros(n_designs, design_shape[0], design_shape[1], design_shape[2]) - # data = data.cuda() - # with torch.no_grad(): - # for i in range(design_shape[1]): - # for j in range(design_shape[2]): - # out = model(data, sample=True) - # out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) - # data[:, :, i, j] = out_sample.data[:, :, i, j] - # return data @th.no_grad() def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: - """Samples n_designs designs using dataset conditions.""" + """Samples n_designs designs using dataset conditions, parallelized across up to 5 GPUs.""" model.eval() device = next(model.parameters()).device + # Wrap model with DataParallel if multiple GPUs are available + num_gpus = th.cuda.device_count() + if num_gpus > 1 and device.type == "cuda": + num_gpus = min(num_gpus, 5) # Use at most 5 GPUs + parallel_model = nn.DataParallel(model, device_ids=list(range(num_gpus))) + print(f"Using {num_gpus} GPUs for sampling") + else: + parallel_model = model + linspaces = [ th.linspace(conds[:, i].min(), conds[:, i].max(), n_designs, device=device) for i in range(conds.shape[1]) ] desired_conds = th.stack(linspaces, dim=1).reshape(-1, nr_conditions, 1, 1) - # # Build per-condition min/max from the dataset tensors (device-safe) - # # `condition_tensors` is defined in the outer scope above when the - # # dataset is prepared: it's a list of 1-D tensors (one per condition). - # all_conditions = th.stack(condition_tensors, dim=1).to(device) # [N_examples, nr_conditions] - # conds_min = all_conditions.amin(dim=0) - # conds_max = all_conditions.amax(dim=0) - - # # Create a sweep of condition vectors between min and max (diagonal sweep) - # steps = th.linspace(0.0, 1.0, n_designs, device=device).unsqueeze(1) # [n_designs, 1] - # encoder_hidden_states = conds_min.unsqueeze(0) + steps * (conds_max - conds_min).unsqueeze(0) - # # reshape to [B, nr_conditions, 1, 1] as expected by the model's conditional input - # encoder_hidden_states = encoder_hidden_states.view(n_designs, len(problem.conditions_keys), 1, 1).to(device) - # Prepare an empty image batch on the same device as the model data = th.zeros((n_designs, dim, *design_shape), device=device) @@ -625,20 +608,50 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i # previously sampled pixels and the desired_conds. for i in range(design_shape[0]): for j in range(design_shape[1]): - # print(f"Sampling pixel ({i}, {j})") - out = model(data, desired_conds) - # print(f"out shape: {out.shape}") + # Use parallel_model for forward pass (DataParallel distributes batch across GPUs) + out = parallel_model(data, desired_conds) out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) - # print(f"out_sample shape: {out_sample.shape}") # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) - data[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] + data[:, :, i, j] = out_sample.data[:, :, i, j] #print(f"Completed row {i+1}/{design_shape[0]}") - #print(f"final data shape: {data.shape}") - #print(f"desired_conds shape: {desired_conds.shape}") return data, desired_conds + # old function without DataParallel + # @th.no_grad() + # def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: + # """Samples n_designs designs using dataset conditions.""" + # model.eval() + # device = next(model.parameters()).device + + # linspaces = [ + # th.linspace(conds[:, i].min(), conds[:, i].max(), n_designs, device=device) for i in range(conds.shape[1]) + # ] + + # desired_conds = th.stack(linspaces, dim=1).reshape(-1, nr_conditions, 1, 1) + + # # Prepare an empty image batch on the same device as the model + # data = th.zeros((n_designs, dim, *design_shape), device=device) + + # # Autoregressive pixel sampling: iterate spatial positions and condition on + # # previously sampled pixels and the desired_conds. + # for i in range(design_shape[0]): + # for j in range(design_shape[1]): + # # print(f"Sampling pixel ({i}, {j})") + # out = model(data, desired_conds) + # # print(f"out shape: {out.shape}") + # out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) + # # print(f"out_sample shape: {out_sample.shape}") + # # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) + # data[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] + # #print(f"Completed row {i+1}/{design_shape[0]}") + + # #print(f"final data shape: {data.shape}") + # #print(f"desired_conds shape: {desired_conds.shape}") + # return data, desired_conds + + # ---------- # Training @@ -655,20 +668,6 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i # print(f"conds.shape: {conds.shape}") #print(conds) - - # in PixelCNNpp.__init__ - # h_lin = nn.Linear(nr_conditions, 2 * args.nr_filters) - - # # in forward, when `h` is [B, nr_conditions, 1, 1] - # h_flat = conds.view(conds.size(0), -1) # [B, nr_conditions] - # h_proj = h_lin(h_flat).unsqueeze(-1).unsqueeze(-1) # [B, 2*nr_filters, 1, 1] - # print(h_proj.shape) - # print(h_proj) - - # conds = th.stack((data[1:]), dim=1).reshape(-1, nr_conditions, 1, 1) - # print(conds.shape) - # print(conds) - batch_start_time = time.time() out = model(designs, conds) # Compute loss @@ -727,8 +726,8 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i # -------------- # Save models # -------------- - #if args.save_model and epoch == args.n_epochs - 1 and i == len(dataloader) - 1: - if args.save_model and (((epoch + 1) % args.model_storage_interval == 0) or (epoch == args.n_epochs - 1)) and i == len(dataloader) - 1: + if args.save_model and epoch == args.n_epochs - 1 and i == len(dataloader) - 1: + #if args.save_model and (((epoch + 1) % args.model_storage_interval == 0) or (epoch == args.n_epochs - 1)) and i == len(dataloader) - 1: ckpt_model = { "epoch": epoch, "batches_done": batches_done, @@ -737,10 +736,10 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i "loss": loss.item(), } - th.save(ckpt_model, f"model_epoch{epoch+1}.pth") + th.save(ckpt_model, "model.pth") if args.track: artifact_model = wandb.Artifact(f"{args.problem_id}_{args.algo}_model", type="model") - artifact_model.add_file(f"model_epoch{epoch+1}.pth") + artifact_model.add_file("model.pth") wandb.log_artifact(artifact_model, aliases=[f"seed_{args.seed}"]) From bd34475a0310cc1ee90a791dc315d1e62f588e74 Mon Sep 17 00:00:00 2001 From: Jonas Date: Fri, 5 Dec 2025 15:45:57 +0100 Subject: [PATCH 13/31] increased number of generated samples to 25 --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 6cfa591..44f01e7 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -698,8 +698,8 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=5) - fig, axes = plt.subplots(1, 5, figsize=(12, 12)) + designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=25) + fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing axes = axes.flatten() From f5c475a64841d5658c1d5e217222c04c2c2fc10b Mon Sep 17 00:00:00 2001 From: Jonas Date: Fri, 5 Dec 2025 17:01:23 +0100 Subject: [PATCH 14/31] multi GPU v2 --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 154 ++++++++++++++------- 1 file changed, 104 insertions(+), 50 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 44f01e7..8a662f6 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -13,6 +13,7 @@ import numpy as np import torch as th from torch import nn +import torch.multiprocessing as mp import torch.nn.functional as F from torch.nn.utils.parametrizations import weight_norm import tqdm @@ -46,8 +47,8 @@ class Args: """number of epochs of training""" sample_interval: int = 400 """interval between image samples""" - model_storage_interval: int = 2 - """interval between model storage""" + # model_storage_interval: int = 2 + # """interval between model storage""" batch_size: int = 8 """size of the batches""" lr: float = 0.001 @@ -579,22 +580,104 @@ def sample_from_discretized_mix_logistic(l, nr_mix): optimizer = th.optim.Adam(model.parameters(), lr=args.lr, betas=(args.b1, args.b2)) # add other args if necessary + @th.no_grad() + def worker_proc(rank, ngpus, model_state, design_shape, dim, conds, n_designs_total, return_dict): + + th.cuda.set_device(rank) + device = f"cuda:{rank}" + + # Rebuild model in worker process + model = PixelCNNpp(nr_resnet=args.nr_resnet, + nr_filters=args.nr_filters, + nr_logistic_mix=args.nr_logistic_mix, + resnet_nonlinearity=args.resnet_nonlinearity, + dropout_p=args.dropout_p, + input_channels=1, + nr_conditions=conds.shape[1] + ) + model.load_state_dict(model_state) + model.to(device) + model.eval() + + # Determine this worker's batch slice + per_gpu = n_designs_total // ngpus + start = rank * per_gpu + end = (rank + 1) * per_gpu + n = per_gpu + + conds_slice = conds[start:end].to(device) + + # Build desired conditions for this worker + linspaces = [ + th.linspace(conds_slice[:, i].min(), conds_slice[:, i].max(), n, device=device) for i in range(conds_slice.shape[1]) + ] + + desired_conds = th.stack(linspaces, dim=1).reshape(n, conds.shape[1], 1, 1) + + # Allocate output batch for this worker + data = th.zeros((n, dim, *design_shape), device=device) + + # --------------------------------------------------------------------- + # Autoregressive sampling loop (runs independently on each GPU) + # --------------------------------------------------------------------- + for i in range(design_shape[0]): + for j in range(design_shape[1]): + out = model(data, desired_conds) + out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) + data[:, :, i, j] = out_sample[:, :, i, j] + + # Store CPU results to parent process + return_dict[rank] = (data.cpu(), desired_conds.cpu()) + + + # ------------------------------------------------------------------------- + # Launcher: runs once in the main process + # ------------------------------------------------------------------------- + def sample_designs_multigpu(model: PixelCNNpp, design_shape, dim, conds, n_designs=25, ngpus=None): + + if ngpus is None: + ngpus = th.cuda.device_count() + assert n_designs % ngpus == 0, "n_designs must divide evenly across GPUs" + + # Share model weights with workers + model_state = model.state_dict() + + mp.set_start_method("spawn", force=True) + manager = mp.Manager() + return_dict = manager.dict() + + mp.spawn( + worker_proc, + args=(ngpus, model_state, design_shape, dim, + conds, n_designs, return_dict), + nprocs=ngpus, + join=True + ) + + # --------------------------------------------------------------------- + # Collect results from all GPUs in order + # --------------------------------------------------------------------- + all_data = [] + all_conds = [] + + for rank in range(ngpus): + data, dc = return_dict[rank] + all_data.append(data) + all_conds.append(dc) + + all_data = th.cat(all_data, dim=0) + all_conds = th.cat(all_conds, dim=0) + + return all_data, all_conds + + # old function without parallelization @th.no_grad() def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: - """Samples n_designs designs using dataset conditions, parallelized across up to 5 GPUs.""" + """Samples n_designs designs using dataset conditions.""" model.eval() device = next(model.parameters()).device - # Wrap model with DataParallel if multiple GPUs are available - num_gpus = th.cuda.device_count() - if num_gpus > 1 and device.type == "cuda": - num_gpus = min(num_gpus, 5) # Use at most 5 GPUs - parallel_model = nn.DataParallel(model, device_ids=list(range(num_gpus))) - print(f"Using {num_gpus} GPUs for sampling") - else: - parallel_model = model - linspaces = [ th.linspace(conds[:, i].min(), conds[:, i].max(), n_designs, device=device) for i in range(conds.shape[1]) ] @@ -608,50 +691,20 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i # previously sampled pixels and the desired_conds. for i in range(design_shape[0]): for j in range(design_shape[1]): - # Use parallel_model for forward pass (DataParallel distributes batch across GPUs) - out = parallel_model(data, desired_conds) + # print(f"Sampling pixel ({i}, {j})") + out = model(data, desired_conds) + # print(f"out shape: {out.shape}") out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) + # print(f"out_sample shape: {out_sample.shape}") # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) - data[:, :, i, j] = out_sample.data[:, :, i, j] + data[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] #print(f"Completed row {i+1}/{design_shape[0]}") + #print(f"final data shape: {data.shape}") + #print(f"desired_conds shape: {desired_conds.shape}") return data, desired_conds - # old function without DataParallel - # @th.no_grad() - # def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: - # """Samples n_designs designs using dataset conditions.""" - # model.eval() - # device = next(model.parameters()).device - - # linspaces = [ - # th.linspace(conds[:, i].min(), conds[:, i].max(), n_designs, device=device) for i in range(conds.shape[1]) - # ] - - # desired_conds = th.stack(linspaces, dim=1).reshape(-1, nr_conditions, 1, 1) - - # # Prepare an empty image batch on the same device as the model - # data = th.zeros((n_designs, dim, *design_shape), device=device) - - # # Autoregressive pixel sampling: iterate spatial positions and condition on - # # previously sampled pixels and the desired_conds. - # for i in range(design_shape[0]): - # for j in range(design_shape[1]): - # # print(f"Sampling pixel ({i}, {j})") - # out = model(data, desired_conds) - # # print(f"out shape: {out.shape}") - # out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) - # # print(f"out_sample shape: {out_sample.shape}") - # # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) - # data[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] - # #print(f"Completed row {i+1}/{design_shape[0]}") - - # #print(f"final data shape: {data.shape}") - # #print(f"desired_conds shape: {desired_conds.shape}") - # return data, desired_conds - - # ---------- # Training @@ -698,7 +751,8 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=25) + #designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=25) + designs, desired_conds = sample_designs_multigpu(model, design_shape, dim=1, conds=conds, n_designs=25, ngpus=None) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing From 66a1e8f110b37f7a2cba805363ab55e010a7ea46 Mon Sep 17 00:00:00 2001 From: Jonas Date: Fri, 5 Dec 2025 17:03:08 +0100 Subject: [PATCH 15/31] added info message for multi-GPU sampling --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 8a662f6..96b31ed 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -639,7 +639,7 @@ def sample_designs_multigpu(model: PixelCNNpp, design_shape, dim, conds, n_desig ngpus = th.cuda.device_count() assert n_designs % ngpus == 0, "n_designs must divide evenly across GPUs" - + print(f"Sampling {n_designs} designs across {ngpus} GPUs...") # Share model weights with workers model_state = model.state_dict() From d6907813e148963cc185ef8da1af3a1727ef65db Mon Sep 17 00:00:00 2001 From: Jonas Date: Fri, 5 Dec 2025 23:43:41 +0100 Subject: [PATCH 16/31] model with max settings and 1 GPU --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 96b31ed..9d7a12c 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -47,8 +47,6 @@ class Args: """number of epochs of training""" sample_interval: int = 400 """interval between image samples""" - # model_storage_interval: int = 2 - # """interval between model storage""" batch_size: int = 8 """size of the batches""" lr: float = 0.001 @@ -57,11 +55,11 @@ class Args: """decay of first order momentum of gradient""" b2: float = 0.9995 """decay of first order momentum of gradient""" - nr_resnet: int = 2 + nr_resnet: int = 5 """Number of residual blocks per stage of the model.""" - nr_filters: int = 40 + nr_filters: int = 160 """Number of filters to use across the model. Higher = larger model.""" - nr_logistic_mix: int = 5 + nr_logistic_mix: int = 10 """Number of logistic components in the mixture. Higher = more flexible model.""" resnet_nonlinearity: str = "concat_elu" """Nonlinearity to use in the ResNet blocks. One of 'concat_elu', 'elu', 'relu'.""" @@ -751,8 +749,8 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i if batches_done % args.sample_interval == 0: # Extract 25 designs - #designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=25) - designs, desired_conds = sample_designs_multigpu(model, design_shape, dim=1, conds=conds, n_designs=25, ngpus=None) + designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=25) + # designs, desired_conds = sample_designs_multigpu(model, design_shape, dim=1, conds=conds, n_designs=25, ngpus=None) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing From 1ef0254ea36e9060e1675a06a2ac75e8fba3a298 Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 8 Dec 2025 13:41:27 +0100 Subject: [PATCH 17/31] fixed scaling (input/output) and removed multi GPU sampling --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 119 ++------------------- 1 file changed, 11 insertions(+), 108 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 9d7a12c..ad4f780 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -13,7 +13,6 @@ import numpy as np import torch as th from torch import nn -import torch.multiprocessing as mp import torch.nn.functional as F from torch.nn.utils.parametrizations import weight_norm import tqdm @@ -43,9 +42,9 @@ class Args: """Saves the model to disk.""" # Algorithm specific - n_epochs: int = 100 + n_epochs: int = 2 """number of epochs of training""" - sample_interval: int = 400 + sample_interval: int = 2000 """interval between image samples""" batch_size: int = 8 """size of the batches""" @@ -57,7 +56,7 @@ class Args: """decay of first order momentum of gradient""" nr_resnet: int = 5 """Number of residual blocks per stage of the model.""" - nr_filters: int = 160 + nr_filters: int = 120 """Number of filters to use across the model. Higher = larger model.""" nr_logistic_mix: int = 10 """Number of logistic components in the mixture. Higher = more flexible model.""" @@ -363,11 +362,7 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: u_list.append(self.gated_resnet_block_u_up_3[i](u_list[-1], a=None, h=c)) ul_list.append(self.gated_resnet_block_ul_up_3[i](ul_list[-1], a=u_list[-1], h=c)) - # for i, u in enumerate(u_list): - # print(f"u_list[{i}] shape: {u.shape}") - # for i, ul in enumerate(ul_list): - # print(f"ul_list[{i}] shape: {ul.shape}") - # print(f"output_padding_list: {output_padding_list}") + # DOWN PASS ("decoder") u = u_list.pop() ul = ul_list.pop() @@ -429,7 +424,7 @@ def discretized_mix_logistic_loss(x, l): nr_mix = int(ls[-1] / 3) # here and below: unpacking the params of the mixture of logistics logit_probs = l[:,:,:,:nr_mix] # different for 3 channels: nr_mix * 3 #, coeff - l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # 3 for mean, scale + l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # 2 for mean, scale means = l[:,:,:,:,:nr_mix] log_scales = th.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) # for 3 channels: @@ -577,99 +572,6 @@ def sample_from_discretized_mix_logistic(l, nr_mix): # Optimzer optimizer = th.optim.Adam(model.parameters(), lr=args.lr, betas=(args.b1, args.b2)) # add other args if necessary - - @th.no_grad() - def worker_proc(rank, ngpus, model_state, design_shape, dim, conds, n_designs_total, return_dict): - - th.cuda.set_device(rank) - device = f"cuda:{rank}" - - # Rebuild model in worker process - model = PixelCNNpp(nr_resnet=args.nr_resnet, - nr_filters=args.nr_filters, - nr_logistic_mix=args.nr_logistic_mix, - resnet_nonlinearity=args.resnet_nonlinearity, - dropout_p=args.dropout_p, - input_channels=1, - nr_conditions=conds.shape[1] - ) - model.load_state_dict(model_state) - model.to(device) - model.eval() - - # Determine this worker's batch slice - per_gpu = n_designs_total // ngpus - start = rank * per_gpu - end = (rank + 1) * per_gpu - n = per_gpu - - conds_slice = conds[start:end].to(device) - - # Build desired conditions for this worker - linspaces = [ - th.linspace(conds_slice[:, i].min(), conds_slice[:, i].max(), n, device=device) for i in range(conds_slice.shape[1]) - ] - - desired_conds = th.stack(linspaces, dim=1).reshape(n, conds.shape[1], 1, 1) - - # Allocate output batch for this worker - data = th.zeros((n, dim, *design_shape), device=device) - - # --------------------------------------------------------------------- - # Autoregressive sampling loop (runs independently on each GPU) - # --------------------------------------------------------------------- - for i in range(design_shape[0]): - for j in range(design_shape[1]): - out = model(data, desired_conds) - out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) - data[:, :, i, j] = out_sample[:, :, i, j] - - # Store CPU results to parent process - return_dict[rank] = (data.cpu(), desired_conds.cpu()) - - - # ------------------------------------------------------------------------- - # Launcher: runs once in the main process - # ------------------------------------------------------------------------- - def sample_designs_multigpu(model: PixelCNNpp, design_shape, dim, conds, n_designs=25, ngpus=None): - - if ngpus is None: - ngpus = th.cuda.device_count() - - assert n_designs % ngpus == 0, "n_designs must divide evenly across GPUs" - print(f"Sampling {n_designs} designs across {ngpus} GPUs...") - # Share model weights with workers - model_state = model.state_dict() - - mp.set_start_method("spawn", force=True) - manager = mp.Manager() - return_dict = manager.dict() - - mp.spawn( - worker_proc, - args=(ngpus, model_state, design_shape, dim, - conds, n_designs, return_dict), - nprocs=ngpus, - join=True - ) - - # --------------------------------------------------------------------- - # Collect results from all GPUs in order - # --------------------------------------------------------------------- - all_data = [] - all_conds = [] - - for rank in range(ngpus): - data, dc = return_dict[rank] - all_data.append(data) - all_conds.append(dc) - - all_data = th.cat(all_data, dim=0) - all_conds = th.cat(all_conds, dim=0) - - return all_data, all_conds - - # old function without parallelization @th.no_grad() def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: """Samples n_designs designs using dataset conditions.""" @@ -711,6 +613,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i model.train() for i, data in enumerate(dataloader): designs = data[0].unsqueeze(dim=1) # add channel dim + designs_rescaled = designs * 2. - 1. # rescale to [-1, 1] #print(designs.shape) #print(data[1:]) @@ -720,9 +623,9 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i #print(conds) batch_start_time = time.time() - out = model(designs, conds) + out = model(designs_rescaled, conds) # Compute loss - loss = loss_operator(designs, out) + loss = loss_operator(designs_rescaled, out) optimizer.zero_grad() # Backpropagation loss.backward() @@ -749,8 +652,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=25) - # designs, desired_conds = sample_designs_multigpu(model, design_shape, dim=1, conds=conds, n_designs=25, ngpus=None) + designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=1) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing @@ -758,7 +660,8 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i # Plot the image created by each output for j, tensor in enumerate(designs): - img = tensor.cpu().numpy().reshape(design_shape[0], design_shape[1]) # Extract x and y coordinates + tensor_rescaled = (tensor + 1.) / 2. # rescale to [0, 1] + img = tensor_rescaled.cpu().numpy().reshape(design_shape[0], design_shape[1]) # Extract x and y coordinates #print(f"img shape: {img.shape}") dc = desired_conds[j].cpu().squeeze() # Extract design conditions #print(f"dc shape: {dc.shape}") From eeb181253a14fdd9baa29121574fe365400a6820 Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 8 Dec 2025 14:00:50 +0100 Subject: [PATCH 18/31] sample interval 600 --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index ad4f780..19c6664 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -44,7 +44,7 @@ class Args: # Algorithm specific n_epochs: int = 2 """number of epochs of training""" - sample_interval: int = 2000 + sample_interval: int = 600 """interval between image samples""" batch_size: int = 8 """size of the batches""" @@ -652,7 +652,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=1) + designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=25) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing From 811167a6830c65be763ba7d1a72c4252a24be29b Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 10 Dec 2025 08:52:09 +0100 Subject: [PATCH 19/31] updated evaluate_pixel_cnn_pp_2d.py --- engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index 63983f1..8be3db2 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -25,7 +25,7 @@ class Args: """Random seed to run.""" wandb_project: str = "engiopt" """Wandb project name.""" - wandb_entity: str = "jstehlin-eth-z-rich" #| None = None + wandb_entity: str | None = None """Wandb entity name.""" n_samples: int = 50 """Number of generated samples per seed.""" @@ -93,7 +93,7 @@ def __init__(self): raise RunRetrievalError artifact_dir = artifact.download() - ckpt_path = os.path.join(artifact_dir, "model_epoch400.pth") # change model.pth if necessary + ckpt_path = os.path.join(artifact_dir, "model.pth") # change model.pth if necessary ckpt = th.load(ckpt_path, map_location=device) # or th.device(device) From e3081c33e578b94126d7f16c93c970fd064f5818 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 10 Dec 2025 08:52:44 +0100 Subject: [PATCH 20/31] added some debugging lines --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 19c6664..e432465 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -54,9 +54,9 @@ class Args: """decay of first order momentum of gradient""" b2: float = 0.9995 """decay of first order momentum of gradient""" - nr_resnet: int = 5 + nr_resnet: int = 3 """Number of residual blocks per stage of the model.""" - nr_filters: int = 120 + nr_filters: int = 40 """Number of filters to use across the model. Higher = larger model.""" nr_logistic_mix: int = 10 """Number of logistic components in the mixture. Higher = more flexible model.""" @@ -453,7 +453,7 @@ def discretized_mix_logistic_loss(x, l): cdf_delta = cdf_plus - cdf_min # probability for all other cases mid_in = inv_stdv * centered_x # log probability in the center of the bin, to be used in extreme cases - # (not actually used in this code) + # (likely not used in this code) log_pdf_mid = mid_in - log_scales - 2.*F.softplus(mid_in) # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen here) @@ -473,9 +473,11 @@ def discretized_mix_logistic_loss(x, l): def to_one_hot(tensor, n, fill_with=1.): - # we perform one hot encore with respect to the last axis + # we perform one hot encode with respect to the last axis one_hot = th.zeros((*tensor.size(), n), device=tensor.device) + print(f"one_hot.shape: {one_hot.shape}") one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) + print(f"one_hot.shape after scatter: {one_hot.shape}\none_hot: {one_hot}") return one_hot @@ -491,11 +493,17 @@ def sample_from_discretized_mix_logistic(l, nr_mix): # sample mixture indicator from softmax temp = th.empty_like(logit_probs).uniform_(1e-5, 1. - 1e-5) + print(f"temp.size: {temp.size()}\ntemp: {temp}") + print(f"-th.log(- th.log(temp)).shape: {th.log(- th.log(temp)).shape}\n-th.log(- th.log(temp)): {- th.log(- th.log(temp))}") temp = logit_probs.detach() - th.log(- th.log(temp)) + print(f"temp.size: {temp.size()}\ntemp: {temp}") _, argmax = temp.max(dim=3) - + print(f"argmax.shape: {argmax.shape}\nargmax: {argmax}") one_hot = to_one_hot(argmax, nr_mix) + print(f"one_hot.shape: {one_hot.shape}\none_hot: {one_hot}") sel = one_hot.view([*xs[:-1], 1, nr_mix]) + print(f"sel.shape: {sel.shape}\nsel: {sel}") + time.sleep(100000) # select logistic parameters means = th.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) log_scales = th.clamp(th.sum( @@ -652,7 +660,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=25) + designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=1) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing From c5a143a9d59478f5c43f17433846b6eaea66bdac Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 10 Dec 2025 10:11:08 +0100 Subject: [PATCH 21/31] added sampling in batches for pixel_cnn_pp_2d --- .../evaluate_pixel_cnn_pp_2d.py | 60 +++++++++---- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 85 ++++++++++++++----- 2 files changed, 107 insertions(+), 38 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index 8be3db2..1ed8468 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -23,6 +23,8 @@ class Args: """Problem identifier.""" seed: int = 1 """Random seed to run.""" + sampling_batch_size: int = 10 + """Batch size to use during sampling.""" wandb_project: str = "engiopt" """Wandb project name.""" wandb_entity: str | None = None @@ -66,13 +68,6 @@ class Args: # -------------------------------------------------------- # adapt to PixelCNN++ input shape requirements conditions_tensor = conditions_tensor.unsqueeze(-1).unsqueeze(-1) - # print(f"Conditions tensor shape: {conditions_tensor.shape}") - # print(conditions_tensor) - # print(f"Sampled conditions shape: {sampled_conditions.shape}") - # print(sampled_conditions) - # print(f"Sampled designs shape: {sampled_designs_np.shape}") - # print(sampled_designs_np) - # print(f"Selected indices: {selected_indices}") design_shape = (problem.design_space.shape[0], problem.design_space.shape[1]) ### Set Up PixelCNN++ Model ### @@ -112,18 +107,47 @@ def __init__(self): model.eval() # Set to evaluation mode model.to(device) + + batch_size = args.sampling_batch_size + + all_batches: list[th.Tensor] = [] + + for start in range(0, args.n_samples, batch_size): + end = min(args.n_samples, start + batch_size) + b = end - start + + # prepare batch-local tensors on the same device as the model + batch_conds = conditions_tensor[start:end] + data = th.zeros((b, 1, *design_shape), device=device) + + # Autoregressive pixel sampling for this batch + for i in range(design_shape[0]): + for j in range(design_shape[1]): + out = model(data, batch_conds) + out_sample = sample_from_discretized_mix_logistic(out, run.config["nr_logistic_mix"]) + data[:, :, i, j] = out_sample.data[:, :, i, j] + + # move completed batch to CPU to free GPU memory and store + all_batches.append(data.cpu()) + + # concatenate all batches on CPU and return desired_conds on CPU as well + gen_designs = th.cat(all_batches, dim=0) + + print(gen_designs.shape) + + # old sampling # input - gen_designs = th.zeros((args.n_samples, 1, *design_shape), device=device) - - # Generate a batch of designs - # Autoregressive pixel sampling: iterate spatial positions and condition on - # previously sampled pixels and the desired_conds. - for i in range(design_shape[0]): - for j in range(design_shape[1]): - out = model(gen_designs, conditions_tensor) - out_sample = sample_from_discretized_mix_logistic(out, run.config["nr_logistic_mix"]) - # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) - gen_designs[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] + # gen_designs = th.zeros((args.n_samples, 1, *design_shape), device=device) + + # # Generate a batch of designs + # # Autoregressive pixel sampling: iterate spatial positions and condition on + # # previously sampled pixels and the desired_conds. + # for i in range(design_shape[0]): + # for j in range(design_shape[1]): + # out = model(gen_designs, conditions_tensor) + # out_sample = sample_from_discretized_mix_logistic(out, run.config["nr_logistic_mix"]) + # # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) + # gen_designs[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] gen_designs_np = gen_designs.detach().cpu().numpy() diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index e432465..8407050 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -48,6 +48,8 @@ class Args: """interval between image samples""" batch_size: int = 8 """size of the batches""" + sampling_batch_size: int = 10 + """Batch size to use during sampling.""" lr: float = 0.001 """learning rate""" b1: float = 0.95 @@ -581,36 +583,79 @@ def sample_from_discretized_mix_logistic(l, nr_mix): optimizer = th.optim.Adam(model.parameters(), lr=args.lr, betas=(args.b1, args.b2)) # add other args if necessary @th.no_grad() - def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: + def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25, sampling_batch_size: int = 10) -> tuple[th.Tensor, th.Tensor]: """Samples n_designs designs using dataset conditions.""" model.eval() device = next(model.parameters()).device - + # Build the full list of requested condition combinations (on the model device) linspaces = [ th.linspace(conds[:, i].min(), conds[:, i].max(), n_designs, device=device) for i in range(conds.shape[1]) ] desired_conds = th.stack(linspaces, dim=1).reshape(-1, nr_conditions, 1, 1) - # Prepare an empty image batch on the same device as the model - data = th.zeros((n_designs, dim, *design_shape), device=device) + # If n_designs is large, sample in smaller batches to reduce GPU memory use. + # The default behavior previously sampled all at once; we keep that by + # allowing the caller to set `batch_size`. If batch_size >= n_designs we + # behave exactly as before. + batch_size = sampling_batch_size + + all_batches: list[th.Tensor] = [] + + for start in range(0, n_designs, batch_size): + end = min(n_designs, start + batch_size) + b = end - start + + # prepare batch-local tensors on the same device as the model + batch_conds = desired_conds[start:end] + data = th.zeros((b, dim, *design_shape), device=device) + + # Autoregressive pixel sampling for this batch + for i in range(design_shape[0]): + for j in range(design_shape[1]): + out = model(data, batch_conds) + out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) + data[:, :, i, j] = out_sample.data[:, :, i, j] + + # move completed batch to CPU to free GPU memory and store + all_batches.append(data.cpu()) + + # concatenate all batches on CPU and return desired_conds on CPU as well + data_all = th.cat(all_batches, dim=0) + return data_all, desired_conds.cpu() + + # original version without batching + # @th.no_grad() + # def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: + # """Samples n_designs designs using dataset conditions.""" + # model.eval() + # device = next(model.parameters()).device + + # linspaces = [ + # th.linspace(conds[:, i].min(), conds[:, i].max(), n_designs, device=device) for i in range(conds.shape[1]) + # ] + + # desired_conds = th.stack(linspaces, dim=1).reshape(-1, nr_conditions, 1, 1) + + # # Prepare an empty image batch on the same device as the model + # data = th.zeros((n_designs, dim, *design_shape), device=device) - # Autoregressive pixel sampling: iterate spatial positions and condition on - # previously sampled pixels and the desired_conds. - for i in range(design_shape[0]): - for j in range(design_shape[1]): - # print(f"Sampling pixel ({i}, {j})") - out = model(data, desired_conds) - # print(f"out shape: {out.shape}") - out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) - # print(f"out_sample shape: {out_sample.shape}") - # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) - data[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] - #print(f"Completed row {i+1}/{design_shape[0]}") + # # Autoregressive pixel sampling: iterate spatial positions and condition on + # # previously sampled pixels and the desired_conds. + # for i in range(design_shape[0]): + # for j in range(design_shape[1]): + # # print(f"Sampling pixel ({i}, {j})") + # out = model(data, desired_conds) + # # print(f"out shape: {out.shape}") + # out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) + # # print(f"out_sample shape: {out_sample.shape}") + # # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) + # data[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] + # #print(f"Completed row {i+1}/{design_shape[0]}") - #print(f"final data shape: {data.shape}") - #print(f"desired_conds shape: {desired_conds.shape}") - return data, desired_conds + # #print(f"final data shape: {data.shape}") + # #print(f"desired_conds shape: {desired_conds.shape}") + # return data, desired_conds @@ -660,7 +705,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=1) + designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=1, sampling_batch_size=args.sampling_batch_size) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing From b73b42059d772e6e743221ef6e65b953e29a5b40 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 10 Dec 2025 11:24:33 +0100 Subject: [PATCH 22/31] removed comments and time.sleep() --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 8407050..73495d6 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -477,9 +477,9 @@ def discretized_mix_logistic_loss(x, l): def to_one_hot(tensor, n, fill_with=1.): # we perform one hot encode with respect to the last axis one_hot = th.zeros((*tensor.size(), n), device=tensor.device) - print(f"one_hot.shape: {one_hot.shape}") + #print(f"one_hot.shape: {one_hot.shape}") one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) - print(f"one_hot.shape after scatter: {one_hot.shape}\none_hot: {one_hot}") + #print(f"one_hot.shape after scatter: {one_hot.shape}\none_hot: {one_hot}") return one_hot @@ -495,17 +495,17 @@ def sample_from_discretized_mix_logistic(l, nr_mix): # sample mixture indicator from softmax temp = th.empty_like(logit_probs).uniform_(1e-5, 1. - 1e-5) - print(f"temp.size: {temp.size()}\ntemp: {temp}") - print(f"-th.log(- th.log(temp)).shape: {th.log(- th.log(temp)).shape}\n-th.log(- th.log(temp)): {- th.log(- th.log(temp))}") + #print(f"temp.size: {temp.size()}\ntemp: {temp}") + #print(f"-th.log(- th.log(temp)).shape: {th.log(- th.log(temp)).shape}\n-th.log(- th.log(temp)): {- th.log(- th.log(temp))}") temp = logit_probs.detach() - th.log(- th.log(temp)) - print(f"temp.size: {temp.size()}\ntemp: {temp}") + #print(f"temp.size: {temp.size()}\ntemp: {temp}") _, argmax = temp.max(dim=3) - print(f"argmax.shape: {argmax.shape}\nargmax: {argmax}") + #print(f"argmax.shape: {argmax.shape}\nargmax: {argmax}") one_hot = to_one_hot(argmax, nr_mix) - print(f"one_hot.shape: {one_hot.shape}\none_hot: {one_hot}") + #print(f"one_hot.shape: {one_hot.shape}\none_hot: {one_hot}") sel = one_hot.view([*xs[:-1], 1, nr_mix]) - print(f"sel.shape: {sel.shape}\nsel: {sel}") - time.sleep(100000) + #print(f"sel.shape: {sel.shape}\nsel: {sel}") + #time.sleep(100000) # select logistic parameters means = th.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) log_scales = th.clamp(th.sum( From 93e34ec36d7d9dd4d3e022a90cc1ddbbbf6031fd Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 10 Dec 2025 11:31:21 +0100 Subject: [PATCH 23/31] added info text --- engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index 1ed8468..88c7d66 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -129,6 +129,7 @@ def __init__(self): # move completed batch to CPU to free GPU memory and store all_batches.append(data.cpu()) + print(f"Sampled batch {start} to {end} / {args.n_samples}") # concatenate all batches on CPU and return desired_conds on CPU as well gen_designs = th.cat(all_batches, dim=0) From c246690792bc9af96d29a1cb75d5d1e3d4a17383 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 10 Dec 2025 11:35:36 +0100 Subject: [PATCH 24/31] another print statement --- engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index 88c7d66..7ac905b 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -119,6 +119,7 @@ def __init__(self): # prepare batch-local tensors on the same device as the model batch_conds = conditions_tensor[start:end] data = th.zeros((b, 1, *design_shape), device=device) + print(data.shape) # Autoregressive pixel sampling for this batch for i in range(design_shape[0]): From e8fbf365a0994a3c345e93616587e98c0814dab6 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 10 Dec 2025 11:39:00 +0100 Subject: [PATCH 25/31] lowered sampling batch size to reduce memory consumption during evaluation --- engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index 7ac905b..981abeb 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -23,7 +23,7 @@ class Args: """Problem identifier.""" seed: int = 1 """Random seed to run.""" - sampling_batch_size: int = 10 + sampling_batch_size: int = 5 """Batch size to use during sampling.""" wandb_project: str = "engiopt" """Wandb project name.""" From fd2bb14a76eb5ec07b05631f4fa3ba8ddedd9874 Mon Sep 17 00:00:00 2001 From: Jonas Date: Fri, 12 Dec 2025 15:47:07 +0100 Subject: [PATCH 26/31] nr of samples changed to 25 --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 73495d6..098f7f7 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -705,7 +705,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=1, sampling_batch_size=args.sampling_batch_size) + designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=25, sampling_batch_size=args.sampling_batch_size) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing From b2858def7ee6b8e5fe41676cd7c6092e2529a720 Mon Sep 17 00:00:00 2001 From: Jonas Date: Mon, 15 Dec 2025 16:21:55 +0100 Subject: [PATCH 27/31] evaluation preparation and removal of some debug prints --- .../evaluate_pixel_cnn_pp_2d.py | 16 +------ engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 47 ++++++------------- 2 files changed, 16 insertions(+), 47 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index 981abeb..cd3b0aa 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -89,7 +89,7 @@ def __init__(self): artifact_dir = artifact.download() ckpt_path = os.path.join(artifact_dir, "model.pth") # change model.pth if necessary - ckpt = th.load(ckpt_path, map_location=device) # or th.device(device) + ckpt = th.load(ckpt_path, map_location=device) # Build PixelCNN++ Model @@ -137,20 +137,6 @@ def __init__(self): print(gen_designs.shape) - # old sampling - # input - # gen_designs = th.zeros((args.n_samples, 1, *design_shape), device=device) - - # # Generate a batch of designs - # # Autoregressive pixel sampling: iterate spatial positions and condition on - # # previously sampled pixels and the desired_conds. - # for i in range(design_shape[0]): - # for j in range(design_shape[1]): - # out = model(gen_designs, conditions_tensor) - # out_sample = sample_from_discretized_mix_logistic(out, run.config["nr_logistic_mix"]) - # # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) - # gen_designs[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] - gen_designs_np = gen_designs.detach().cpu().numpy() gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 098f7f7..3841305 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -46,7 +46,7 @@ class Args: """number of epochs of training""" sample_interval: int = 600 """interval between image samples""" - batch_size: int = 8 + batch_size: int = 4 """size of the batches""" sampling_batch_size: int = 10 """Batch size to use during sampling.""" @@ -56,9 +56,9 @@ class Args: """decay of first order momentum of gradient""" b2: float = 0.9995 """decay of first order momentum of gradient""" - nr_resnet: int = 3 + nr_resnet: int = 2 """Number of residual blocks per stage of the model.""" - nr_filters: int = 40 + nr_filters: int = 30 """Number of filters to use across the model. Higher = larger model.""" nr_logistic_mix: int = 10 """Number of logistic components in the mixture. Higher = more flexible model.""" @@ -417,30 +417,21 @@ def log_prob_from_logits(x): def discretized_mix_logistic_loss(x, l): """Log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval.""" - # Pytorch ordering + # Tensorflow ordering x = x.permute(0, 2, 3, 1) + print(f"x shape in loss: {x.shape}") l = l.permute(0, 2, 3, 1) - xs = list(x.shape) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) - ls = list(l.shape) # predicted distribution, e.g. (B,32,32,100) - # different for 3 channels: / 10 + print(f"l shape in loss: {l.shape}") + xs = list(x.shape) # true image (i.e. labels) to regress to, e.g. (B,width,height,3) + ls = list(l.shape) # predicted distribution, e.g. (B,width,height,30) nr_mix = int(ls[-1] / 3) # here and below: unpacking the params of the mixture of logistics logit_probs = l[:,:,:,:nr_mix] - # different for 3 channels: nr_mix * 3 #, coeff l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # 2 for mean, scale means = l[:,:,:,:,:nr_mix] log_scales = th.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) - # for 3 channels: - # coeffs = F.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) x = x.contiguous() zeros = th.zeros([*xs, nr_mix], device=x.device) x = x.unsqueeze(-1) + zeros - # for 3 channels: - # m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] - # * x[:, :, :, 0, :]).view(xs[0], xs[1], xs[2], 1, nr_mix) - - # m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + - # coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view(xs[0], xs[1], xs[2], 1, nr_mix) - # means = th.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3) centered_x = x - means inv_stdv = th.exp(-log_scales) @@ -468,9 +459,10 @@ def discretized_mix_logistic_loss(x, l): # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value - # log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5)))) log_probs = th.where(x < -0.999, log_cdf_plus, th.where(x > 0.999, log_one_minus_cdf_min, th.where(cdf_delta > 1e-5, th.log(th.clamp(cdf_delta, min=1e-12)), log_pdf_mid - np.log(127.5)))) # noqa: PLR2004 + print(f"log_probs shape before sum: {log_probs.shape}") log_probs = th.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) + time.sleep(100000) return -th.sum(log_sum_exp(log_probs)) @@ -484,10 +476,10 @@ def to_one_hot(tensor, n, fill_with=1.): def sample_from_discretized_mix_logistic(l, nr_mix): - # Pytorch ordering + # Tensorflow ordering l = l.permute(0, 2, 3, 1) ls = list(l.shape) - xs = [*ls[:-1], 1] #[3] + xs = [*ls[:-1], 1] # unpack parameters logit_probs = l[:, :, :, :nr_mix] @@ -495,17 +487,11 @@ def sample_from_discretized_mix_logistic(l, nr_mix): # sample mixture indicator from softmax temp = th.empty_like(logit_probs).uniform_(1e-5, 1. - 1e-5) - #print(f"temp.size: {temp.size()}\ntemp: {temp}") - #print(f"-th.log(- th.log(temp)).shape: {th.log(- th.log(temp)).shape}\n-th.log(- th.log(temp)): {- th.log(- th.log(temp))}") temp = logit_probs.detach() - th.log(- th.log(temp)) - #print(f"temp.size: {temp.size()}\ntemp: {temp}") _, argmax = temp.max(dim=3) - #print(f"argmax.shape: {argmax.shape}\nargmax: {argmax}") one_hot = to_one_hot(argmax, nr_mix) - #print(f"one_hot.shape: {one_hot.shape}\none_hot: {one_hot}") sel = one_hot.view([*xs[:-1], 1, nr_mix]) - #print(f"sel.shape: {sel.shape}\nsel: {sel}") - #time.sleep(100000) + # select logistic parameters means = th.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) log_scales = th.clamp(th.sum( @@ -524,7 +510,7 @@ def sample_from_discretized_mix_logistic(l, nr_mix): problem.reset(seed=args.seed) design_shape = problem.design_space.shape - #print(f"Design shape: {design_shape}") + conditions = problem.conditions_keys nr_conditions = len(conditions) @@ -548,7 +534,6 @@ def sample_from_discretized_mix_logistic(l, nr_mix): device = th.device("cuda") else: device = th.device("cpu") - #device = th.device("cpu") # Loss function loss_operator = discretized_mix_logistic_loss @@ -595,9 +580,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i desired_conds = th.stack(linspaces, dim=1).reshape(-1, nr_conditions, 1, 1) # If n_designs is large, sample in smaller batches to reduce GPU memory use. - # The default behavior previously sampled all at once; we keep that by - # allowing the caller to set `batch_size`. If batch_size >= n_designs we - # behave exactly as before. + # If batch_size >= n_designs then there is only one batch. batch_size = sampling_batch_size all_batches: list[th.Tensor] = [] From 9497b4e5cc2041a0be494eb35ddd9b3b565c94a4 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 17 Dec 2025 10:08:55 +0100 Subject: [PATCH 28/31] Final version of PixelCNN++ implementation --- .../evaluate_pixel_cnn_pp_2d.py | 31 +- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 284 ++++++++---------- 2 files changed, 129 insertions(+), 186 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index cd3b0aa..c728026 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -1,4 +1,6 @@ -"""Evaluation of the PixelCNN++.""" +"""Evaluation of the PixelCNN++ model.""" +from __future__ import annotations + from dataclasses import dataclass import os @@ -57,7 +59,7 @@ class Args: else: device = th.device("cpu") - ### Set up testing conditions ### + # Set up testing conditions conditions_tensor, sampled_conditions, sampled_designs_np, selected_indices = sample_conditions( problem=problem, n_samples=args.n_samples, @@ -65,12 +67,11 @@ class Args: seed=seed, ) - # -------------------------------------------------------- # adapt to PixelCNN++ input shape requirements conditions_tensor = conditions_tensor.unsqueeze(-1).unsqueeze(-1) design_shape = (problem.design_space.shape[0], problem.design_space.shape[1]) - ### Set Up PixelCNN++ Model ### + # Set up PixelCNN++ model if args.wandb_entity is not None: artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_pixel_cnn_pp_2d_model:seed_{seed}" else: @@ -88,11 +89,9 @@ def __init__(self): raise RunRetrievalError artifact_dir = artifact.download() - ckpt_path = os.path.join(artifact_dir, "model.pth") # change model.pth if necessary + ckpt_path = os.path.join(artifact_dir, "model.pth") ckpt = th.load(ckpt_path, map_location=device) - - # Build PixelCNN++ Model model = PixelCNNpp( nr_resnet=run.config["nr_resnet"], nr_filters=run.config["nr_filters"], @@ -104,22 +103,20 @@ def __init__(self): ) model.load_state_dict(ckpt["model"]) - model.eval() # Set to evaluation mode + model.eval() model.to(device) - batch_size = args.sampling_batch_size - + # Sample designs in batches all_batches: list[th.Tensor] = [] - for start in range(0, args.n_samples, batch_size): - end = min(args.n_samples, start + batch_size) + for start in range(0, args.n_samples, args.sampling_batch_size): + end = min(args.n_samples, start + args.sampling_batch_size) b = end - start # prepare batch-local tensors on the same device as the model batch_conds = conditions_tensor[start:end] data = th.zeros((b, 1, *design_shape), device=device) - print(data.shape) # Autoregressive pixel sampling for this batch for i in range(design_shape[0]): @@ -130,17 +127,13 @@ def __init__(self): # move completed batch to CPU to free GPU memory and store all_batches.append(data.cpu()) - print(f"Sampled batch {start} to {end} / {args.n_samples}") - # concatenate all batches on CPU and return desired_conds on CPU as well + # concatenate all batches on CPU gen_designs = th.cat(all_batches, dim=0) - print(gen_designs.shape) - gen_designs_np = gen_designs.detach().cpu().numpy() - gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) - # Clip to boundaries for running THIS IS PROBLEM DEPENDENT + gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) # remove channel dim gen_designs_np = np.clip(gen_designs_np, 1e-3, 1.0) # Compute metrics diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 3841305..f8f717e 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -1,8 +1,13 @@ """PixelCNN++ 2D model implementation. +Based on the original Tensorflow code of OpenAI: https://github.com/openai/pixel-cnn, +and the PyTorch implementation of Lucas Caccia: https://github.com/pclucas14/pixel-cnn-pp. + Provides the model classes, shifted convolutional blocks, and the discretized mixture of logistics loss used for training and sampling. """ +from __future__ import annotations + from dataclasses import dataclass import os import random @@ -13,7 +18,7 @@ import numpy as np import torch as th from torch import nn -import torch.nn.functional as F +import torch.nn.functional as f from torch.nn.utils.parametrizations import weight_norm import tqdm import tyro @@ -42,11 +47,11 @@ class Args: """Saves the model to disk.""" # Algorithm specific - n_epochs: int = 2 + n_epochs: int = 100 """number of epochs of training""" sample_interval: int = 600 """interval between image samples""" - batch_size: int = 4 + batch_size: int = 8 """size of the batches""" sampling_batch_size: int = 10 """Batch size to use during sampling.""" @@ -56,31 +61,34 @@ class Args: """decay of first order momentum of gradient""" b2: float = 0.9995 """decay of first order momentum of gradient""" - nr_resnet: int = 2 + nr_resnet: int = 5 """Number of residual blocks per stage of the model.""" - nr_filters: int = 30 + nr_filters: int = 160 """Number of filters to use across the model. Higher = larger model.""" nr_logistic_mix: int = 10 """Number of logistic components in the mixture. Higher = more flexible model.""" resnet_nonlinearity: str = "concat_elu" - """Nonlinearity to use in the ResNet blocks. One of 'concat_elu', 'elu', 'relu'.""" + """Nonlinearity to use in the ResNet blocks.""" dropout_p: float = 0.5 """Dropout probability.""" -def concat_elu(x): + +def concat_elu(x: th.Tensor) -> th.Tensor: """Like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU.""" - # Pytorch ordering + # PyTorch ordering axis = len(x.size()) - 3 - return F.elu(th.cat([x, -x], dim=axis)) + return f.elu(th.cat([x, -x], dim=axis)) + -class nin(nn.Module): - def __init__(self, nr_filters_in, nr_filters_out): + +class Nin(nn.Module): + def __init__(self, nr_filters_in: int, nr_filters_out: int): super().__init__() self.lin_a = weight_norm(nn.Linear(nr_filters_in, nr_filters_out)) self.nr_filters_out = nr_filters_out - def forward(self, x): + def forward(self, x: th.Tensor) -> th.Tensor: x = x.permute(0, 2, 3, 1) # BCHW -> BHWC xs = list(x.shape) x = x.reshape(-1, xs[3]) # -> [B*H*W, C] @@ -88,9 +96,12 @@ def forward(self, x): out = out.view(xs[0], xs[1], xs[2], self.nr_filters_out) return out.permute(0, 3, 1, 2) # BHWC -> BCHW + + class GatedResnet(nn.Module): - def __init__(self, nr_filters, conv_op, resnet_nonlinearity=concat_elu, skip_connection=0, dropout_p=0.5, nr_conditions=0): + def __init__(self, nr_filters: int, conv_op: nn.Module, resnet_nonlinearity: callable = concat_elu, skip_connection: int = 0, dropout_p: float = 0.5, nr_conditions: int = 0): # noqa: PLR0913 super().__init__() + self.skip_connection = skip_connection self.resnet_nonlinearity = resnet_nonlinearity @@ -102,7 +113,7 @@ def __init__(self, nr_filters, conv_op, resnet_nonlinearity=concat_elu, skip_con self.conv_input = conv_op(self.filter_doubling * nr_filters, nr_filters) if skip_connection != 0: - self.nin_skip = nin(self.filter_doubling * skip_connection * nr_filters, nr_filters) + self.nin_skip = Nin(self.filter_doubling * skip_connection * nr_filters, nr_filters) self.dropout = nn.Dropout2d(dropout_p) self.conv_out = conv_op(self.filter_doubling * nr_filters, 2 * nr_filters) # output has to be doubled for gating @@ -110,80 +121,73 @@ def __init__(self, nr_filters, conv_op, resnet_nonlinearity=concat_elu, skip_con self.h_lin = nn.Linear(nr_conditions, 2 * nr_filters) - def forward(self, x, a=None, h=None): + def forward(self, x: th.Tensor, a: th.Tensor = None, h: th.Tensor = None) -> th.Tensor: c1 = self.conv_input(self.resnet_nonlinearity(x)) - # print(f"c1 shape: {c1.shape}") if a is not None : - # print(f"a shape: {a.shape}") c1 += self.nin_skip(self.resnet_nonlinearity(a)) c1 = self.resnet_nonlinearity(c1) c1 = self.dropout(c1) c2 = self.conv_out(c1) if h is not None: - # in forward, when `h` is [B, nr_conditions, 1, 1] + # `h` is [B, nr_conditions, 1, 1] h_flat = h.view(h.size(0), -1) # [B, nr_conditions] h_proj = self.h_lin(h_flat).unsqueeze(-1).unsqueeze(-1) # [B, 2*nr_filters, 1, 1] c2 += h_proj a, b = th.chunk(c2, 2, dim=1) - c3 = a * F.sigmoid(b) + c3 = a * f.sigmoid(b) return x + c3 -def downShift(x, pad): + +def down_shift(x: th.Tensor, pad: nn.Module) -> th.Tensor: + """Down shift the input tensor by one row.""" xs = list(x.shape) x = x[:, :, :xs[2] - 1, :] - x = pad(x) - return x + return pad(x) + + -def downRightShift(x, pad): +def down_right_shift(x: th.Tensor, pad: nn.Module) -> th.Tensor: + """Down right shift the input tensor by one row and one column.""" xs = list(x.shape) x = x[:, :, :, :xs[3] - 1] - x = pad(x) - return x + return pad(x) + -class DownShiftedConv2d(nn.Module): - def __init__(self, - nr_filters_in, - nr_filters_out, - filter_size=(2,3), - stride=(1,1), - shift_output_down=False): +class DownShiftedConv2d(nn.Module): + def __init__(self, nr_filters_in: int, nr_filters_out: int, filter_size: tuple = (2,3), stride: tuple = (1,1), shift_output_down: bool = False): # noqa: FBT001, FBT002 super().__init__() + self.pad = nn.ZeroPad2d((int((filter_size[1] - 1) / 2), int((filter_size[1] - 1) / 2), filter_size[0]-1, 0)) # padding left, right, top, bottom self.conv = weight_norm(nn.Conv2d(nr_filters_in, nr_filters_out, filter_size, stride)) self.shift_output_down = shift_output_down - self.down_shift = downShift + self.down_shift = down_shift self.down_shift_pad = nn.ZeroPad2d((0,0,1,0)) - def forward(self, x): + def forward(self, x : th.Tensor) -> th.Tensor: x = self.pad(x) x = self.conv(x) if self.shift_output_down: x = self.down_shift(x, pad=self.down_shift_pad) return x -class DownShiftedDeconv2d(nn.Module): - def __init__(self, - nr_filters_in, - nr_filters_out, - filter_size=(2,3), - stride=(1,1), - output_padding=(0,1)): + +class DownShiftedDeconv2d(nn.Module): + def __init__(self, nr_filters_in: int, nr_filters_out: int, filter_size: tuple = (2,3), stride: tuple = (1,1), output_padding: tuple = (0,1)): super().__init__() + self.deconv = weight_norm(nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=output_padding)) self.filter_size = filter_size self.stride = stride - def forward(self, x, output_padding=None): - # Allow callers to pass a dynamic `output_padding` (some callers compute - # this to handle odd/even spatial sizes). If not provided, use the - # configured ConvTranspose2d module directly. + def forward(self, x: th.Tensor, output_padding: tuple | None = None) -> th.Tensor: + # Use output_padding if needed & provided to handle odd/even spatial sizes. if output_padding is None: x = self.deconv(x) else: - x = F.conv_transpose2d( + x = f.conv_transpose2d( x, self.deconv.weight, self.deconv.bias, @@ -196,46 +200,41 @@ def forward(self, x, output_padding=None): xs = list(x.shape) return x[:, :, :(xs[2] - self.filter_size[0] + 1), int((self.filter_size[1] - 1) / 2):(xs[3] - int((self.filter_size[1] - 1) / 2))] -class DownRightShiftedConv2d(nn.Module): - def __init__(self, - nr_filters_in, - nr_filters_out, - filter_size=(2,2), - stride=(1,1), - shift_output_right_down=False): + +class DownRightShiftedConv2d(nn.Module): + def __init__(self, nr_filters_in: int, nr_filters_out: int, filter_size: tuple = (2,2), stride: tuple = (1,1), shift_output_right_down: bool = False): # noqa: FBT001, FBT002 super().__init__() + self.pad = nn.ZeroPad2d((filter_size[1] - 1, 0, filter_size[0]-1, 0)) # padding left, right, top, bottom self.conv = weight_norm(nn.Conv2d(nr_filters_in, nr_filters_out, filter_size, stride)) self.shift_output_right_down = shift_output_right_down - self.down_right_shift = downRightShift + self.down_right_shift = down_right_shift self.down_right_shift_pad = nn.ZeroPad2d((1,0,0,0)) - def forward(self, x): + def forward(self, x: th.Tensor) -> th.Tensor: x = self.pad(x) x = self.conv(x) if self.shift_output_right_down: x = self.down_right_shift(x, pad=self.down_right_shift_pad) return x -class DownRightShiftedDeconv2d(nn.Module): - def __init__(self, - nr_filters_in, - nr_filters_out, - filter_size=(2,2), - stride=(1,1), - output_padding=(1,0)): + +class DownRightShiftedDeconv2d(nn.Module): + def __init__(self, nr_filters_in: int, nr_filters_out: int, filter_size: tuple = (2,2), stride: tuple = (1,1), output_padding: tuple = (1,0)): super().__init__() + self.deconv = weight_norm(nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=output_padding)) self.filter_size = filter_size self.stride = stride - def forward(self, x, output_padding=None): + def forward(self, x: th.Tensor, output_padding: tuple | None = None) -> th.Tensor: + # Use output_padding if needed & provided to handle odd/even spatial sizes. if output_padding is None: x = self.deconv(x) else: - x = F.conv_transpose2d( + x = f.conv_transpose2d( x, self.deconv.weight, self.deconv.bias, @@ -249,26 +248,19 @@ def forward(self, x, output_padding=None): return x[:, :, :(xs[2] - self.filter_size[0] + 1), :(xs[3] - self.filter_size[1] + 1)] -# IMPLEMENT PIXELCNN++ HERE -class PixelCNNpp(nn.Module): - def __init__(self, - nr_resnet: int, - nr_filters: int, - nr_logistic_mix: int, - resnet_nonlinearity: str, - dropout_p: float, - input_channels: int = 1, - nr_conditions: int = 0): +class PixelCNNpp(nn.Module): + def __init__(self, nr_resnet: int, nr_filters: int, nr_logistic_mix: int, resnet_nonlinearity: str, dropout_p: float, input_channels: int = 1, nr_conditions: int = 0): # noqa: PLR0913 super().__init__() + if resnet_nonlinearity == "concat_elu" : self.resnet_nonlinearity = concat_elu elif resnet_nonlinearity == "elu" : - self.resnet_nonlinearity = F.elu + self.resnet_nonlinearity = f.elu elif resnet_nonlinearity == "relu" : - self.resnet_nonlinearity = F.relu + self.resnet_nonlinearity = f.relu else: - raise Exception("Only concat elu, elu and relu are supported as resnet nonlinearity.") # noqa: TRY002 + raise Exception("Only concat_elu, elu and relu are supported as resnet_nonlinearity.") # noqa: TRY002 self.nr_resnet = nr_resnet self.nr_filters = nr_filters @@ -313,10 +305,11 @@ def __init__(self, self.gated_resnet_block_u_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet + 1)]) self.gated_resnet_block_ul_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet + 1)]) - num_mix = 3 if self.input_channels == 1 else 10 - self.nin_out = nin(nr_filters, num_mix * nr_logistic_mix) + num_mix = 3 + self.nin_out = Nin(nr_filters, num_mix * nr_logistic_mix) + - def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: + def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: # noqa: C901 xs = list(x.shape) padding = th.ones(xs[0], 1, xs[2], xs[3], device=x.device) x = th.cat((x, padding), dim=1) # add extra channel for padding @@ -335,7 +328,6 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: ul_list.append(self.downsize_ul_1(ul_list[-1])) # Handle images with odd height/width - # print(f"Before first downsize: u.shape[2]: {u_list[-1].shape[2]}, u.shape[3]: {u_list[-1].shape[3]}, u.shape[2]: {u_list[-2].shape[2]}, u.shape[3]: {u_list[-2].shape[3]}") pad_height = 1 pad_width = 1 if u_list[-2].shape[2] % u_list[-1].shape[2] != 0: @@ -373,33 +365,31 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: u = self.gated_resnet_block_u_down_1[i](u, a=u_list.pop(), h=c) ul = self.gated_resnet_block_ul_down_1[i](ul, a=th.cat((u, ul_list.pop()), dim=1), h=c) - #print(f"After first down pass: u shape: {u.shape}, ul shape: {ul.shape}") u = self.upsize_u_1(u, output_padding=output_padding_list[-1]) ul = self.upsize_ul_1(ul, output_padding=output_padding_list[-1]) - # print(f"After first upsize: u shape: {u.shape}, ul shape: {ul.shape}") for i in range(self.nr_resnet + 1): u = self.gated_resnet_block_u_down_2[i](u, a=u_list.pop(), h=c) ul = self.gated_resnet_block_ul_down_2[i](ul, a=th.cat((u, ul_list.pop()), dim=1), h=c) - # print(f"After second down pass: u shape: {u.shape}, ul shape: {ul.shape}") u = self.upsize_u_2(u, output_padding=output_padding_list[-2]) ul = self.upsize_ul_2(ul, output_padding=output_padding_list[-2]) - # print(f"After second upsize: u shape: {u.shape}, ul shape: {ul.shape}") for i in range(self.nr_resnet + 1): u = self.gated_resnet_block_u_down_3[i](u, a=u_list.pop(), h=c) ul = self.gated_resnet_block_ul_down_3[i](ul, a=th.cat((u, ul_list.pop()), dim=1), h=c) - x_out = self.nin_out(F.elu(ul)) + x_out = self.nin_out(f.elu(ul)) assert len(u_list) == 0 assert len(ul_list) == 0 return x_out -def log_sum_exp(x): + + +def log_sum_exp(x: th.Tensor) -> th.Tensor: """Numerically stable log_sum_exp implementation that prevents overflow.""" # TF ordering axis = len(x.size()) - 1 @@ -407,7 +397,9 @@ def log_sum_exp(x): m2, _ = th.max(x, dim=axis, keepdim=True) return m + th.log(th.sum(th.exp(x - m2), dim=axis)) -def log_prob_from_logits(x): + + +def log_prob_from_logits(x: th.Tensor) -> th.Tensor: """Numerically stable log_softmax implementation that prevents overflow.""" # TF ordering axis = len(x.size()) - 1 @@ -415,17 +407,17 @@ def log_prob_from_logits(x): return x - m - th.log(th.sum(th.exp(x - m), dim=axis, keepdim=True)) -def discretized_mix_logistic_loss(x, l): +def discretized_mix_logistic_loss(x: th.Tensor, l: th.Tensor) -> th.Tensor: """Log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval.""" # Tensorflow ordering x = x.permute(0, 2, 3, 1) - print(f"x shape in loss: {x.shape}") l = l.permute(0, 2, 3, 1) - print(f"l shape in loss: {l.shape}") - xs = list(x.shape) # true image (i.e. labels) to regress to, e.g. (B,width,height,3) - ls = list(l.shape) # predicted distribution, e.g. (B,width,height,30) - nr_mix = int(ls[-1] / 3) # here and below: unpacking the params of the mixture of logistics - logit_probs = l[:,:,:,:nr_mix] + xs = list(x.shape) # true image (i.e. labels) + ls = list(l.shape) # predicted distribution + + # unpacking the params of the mixture of logistics + nr_mix = int(ls[-1] / 3) + logit_probs = l[:,:,:,:nr_mix] # mixture probabilities l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # 2 for mean, scale means = l[:,:,:,:,:nr_mix] log_scales = th.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) @@ -436,53 +428,49 @@ def discretized_mix_logistic_loss(x, l): centered_x = x - means inv_stdv = th.exp(-log_scales) plus_in = inv_stdv * (centered_x + 1./255.) - cdf_plus = F.sigmoid(plus_in) + cdf_plus = f.sigmoid(plus_in) min_in = inv_stdv * (centered_x - 1./255.) - cdf_min = F.sigmoid(min_in) - # log probability for edge case of 0 (before scaling) - log_cdf_plus = plus_in - F.softplus(plus_in) - # log probability for edge case of 255 (before scaling) - log_one_minus_cdf_min = -F.softplus(min_in) - cdf_delta = cdf_plus - cdf_min # probability for all other cases - mid_in = inv_stdv * centered_x + cdf_min = f.sigmoid(min_in) + # log probability for edge case of 0 + log_cdf_plus = plus_in - f.softplus(plus_in) + # log probability for edge case of 255 + log_one_minus_cdf_min = -f.softplus(min_in) + # probability for all other cases + cdf_delta = cdf_plus - cdf_min # log probability in the center of the bin, to be used in extreme cases - # (likely not used in this code) - log_pdf_mid = mid_in - log_scales - 2.*F.softplus(mid_in) - - # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen here) + mid_in = inv_stdv * centered_x + log_pdf_mid = mid_in - log_scales - 2.*f.softplus(mid_in) - # this is what is really done, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() - # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) + # select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen here) - # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) - # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs - # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue - # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value + # this is what is really done, but using the robust version below for extreme cases + # log_probs = th.where(x < -0.999, log_cdf_plus, th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta))) # noqa: ERA001 + # robust version, that still works if the probability is below 1e-5 + # approximation used based on the assumption that the log-density is constant in the bin of the observed sub-pixel value log_probs = th.where(x < -0.999, log_cdf_plus, th.where(x > 0.999, log_one_minus_cdf_min, th.where(cdf_delta > 1e-5, th.log(th.clamp(cdf_delta, min=1e-12)), log_pdf_mid - np.log(127.5)))) # noqa: PLR2004 - print(f"log_probs shape before sum: {log_probs.shape}") log_probs = th.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) - time.sleep(100000) return -th.sum(log_sum_exp(log_probs)) -def to_one_hot(tensor, n, fill_with=1.): - # we perform one hot encode with respect to the last axis + +def to_one_hot(tensor: th.Tensor, n: int, fill_with: float = 1.) -> th.Tensor: + """One hot encoding with respect to the last axis.""" one_hot = th.zeros((*tensor.size(), n), device=tensor.device) - #print(f"one_hot.shape: {one_hot.shape}") one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) - #print(f"one_hot.shape after scatter: {one_hot.shape}\none_hot: {one_hot}") return one_hot -def sample_from_discretized_mix_logistic(l, nr_mix): + +def sample_from_discretized_mix_logistic(l: th.Tensor, nr_mix: int) -> th.Tensor: + """Sample from a discretized mixture of logistic distributions.""" # Tensorflow ordering l = l.permute(0, 2, 3, 1) ls = list(l.shape) xs = [*ls[:-1], 1] # unpack parameters - logit_probs = l[:, :, :, :nr_mix] + logit_probs = l[:, :, :, :nr_mix] # mixture probabilities l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # for mean, scale # sample mixture indicator from softmax @@ -499,8 +487,8 @@ def sample_from_discretized_mix_logistic(l, nr_mix): u = th.empty_like(means).uniform_(1e-5, 1. - 1e-5) x = means + th.exp(log_scales) * (th.log(u) - th.log(1. - u)) x0 = th.clamp(th.clamp(x[:, :, :, 0], min=-1.), max=1.) - out = x0.unsqueeze(1) - return out + return x0.unsqueeze(1) + if __name__ == "__main__": @@ -514,7 +502,6 @@ def sample_from_discretized_mix_logistic(l, nr_mix): conditions = problem.conditions_keys nr_conditions = len(conditions) - # Logging run_name = f"{args.problem_id}__{args.algo}__{args.seed}__{int(time.time())}" if args.track: @@ -550,13 +537,12 @@ def sample_from_discretized_mix_logistic(l, nr_mix): ) model.to(device) - # loss.to(device) # Configure data loader training_ds = problem.dataset.with_format("torch", device=device)["train"] condition_tensors = [training_ds[key][:] for key in problem.conditions_keys] - training_ds = th.utils.data.TensorDataset(training_ds["optimal_design"][:], *condition_tensors) # .flatten(1) ? + training_ds = th.utils.data.TensorDataset(training_ds["optimal_design"][:], *condition_tensors) dataloader = th.utils.data.DataLoader( training_ds, @@ -565,7 +551,9 @@ def sample_from_discretized_mix_logistic(l, nr_mix): ) # Optimzer - optimizer = th.optim.Adam(model.parameters(), lr=args.lr, betas=(args.b1, args.b2)) # add other args if necessary + optimizer = th.optim.Adam(model.parameters(), lr=args.lr, betas=(args.b1, args.b2)) + + @th.no_grad() def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25, sampling_batch_size: int = 10) -> tuple[th.Tensor, th.Tensor]: @@ -607,39 +595,6 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i data_all = th.cat(all_batches, dim=0) return data_all, desired_conds.cpu() - # original version without batching - # @th.no_grad() - # def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25) -> tuple[th.Tensor, th.Tensor]: - # """Samples n_designs designs using dataset conditions.""" - # model.eval() - # device = next(model.parameters()).device - - # linspaces = [ - # th.linspace(conds[:, i].min(), conds[:, i].max(), n_designs, device=device) for i in range(conds.shape[1]) - # ] - - # desired_conds = th.stack(linspaces, dim=1).reshape(-1, nr_conditions, 1, 1) - - # # Prepare an empty image batch on the same device as the model - # data = th.zeros((n_designs, dim, *design_shape), device=device) - - # # Autoregressive pixel sampling: iterate spatial positions and condition on - # # previously sampled pixels and the desired_conds. - # for i in range(design_shape[0]): - # for j in range(design_shape[1]): - # # print(f"Sampling pixel ({i}, {j})") - # out = model(data, desired_conds) - # # print(f"out shape: {out.shape}") - # out_sample = sample_from_discretized_mix_logistic(out, args.nr_logistic_mix) - # # print(f"out_sample shape: {out_sample.shape}") - # # `out_sample` has shape [B, 1, H, W]; copy the sampled value for (i,j) - # data[:, :, i, j] = out_sample.data[:, :, i, j] #out_sample[:, :, i, j] # out_sample.data[:, :, i, j] - # #print(f"Completed row {i+1}/{design_shape[0]}") - - # #print(f"final data shape: {data.shape}") - # #print(f"desired_conds shape: {desired_conds.shape}") - # return data, desired_conds - # ---------- @@ -651,12 +606,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i designs = data[0].unsqueeze(dim=1) # add channel dim designs_rescaled = designs * 2. - 1. # rescale to [-1, 1] - #print(designs.shape) - #print(data[1:]) - conds = th.stack((data[1:]), dim=1).reshape(-1, nr_conditions, 1, 1) - # print(f"conds.shape: {conds.shape}") - #print(conds) batch_start_time = time.time() out = model(designs_rescaled, conds) @@ -698,9 +648,9 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i for j, tensor in enumerate(designs): tensor_rescaled = (tensor + 1.) / 2. # rescale to [0, 1] img = tensor_rescaled.cpu().numpy().reshape(design_shape[0], design_shape[1]) # Extract x and y coordinates - #print(f"img shape: {img.shape}") + dc = desired_conds[j].cpu().squeeze() # Extract design conditions - #print(f"dc shape: {dc.shape}") + axes[j].imshow(img) # image plot title = [(problem.conditions_keys[i][0], f"{dc[i]:.2f}") for i in range(nr_conditions)] title_string = "\n ".join(f"{condition}: {value}" for condition, value in title) From f29cca1e7175c01da98c79b1b941ff208299c1a4 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 17 Dec 2025 13:23:09 +0100 Subject: [PATCH 29/31] Fixed typos --- engiopt/cgan_cnn_2d/cgan_cnn_2d.py | 2 +- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 40 ++++++++++++---------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/engiopt/cgan_cnn_2d/cgan_cnn_2d.py b/engiopt/cgan_cnn_2d/cgan_cnn_2d.py index 23119df..80cf116 100644 --- a/engiopt/cgan_cnn_2d/cgan_cnn_2d.py +++ b/engiopt/cgan_cnn_2d/cgan_cnn_2d.py @@ -44,7 +44,7 @@ class Args: """Saves the model to disk.""" # Algorithm specific - n_epochs: int = 2 + n_epochs: int = 200 """number of epochs of training""" batch_size: int = 32 """size of the batches""" diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index f8f717e..6de761e 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -60,7 +60,7 @@ class Args: b1: float = 0.95 """decay of first order momentum of gradient""" b2: float = 0.9995 - """decay of first order momentum of gradient""" + """decay of second order momentum of gradient""" nr_resnet: int = 5 """Number of residual blocks per stage of the model.""" nr_filters: int = 160 @@ -123,7 +123,7 @@ def __init__(self, nr_filters: int, conv_op: nn.Module, resnet_nonlinearity: cal def forward(self, x: th.Tensor, a: th.Tensor = None, h: th.Tensor = None) -> th.Tensor: c1 = self.conv_input(self.resnet_nonlinearity(x)) - if a is not None : + if a is not None: c1 += self.nin_skip(self.resnet_nonlinearity(a)) c1 = self.resnet_nonlinearity(c1) c1 = self.dropout(c1) @@ -165,7 +165,7 @@ def __init__(self, nr_filters_in: int, nr_filters_out: int, filter_size: tuple = self.down_shift = down_shift self.down_shift_pad = nn.ZeroPad2d((0,0,1,0)) - def forward(self, x : th.Tensor) -> th.Tensor: + def forward(self, x: th.Tensor) -> th.Tensor: x = self.pad(x) x = self.conv(x) if self.shift_output_down: @@ -253,11 +253,11 @@ class PixelCNNpp(nn.Module): def __init__(self, nr_resnet: int, nr_filters: int, nr_logistic_mix: int, resnet_nonlinearity: str, dropout_p: float, input_channels: int = 1, nr_conditions: int = 0): # noqa: PLR0913 super().__init__() - if resnet_nonlinearity == "concat_elu" : + if resnet_nonlinearity == "concat_elu": self.resnet_nonlinearity = concat_elu - elif resnet_nonlinearity == "elu" : + elif resnet_nonlinearity == "elu": self.resnet_nonlinearity = f.elu - elif resnet_nonlinearity == "relu" : + elif resnet_nonlinearity == "relu": self.resnet_nonlinearity = f.relu else: raise Exception("Only concat_elu, elu and relu are supported as resnet_nonlinearity.") # noqa: TRY002 @@ -391,7 +391,7 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: # noqa: C901 def log_sum_exp(x: th.Tensor) -> th.Tensor: """Numerically stable log_sum_exp implementation that prevents overflow.""" - # TF ordering + # [B, W, H, C] ordering axis = len(x.size()) - 1 m, _ = th.max(x, dim=axis) m2, _ = th.max(x, dim=axis, keepdim=True) @@ -401,7 +401,7 @@ def log_sum_exp(x: th.Tensor) -> th.Tensor: def log_prob_from_logits(x: th.Tensor) -> th.Tensor: """Numerically stable log_softmax implementation that prevents overflow.""" - # TF ordering + # [B, W, H, C] ordering axis = len(x.size()) - 1 m, _ = th.max(x, dim=axis, keepdim=True) return x - m - th.log(th.sum(th.exp(x - m), dim=axis, keepdim=True)) @@ -409,16 +409,17 @@ def log_prob_from_logits(x: th.Tensor) -> th.Tensor: def discretized_mix_logistic_loss(x: th.Tensor, l: th.Tensor) -> th.Tensor: """Log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval.""" - # Tensorflow ordering + # [B, W, H, C] ordering x = x.permute(0, 2, 3, 1) l = l.permute(0, 2, 3, 1) xs = list(x.shape) # true image (i.e. labels) ls = list(l.shape) # predicted distribution # unpacking the params of the mixture of logistics + # nr_mix = nr_logistic_mix and is multiplied by 3 (for \pi, \mu, s) nr_mix = int(ls[-1] / 3) - logit_probs = l[:,:,:,:nr_mix] # mixture probabilities - l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # 2 for mean, scale + logit_probs = l[:,:,:,:nr_mix] # mixture probabilities (\pi) + l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # *2 for mean (\mu), scale (s) means = l[:,:,:,:,:nr_mix] log_scales = th.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) x = x.contiguous() @@ -464,14 +465,15 @@ def to_one_hot(tensor: th.Tensor, n: int, fill_with: float = 1.) -> th.Tensor: def sample_from_discretized_mix_logistic(l: th.Tensor, nr_mix: int) -> th.Tensor: """Sample from a discretized mixture of logistic distributions.""" - # Tensorflow ordering + # [B, W, H, C] ordering l = l.permute(0, 2, 3, 1) ls = list(l.shape) xs = [*ls[:-1], 1] - # unpack parameters - logit_probs = l[:, :, :, :nr_mix] # mixture probabilities - l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # for mean, scale + # unpacking the params of the mixture of logistics + # nr_mix = nr_logistic_mix and is multiplied by 3 (for \pi, \mu, s) + logit_probs = l[:, :, :, :nr_mix] # mixture probabilities (\pi) + l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # *2 for mean (\mu), scale (s) # sample mixture indicator from softmax temp = th.empty_like(logit_probs).uniform_(1e-5, 1. - 1e-5) @@ -550,19 +552,19 @@ def sample_from_discretized_mix_logistic(l: th.Tensor, nr_mix: int) -> th.Tensor shuffle=True, ) - # Optimzer + # Optimizer optimizer = th.optim.Adam(model.parameters(), lr=args.lr, betas=(args.b1, args.b2)) @th.no_grad() - def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: int = 1, n_designs: int = 25, sampling_batch_size: int = 10) -> tuple[th.Tensor, th.Tensor]: + def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], conditions: th.Tensor, dim: int = 1, n_designs: int = 25, sampling_batch_size: int = 10) -> tuple[th.Tensor, th.Tensor]: # noqa: PLR0913 """Samples n_designs designs using dataset conditions.""" model.eval() device = next(model.parameters()).device # Build the full list of requested condition combinations (on the model device) linspaces = [ - th.linspace(conds[:, i].min(), conds[:, i].max(), n_designs, device=device) for i in range(conds.shape[1]) + th.linspace(conditions[:, i].min(), conditions[:, i].max(), n_designs, device=device) for i in range(conditions.shape[1]) ] desired_conds = th.stack(linspaces, dim=1).reshape(-1, nr_conditions, 1, 1) @@ -638,7 +640,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], dim: i if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, desired_conds = sample_designs(model, design_shape, dim=1, n_designs=25, sampling_batch_size=args.sampling_batch_size) + designs, desired_conds = sample_designs(model, design_shape, conds, dim=1, n_designs=25, sampling_batch_size=args.sampling_batch_size) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing From ea6fe9a9a39048b369519e57f049c5ec53eea86f Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 17 Dec 2025 13:58:17 +0100 Subject: [PATCH 30/31] Formatting with ruff --- .../evaluate_pixel_cnn_pp_2d.py | 7 +- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 501 ++++++++++++------ 2 files changed, 356 insertions(+), 152 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py index c728026..224340c 100644 --- a/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/evaluate_pixel_cnn_pp_2d.py @@ -1,4 +1,5 @@ """Evaluation of the PixelCNN++ model.""" + from __future__ import annotations from dataclasses import dataclass @@ -99,14 +100,13 @@ def __init__(self): resnet_nonlinearity=run.config["resnet_nonlinearity"], dropout_p=run.config["dropout_p"], input_channels=1, - nr_conditions=conditions_tensor.shape[1] + nr_conditions=conditions_tensor.shape[1], ) model.load_state_dict(ckpt["model"]) model.eval() model.to(device) - # Sample designs in batches all_batches: list[th.Tensor] = [] @@ -131,9 +131,8 @@ def __init__(self): # concatenate all batches on CPU gen_designs = th.cat(all_batches, dim=0) - gen_designs_np = gen_designs.detach().cpu().numpy() - gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) # remove channel dim + gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) # remove channel dim gen_designs_np = np.clip(gen_designs_np, 1e-3, 1.0) # Compute metrics diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 6de761e..9aae3fc 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -6,12 +6,14 @@ Provides the model classes, shifted convolutional blocks, and the discretized mixture of logistics loss used for training and sampling. """ + from __future__ import annotations from dataclasses import dataclass import os import random import time +import typing from engibench.utils.all_problems import BUILTIN_PROBLEMS import matplotlib.pyplot as plt @@ -73,7 +75,6 @@ class Args: """Dropout probability.""" - def concat_elu(x: th.Tensor) -> th.Tensor: """Like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU.""" # PyTorch ordering @@ -81,8 +82,7 @@ def concat_elu(x: th.Tensor) -> th.Tensor: return f.elu(th.cat([x, -x], dim=axis)) - -class Nin(nn.Module): +class NIN(nn.Module): def __init__(self, nr_filters_in: int, nr_filters_out: int): super().__init__() self.lin_a = weight_norm(nn.Linear(nr_filters_in, nr_filters_out)) @@ -97,9 +97,16 @@ def forward(self, x: th.Tensor) -> th.Tensor: return out.permute(0, 3, 1, 2) # BHWC -> BCHW - class GatedResnet(nn.Module): - def __init__(self, nr_filters: int, conv_op: nn.Module, resnet_nonlinearity: callable = concat_elu, skip_connection: int = 0, dropout_p: float = 0.5, nr_conditions: int = 0): # noqa: PLR0913 + def __init__( # noqa: PLR0913 + self, + nr_filters: int, + conv_op: nn.Module, + resnet_nonlinearity: typing.Callable = concat_elu, + skip_connection: int = 0, + dropout_p: float = 0.5, + nr_conditions: int = 0, + ): super().__init__() self.skip_connection = skip_connection @@ -113,14 +120,13 @@ def __init__(self, nr_filters: int, conv_op: nn.Module, resnet_nonlinearity: cal self.conv_input = conv_op(self.filter_doubling * nr_filters, nr_filters) if skip_connection != 0: - self.nin_skip = Nin(self.filter_doubling * skip_connection * nr_filters, nr_filters) + self.nin_skip = NIN(self.filter_doubling * skip_connection * nr_filters, nr_filters) self.dropout = nn.Dropout2d(dropout_p) - self.conv_out = conv_op(self.filter_doubling * nr_filters, 2 * nr_filters) # output has to be doubled for gating + self.conv_out = conv_op(self.filter_doubling * nr_filters, 2 * nr_filters) # output has to be doubled for gating self.h_lin = nn.Linear(nr_conditions, 2 * nr_filters) - def forward(self, x: th.Tensor, a: th.Tensor = None, h: th.Tensor = None) -> th.Tensor: c1 = self.conv_input(self.resnet_nonlinearity(x)) if a is not None: @@ -130,7 +136,7 @@ def forward(self, x: th.Tensor, a: th.Tensor = None, h: th.Tensor = None) -> th. c2 = self.conv_out(c1) if h is not None: # `h` is [B, nr_conditions, 1, 1] - h_flat = h.view(h.size(0), -1) # [B, nr_conditions] + h_flat = h.view(h.size(0), -1) # [B, nr_conditions] h_proj = self.h_lin(h_flat).unsqueeze(-1).unsqueeze(-1) # [B, 2*nr_filters, 1, 1] c2 += h_proj a, b = th.chunk(c2, 2, dim=1) @@ -138,32 +144,38 @@ def forward(self, x: th.Tensor, a: th.Tensor = None, h: th.Tensor = None) -> th. return x + c3 - def down_shift(x: th.Tensor, pad: nn.Module) -> th.Tensor: """Down shift the input tensor by one row.""" xs = list(x.shape) - x = x[:, :, :xs[2] - 1, :] + x = x[:, :, : xs[2] - 1, :] return pad(x) - def down_right_shift(x: th.Tensor, pad: nn.Module) -> th.Tensor: """Down right shift the input tensor by one row and one column.""" xs = list(x.shape) - x = x[:, :, :, :xs[3] - 1] + x = x[:, :, :, : xs[3] - 1] return pad(x) - class DownShiftedConv2d(nn.Module): - def __init__(self, nr_filters_in: int, nr_filters_out: int, filter_size: tuple = (2,3), stride: tuple = (1,1), shift_output_down: bool = False): # noqa: FBT001, FBT002 + def __init__( + self, + nr_filters_in: int, + nr_filters_out: int, + filter_size: tuple = (2, 3), + stride: tuple = (1, 1), + shift_output_down: bool = False, # noqa: FBT001, FBT002 + ): super().__init__() - self.pad = nn.ZeroPad2d((int((filter_size[1] - 1) / 2), int((filter_size[1] - 1) / 2), filter_size[0]-1, 0)) # padding left, right, top, bottom + self.pad = nn.ZeroPad2d( + (int((filter_size[1] - 1) / 2), int((filter_size[1] - 1) / 2), filter_size[0] - 1, 0) + ) # padding left, right, top, bottom self.conv = weight_norm(nn.Conv2d(nr_filters_in, nr_filters_out, filter_size, stride)) self.shift_output_down = shift_output_down self.down_shift = down_shift - self.down_shift_pad = nn.ZeroPad2d((0,0,1,0)) + self.down_shift_pad = nn.ZeroPad2d((0, 0, 1, 0)) def forward(self, x: th.Tensor) -> th.Tensor: x = self.pad(x) @@ -173,12 +185,20 @@ def forward(self, x: th.Tensor) -> th.Tensor: return x - class DownShiftedDeconv2d(nn.Module): - def __init__(self, nr_filters_in: int, nr_filters_out: int, filter_size: tuple = (2,3), stride: tuple = (1,1), output_padding: tuple = (0,1)): + def __init__( + self, + nr_filters_in: int, + nr_filters_out: int, + filter_size: tuple = (2, 3), + stride: tuple = (1, 1), + output_padding: tuple = (0, 1), + ): super().__init__() - self.deconv = weight_norm(nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=output_padding)) + self.deconv = weight_norm( + nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=output_padding) + ) self.filter_size = filter_size self.stride = stride @@ -198,19 +218,30 @@ def forward(self, x: th.Tensor, output_padding: tuple | None = None) -> th.Tenso groups=self.deconv.groups, ) xs = list(x.shape) - return x[:, :, :(xs[2] - self.filter_size[0] + 1), int((self.filter_size[1] - 1) / 2):(xs[3] - int((self.filter_size[1] - 1) / 2))] - + return x[ + :, + :, + : (xs[2] - self.filter_size[0] + 1), + int((self.filter_size[1] - 1) / 2) : (xs[3] - int((self.filter_size[1] - 1) / 2)), + ] class DownRightShiftedConv2d(nn.Module): - def __init__(self, nr_filters_in: int, nr_filters_out: int, filter_size: tuple = (2,2), stride: tuple = (1,1), shift_output_right_down: bool = False): # noqa: FBT001, FBT002 + def __init__( + self, + nr_filters_in: int, + nr_filters_out: int, + filter_size: tuple = (2, 2), + stride: tuple = (1, 1), + shift_output_right_down: bool = False, # noqa: FBT001, FBT002 + ): super().__init__() - self.pad = nn.ZeroPad2d((filter_size[1] - 1, 0, filter_size[0]-1, 0)) # padding left, right, top, bottom + self.pad = nn.ZeroPad2d((filter_size[1] - 1, 0, filter_size[0] - 1, 0)) # padding left, right, top, bottom self.conv = weight_norm(nn.Conv2d(nr_filters_in, nr_filters_out, filter_size, stride)) self.shift_output_right_down = shift_output_right_down self.down_right_shift = down_right_shift - self.down_right_shift_pad = nn.ZeroPad2d((1,0,0,0)) + self.down_right_shift_pad = nn.ZeroPad2d((1, 0, 0, 0)) def forward(self, x: th.Tensor) -> th.Tensor: x = self.pad(x) @@ -220,12 +251,20 @@ def forward(self, x: th.Tensor) -> th.Tensor: return x - class DownRightShiftedDeconv2d(nn.Module): - def __init__(self, nr_filters_in: int, nr_filters_out: int, filter_size: tuple = (2,2), stride: tuple = (1,1), output_padding: tuple = (1,0)): + def __init__( + self, + nr_filters_in: int, + nr_filters_out: int, + filter_size: tuple = (2, 2), + stride: tuple = (1, 1), + output_padding: tuple = (1, 0), + ): super().__init__() - self.deconv = weight_norm(nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=output_padding)) + self.deconv = weight_norm( + nn.ConvTranspose2d(nr_filters_in, nr_filters_out, filter_size, stride, output_padding=output_padding) + ) self.filter_size = filter_size self.stride = stride @@ -245,12 +284,20 @@ def forward(self, x: th.Tensor, output_padding: tuple | None = None) -> th.Tenso groups=self.deconv.groups, ) xs = list(x.shape) - return x[:, :, :(xs[2] - self.filter_size[0] + 1), :(xs[3] - self.filter_size[1] + 1)] - + return x[:, :, : (xs[2] - self.filter_size[0] + 1), : (xs[3] - self.filter_size[1] + 1)] class PixelCNNpp(nn.Module): - def __init__(self, nr_resnet: int, nr_filters: int, nr_logistic_mix: int, resnet_nonlinearity: str, dropout_p: float, input_channels: int = 1, nr_conditions: int = 0): # noqa: PLR0913 + def __init__( # noqa: PLR0913 + self, + nr_resnet: int, + nr_filters: int, + nr_logistic_mix: int, + resnet_nonlinearity: str, + dropout_p: float, + input_channels: int = 1, + nr_conditions: int = 0, + ): super().__init__() if resnet_nonlinearity == "concat_elu": @@ -267,47 +314,198 @@ def __init__(self, nr_resnet: int, nr_filters: int, nr_logistic_mix: int, resnet self.nr_logistic_mix = nr_logistic_mix self.input_channels = input_channels - # UP PASS blocks - self.u_init = DownShiftedConv2d(input_channels + 1, nr_filters, filter_size=(2,3), stride=(1,1), shift_output_down=True) - self.ul_init = nn.ModuleList([DownShiftedConv2d(input_channels + 1, nr_filters, filter_size=(1,3), stride=(1,1), shift_output_down=True), - DownRightShiftedConv2d(input_channels + 1, nr_filters, filter_size=(2,1), stride=(1,1), shift_output_right_down=True)]) - - self.gated_resnet_block_u_up_1 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) - self.gated_resnet_block_ul_up_1 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) - - self.downsize_u_1 = DownShiftedConv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) - self.downsize_ul_1 = DownRightShiftedConv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) - - self.gated_resnet_block_u_up_2 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) - self.gated_resnet_block_ul_up_2 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) - - self.downsize_u_2 = DownShiftedConv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) - self.downsize_ul_2 = DownRightShiftedConv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) - - self.gated_resnet_block_u_up_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=0, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) - self.gated_resnet_block_ul_up_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) - + self.u_init = DownShiftedConv2d( + input_channels + 1, nr_filters, filter_size=(2, 3), stride=(1, 1), shift_output_down=True + ) + self.ul_init = nn.ModuleList( + [ + DownShiftedConv2d( + input_channels + 1, nr_filters, filter_size=(1, 3), stride=(1, 1), shift_output_down=True + ), + DownRightShiftedConv2d( + input_channels + 1, nr_filters, filter_size=(2, 1), stride=(1, 1), shift_output_right_down=True + ), + ] + ) + + self.gated_resnet_block_u_up_1 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=0, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet) + ] + ) + self.gated_resnet_block_ul_up_1 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownRightShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=1, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet) + ] + ) + + self.downsize_u_1 = DownShiftedConv2d(nr_filters, nr_filters, filter_size=(2, 3), stride=(2, 2)) + self.downsize_ul_1 = DownRightShiftedConv2d(nr_filters, nr_filters, filter_size=(2, 2), stride=(2, 2)) + + self.gated_resnet_block_u_up_2 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=0, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet) + ] + ) + self.gated_resnet_block_ul_up_2 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownRightShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=1, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet) + ] + ) + + self.downsize_u_2 = DownShiftedConv2d(nr_filters, nr_filters, filter_size=(2, 3), stride=(2, 2)) + self.downsize_ul_2 = DownRightShiftedConv2d(nr_filters, nr_filters, filter_size=(2, 2), stride=(2, 2)) + + self.gated_resnet_block_u_up_3 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=0, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet) + ] + ) + self.gated_resnet_block_ul_up_3 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownRightShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=1, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet) + ] + ) # DOWN PASS blocks - self.gated_resnet_block_u_down_1 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) - self.gated_resnet_block_ul_down_1 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet)]) - - self.upsize_u_1 = DownShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) - self.upsize_ul_1 = DownRightShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) - - self.gated_resnet_block_u_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet + 1)]) - self.gated_resnet_block_ul_down_2 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet + 1)]) - - self.upsize_u_2 = DownShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,3), stride=(2,2)) - self.upsize_ul_2 = DownRightShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2,2), stride=(2,2)) - - self.gated_resnet_block_u_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownShiftedConv2d, self.resnet_nonlinearity, skip_connection=1, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet + 1)]) - self.gated_resnet_block_ul_down_3 = nn.ModuleList([GatedResnet(nr_filters, DownRightShiftedConv2d, self.resnet_nonlinearity, skip_connection=2, dropout_p=dropout_p, nr_conditions=nr_conditions) for _ in range(nr_resnet + 1)]) + self.gated_resnet_block_u_down_1 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=1, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet) + ] + ) + self.gated_resnet_block_ul_down_1 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownRightShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=2, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet) + ] + ) + + self.upsize_u_1 = DownShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2, 3), stride=(2, 2)) + self.upsize_ul_1 = DownRightShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2, 2), stride=(2, 2)) + + self.gated_resnet_block_u_down_2 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=1, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet + 1) + ] + ) + self.gated_resnet_block_ul_down_2 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownRightShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=2, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet + 1) + ] + ) + + self.upsize_u_2 = DownShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2, 3), stride=(2, 2)) + self.upsize_ul_2 = DownRightShiftedDeconv2d(nr_filters, nr_filters, filter_size=(2, 2), stride=(2, 2)) + + self.gated_resnet_block_u_down_3 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=1, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet + 1) + ] + ) + self.gated_resnet_block_ul_down_3 = nn.ModuleList( + [ + GatedResnet( + nr_filters, + DownRightShiftedConv2d, + self.resnet_nonlinearity, + skip_connection=2, + dropout_p=dropout_p, + nr_conditions=nr_conditions, + ) + for _ in range(nr_resnet + 1) + ] + ) num_mix = 3 - self.nin_out = Nin(nr_filters, num_mix * nr_logistic_mix) - + self.nin_out = NIN(nr_filters, num_mix * nr_logistic_mix) def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: # noqa: C901 xs = list(x.shape) @@ -356,9 +554,8 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: # noqa: C901 u_list.append(self.gated_resnet_block_u_up_3[i](u_list[-1], a=None, h=c)) ul_list.append(self.gated_resnet_block_ul_up_3[i](ul_list[-1], a=u_list[-1], h=c)) - # DOWN PASS ("decoder") - u = u_list.pop() + u = u_list.pop() ul = ul_list.pop() for i in range(self.nr_resnet): @@ -379,7 +576,6 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: # noqa: C901 u = self.gated_resnet_block_u_down_3[i](u, a=u_list.pop(), h=c) ul = self.gated_resnet_block_ul_down_3[i](ul, a=th.cat((u, ul_list.pop()), dim=1), h=c) - x_out = self.nin_out(f.elu(ul)) assert len(u_list) == 0 @@ -388,17 +584,15 @@ def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: # noqa: C901 return x_out - def log_sum_exp(x: th.Tensor) -> th.Tensor: """Numerically stable log_sum_exp implementation that prevents overflow.""" # [B, W, H, C] ordering - axis = len(x.size()) - 1 - m, _ = th.max(x, dim=axis) + axis = len(x.size()) - 1 + m, _ = th.max(x, dim=axis) m2, _ = th.max(x, dim=axis, keepdim=True) return m + th.log(th.sum(th.exp(x - m2), dim=axis)) - def log_prob_from_logits(x: th.Tensor) -> th.Tensor: """Numerically stable log_softmax implementation that prevents overflow.""" # [B, W, H, C] ordering @@ -408,61 +602,67 @@ def log_prob_from_logits(x: th.Tensor) -> th.Tensor: def discretized_mix_logistic_loss(x: th.Tensor, l: th.Tensor) -> th.Tensor: - """Log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval.""" - # [B, W, H, C] ordering - x = x.permute(0, 2, 3, 1) - l = l.permute(0, 2, 3, 1) - xs = list(x.shape) # true image (i.e. labels) - ls = list(l.shape) # predicted distribution - - # unpacking the params of the mixture of logistics - # nr_mix = nr_logistic_mix and is multiplied by 3 (for \pi, \mu, s) - nr_mix = int(ls[-1] / 3) - logit_probs = l[:,:,:,:nr_mix] # mixture probabilities (\pi) - l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # *2 for mean (\mu), scale (s) - means = l[:,:,:,:,:nr_mix] - log_scales = th.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) - x = x.contiguous() - zeros = th.zeros([*xs, nr_mix], device=x.device) - x = x.unsqueeze(-1) + zeros - - centered_x = x - means - inv_stdv = th.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 1./255.) - cdf_plus = f.sigmoid(plus_in) - min_in = inv_stdv * (centered_x - 1./255.) - cdf_min = f.sigmoid(min_in) - # log probability for edge case of 0 - log_cdf_plus = plus_in - f.softplus(plus_in) - # log probability for edge case of 255 - log_one_minus_cdf_min = -f.softplus(min_in) - # probability for all other cases - cdf_delta = cdf_plus - cdf_min - # log probability in the center of the bin, to be used in extreme cases - mid_in = inv_stdv * centered_x - log_pdf_mid = mid_in - log_scales - 2.*f.softplus(mid_in) - - # select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen here) - - # this is what is really done, but using the robust version below for extreme cases - # log_probs = th.where(x < -0.999, log_cdf_plus, th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta))) # noqa: ERA001 - - # robust version, that still works if the probability is below 1e-5 - # approximation used based on the assumption that the log-density is constant in the bin of the observed sub-pixel value - log_probs = th.where(x < -0.999, log_cdf_plus, th.where(x > 0.999, log_one_minus_cdf_min, th.where(cdf_delta > 1e-5, th.log(th.clamp(cdf_delta, min=1e-12)), log_pdf_mid - np.log(127.5)))) # noqa: PLR2004 - log_probs = th.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) - return -th.sum(log_sum_exp(log_probs)) - - - -def to_one_hot(tensor: th.Tensor, n: int, fill_with: float = 1.) -> th.Tensor: + """Log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval.""" + # [B, W, H, C] ordering + x = x.permute(0, 2, 3, 1) + l = l.permute(0, 2, 3, 1) + xs = list(x.shape) # true image (i.e. labels) + ls = list(l.shape) # predicted distribution + + # unpacking the params of the mixture of logistics + # nr_mix = nr_logistic_mix and is multiplied by 3 (for \pi, \mu, s) + nr_mix = int(ls[-1] / 3) + logit_probs = l[:, :, :, :nr_mix] # mixture probabilities (\pi) + l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # *2 for mean (\mu), scale (s) + means = l[:, :, :, :, :nr_mix] + log_scales = th.clamp(l[:, :, :, :, nr_mix : 2 * nr_mix], min=-7.0) + x = x.contiguous() + zeros = th.zeros([*xs, nr_mix], device=x.device) + x = x.unsqueeze(-1) + zeros + + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = f.sigmoid(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = f.sigmoid(min_in) + # log probability for edge case of 0 + log_cdf_plus = plus_in - f.softplus(plus_in) + # log probability for edge case of 255 + log_one_minus_cdf_min = -f.softplus(min_in) + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + # log probability in the center of the bin, to be used in extreme cases + mid_in = inv_stdv * centered_x + log_pdf_mid = mid_in - log_scales - 2.0 * f.softplus(mid_in) + + # select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen here) + + # this is what is really done, but using the robust version below for extreme cases + # log_probs = th.where(x < -0.999, log_cdf_plus, th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta))) # noqa: ERA001 + + # robust version, that still works if the probability is below 1e-5 + # approximation used based on the assumption that the log-density is constant in the bin of the observed sub-pixel value + log_probs = th.where( + x < -0.999, # noqa: PLR2004 + log_cdf_plus, + th.where( + x > 0.999, # noqa: PLR2004 + log_one_minus_cdf_min, + th.where(cdf_delta > 1e-5, th.log(th.clamp(cdf_delta, min=1e-12)), log_pdf_mid - np.log(127.5)), # noqa: PLR2004 + ), + ) + log_probs = th.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) + return -th.sum(log_sum_exp(log_probs)) + + +def to_one_hot(tensor: th.Tensor, n: int, fill_with: float = 1.0) -> th.Tensor: """One hot encoding with respect to the last axis.""" one_hot = th.zeros((*tensor.size(), n), device=tensor.device) one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) return one_hot - def sample_from_discretized_mix_logistic(l: th.Tensor, nr_mix: int) -> th.Tensor: """Sample from a discretized mixture of logistic distributions.""" # [B, W, H, C] ordering @@ -472,27 +672,25 @@ def sample_from_discretized_mix_logistic(l: th.Tensor, nr_mix: int) -> th.Tensor # unpacking the params of the mixture of logistics # nr_mix = nr_logistic_mix and is multiplied by 3 (for \pi, \mu, s) - logit_probs = l[:, :, :, :nr_mix] # mixture probabilities (\pi) - l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # *2 for mean (\mu), scale (s) + logit_probs = l[:, :, :, :nr_mix] # mixture probabilities (\pi) + l = l[:, :, :, nr_mix:].contiguous().view([*xs, nr_mix * 2]) # *2 for mean (\mu), scale (s) # sample mixture indicator from softmax - temp = th.empty_like(logit_probs).uniform_(1e-5, 1. - 1e-5) - temp = logit_probs.detach() - th.log(- th.log(temp)) + temp = th.empty_like(logit_probs).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.detach() - th.log(-th.log(temp)) _, argmax = temp.max(dim=3) one_hot = to_one_hot(argmax, nr_mix) sel = one_hot.view([*xs[:-1], 1, nr_mix]) # select logistic parameters means = th.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) - log_scales = th.clamp(th.sum( - l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) - u = th.empty_like(means).uniform_(1e-5, 1. - 1e-5) - x = means + th.exp(log_scales) * (th.log(u) - th.log(1. - u)) - x0 = th.clamp(th.clamp(x[:, :, :, 0], min=-1.), max=1.) + log_scales = th.clamp(th.sum(l[:, :, :, :, nr_mix : 2 * nr_mix] * sel, dim=4), min=-7.0) + u = th.empty_like(means).uniform_(1e-5, 1.0 - 1e-5) + x = means + th.exp(log_scales) * (th.log(u) - th.log(1.0 - u)) + x0 = th.clamp(th.clamp(x[:, :, :, 0], min=-1.0), max=1.0) return x0.unsqueeze(1) - if __name__ == "__main__": args = tyro.cli(Args) @@ -535,7 +733,7 @@ def sample_from_discretized_mix_logistic(l: th.Tensor, nr_mix: int) -> th.Tensor resnet_nonlinearity=args.resnet_nonlinearity, dropout_p=args.dropout_p, input_channels=1, - nr_conditions=nr_conditions + nr_conditions=nr_conditions, ) model.to(device) @@ -555,16 +753,22 @@ def sample_from_discretized_mix_logistic(l: th.Tensor, nr_mix: int) -> th.Tensor # Optimizer optimizer = th.optim.Adam(model.parameters(), lr=args.lr, betas=(args.b1, args.b2)) - - @th.no_grad() - def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], conditions: th.Tensor, dim: int = 1, n_designs: int = 25, sampling_batch_size: int = 10) -> tuple[th.Tensor, th.Tensor]: # noqa: PLR0913 + def sample_designs( # noqa: PLR0913 + model: PixelCNNpp, + design_shape: tuple[int, int, int], + conditions: th.Tensor, + dim: int = 1, + n_designs: int = 25, + sampling_batch_size: int = 10, + ) -> tuple[th.Tensor, th.Tensor]: """Samples n_designs designs using dataset conditions.""" model.eval() device = next(model.parameters()).device # Build the full list of requested condition combinations (on the model device) linspaces = [ - th.linspace(conditions[:, i].min(), conditions[:, i].max(), n_designs, device=device) for i in range(conditions.shape[1]) + th.linspace(conditions[:, i].min(), conditions[:, i].max(), n_designs, device=device) + for i in range(conditions.shape[1]) ] desired_conds = th.stack(linspaces, dim=1).reshape(-1, nr_conditions, 1, 1) @@ -597,16 +801,14 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], condit data_all = th.cat(all_batches, dim=0) return data_all, desired_conds.cpu() - - # ---------- # Training # ---------- for epoch in tqdm.trange(args.n_epochs): model.train() for i, data in enumerate(dataloader): - designs = data[0].unsqueeze(dim=1) # add channel dim - designs_rescaled = designs * 2. - 1. # rescale to [-1, 1] + designs = data[0].unsqueeze(dim=1) # add channel dim + designs_rescaled = designs * 2.0 - 1.0 # rescale to [-1, 1] conds = th.stack((data[1:]), dim=1).reshape(-1, nr_conditions, 1, 1) @@ -619,7 +821,6 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], condit loss.backward() optimizer.step() - # ---------- # Logging # ---------- @@ -640,7 +841,9 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], condit if batches_done % args.sample_interval == 0: # Extract 25 designs - designs, desired_conds = sample_designs(model, design_shape, conds, dim=1, n_designs=25, sampling_batch_size=args.sampling_batch_size) + designs, desired_conds = sample_designs( + model, design_shape, conds, dim=1, n_designs=25, sampling_batch_size=args.sampling_batch_size + ) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing @@ -648,10 +851,12 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], condit # Plot the image created by each output for j, tensor in enumerate(designs): - tensor_rescaled = (tensor + 1.) / 2. # rescale to [0, 1] - img = tensor_rescaled.cpu().numpy().reshape(design_shape[0], design_shape[1]) # Extract x and y coordinates + tensor_rescaled = (tensor + 1.0) / 2.0 # rescale to [0, 1] + img = ( + tensor_rescaled.cpu().numpy().reshape(design_shape[0], design_shape[1]) + ) # Extract x and y coordinates - dc = desired_conds[j].cpu().squeeze() # Extract design conditions + dc = desired_conds[j].cpu().squeeze() # Extract design conditions axes[j].imshow(img) # image plot title = [(problem.conditions_keys[i][0], f"{dc[i]:.2f}") for i in range(nr_conditions)] @@ -670,7 +875,7 @@ def sample_designs(model: PixelCNNpp, design_shape: tuple[int, int, int], condit # Save models # -------------- if args.save_model and epoch == args.n_epochs - 1 and i == len(dataloader) - 1: - #if args.save_model and (((epoch + 1) % args.model_storage_interval == 0) or (epoch == args.n_epochs - 1)) and i == len(dataloader) - 1: + # if args.save_model and (((epoch + 1) % args.model_storage_interval == 0) or (epoch == args.n_epochs - 1)) and i == len(dataloader) - 1: ckpt_model = { "epoch": epoch, "batches_done": batches_done, From c2f726b4ef5d4b4de69cb42f62094f35b3422dd1 Mon Sep 17 00:00:00 2001 From: Jonas Date: Wed, 17 Dec 2025 14:08:21 +0100 Subject: [PATCH 31/31] replaced NIN by NetworkInNetwork --- engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py index 9aae3fc..4fe6a46 100644 --- a/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py +++ b/engiopt/pixel_cnn_pp_2d/pixel_cnn_pp_2d.py @@ -82,7 +82,7 @@ def concat_elu(x: th.Tensor) -> th.Tensor: return f.elu(th.cat([x, -x], dim=axis)) -class NIN(nn.Module): +class NetworkInNetwork(nn.Module): def __init__(self, nr_filters_in: int, nr_filters_out: int): super().__init__() self.lin_a = weight_norm(nn.Linear(nr_filters_in, nr_filters_out)) @@ -120,7 +120,7 @@ def __init__( # noqa: PLR0913 self.conv_input = conv_op(self.filter_doubling * nr_filters, nr_filters) if skip_connection != 0: - self.nin_skip = NIN(self.filter_doubling * skip_connection * nr_filters, nr_filters) + self.nin_skip = NetworkInNetwork(self.filter_doubling * skip_connection * nr_filters, nr_filters) self.dropout = nn.Dropout2d(dropout_p) self.conv_out = conv_op(self.filter_doubling * nr_filters, 2 * nr_filters) # output has to be doubled for gating @@ -505,7 +505,7 @@ def __init__( # noqa: PLR0913 ) num_mix = 3 - self.nin_out = NIN(nr_filters, num_mix * nr_logistic_mix) + self.nin_out = NetworkInNetwork(nr_filters, num_mix * nr_logistic_mix) def forward(self, x: th.Tensor, c: th.Tensor) -> th.Tensor: # noqa: C901 xs = list(x.shape)