From 482c85dda1a0dee2e14086f7866b6afcb846e1a4 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Tue, 23 Sep 2025 13:07:21 +0200 Subject: [PATCH 01/22] add initial vqgan files --- .gitignore | 4 ++++ engiopt/vqgan/__init__.py | 0 engiopt/vqgan/evaluate_vqgan.py | 0 engiopt/vqgan/vqgan.py | 0 4 files changed, 4 insertions(+) create mode 100644 engiopt/vqgan/__init__.py create mode 100644 engiopt/vqgan/evaluate_vqgan.py create mode 100644 engiopt/vqgan/vqgan.py diff --git a/.gitignore b/.gitignore index 92e8956..93b33db 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,10 @@ share/python-wheels/ *.egg MANIFEST +# Windows local environment files +pyvenv.cfg +Scripts/ + # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. diff --git a/engiopt/vqgan/__init__.py b/engiopt/vqgan/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/engiopt/vqgan/evaluate_vqgan.py b/engiopt/vqgan/evaluate_vqgan.py new file mode 100644 index 0000000..e69de29 diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py new file mode 100644 index 0000000..e69de29 From be14a1772d63a71390b852a4bf083919a6955ed7 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Tue, 23 Sep 2025 13:55:08 +0200 Subject: [PATCH 02/22] add windows support and add vqgan description --- .gitattributes | 16 ++++++++++++++++ engiopt/vqgan/vqgan.py | 16 ++++++++++++++++ setup.py | 7 ++++++- 3 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 .gitattributes diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..9137b2f --- /dev/null +++ b/.gitattributes @@ -0,0 +1,16 @@ +*.py text eol=lf +*.json text eol=lf +*.yml text eol=lf +*.yaml text eol=lf +*.sh text eol=lf +*.md text eol=lf +*.txt text eol=lf +*.ipynb text eol=lf + +*.png binary +*.jpg binary +*.jpeg binary +*.gif binary +*.pdf binary +*.pkl binary +*.npy binary \ No newline at end of file diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index e69de29..f86b235 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -0,0 +1,16 @@ +"""Vector Quantized Generative Adversarial Network (VQGAN). + +Based on https://github.com/dome272/VQGAN-pytorch. + +VQGAN is composed of two primary Stages: + - Stage 1 is similar to an autoencoder (AE) but with a discrete latent space represented by a codebook. + - Stage 2 is a generative model (a transformer in this case) trained on the latent space of Stage 1. + +The transformer now uses nanoGPT (https://github.com/karpathy/nanoGPT) instead of minGPT (https://github.com/karpathy/minGPT) as in the original implementation. + +For Stage 2, we take the indices of the codebook vectors and flatten them into a 1D sequence, treating them as training tokens. +The transformer is then trained to autoregressively predict each token in the sequence, after which it is reshaped back to the original 2D latent space and passed through the decoder of Stage 1 to generate an image. +To make VQGAN conditional, we train a separate VQGAN on the conditions only (CVQGAN) and replace the start-of-sequence tokens of the transformer with the CVQGAN latent tokens. + +We have updated the transformer architecture, converted VQGAN from a two-stage to a single-stage approach, added several new arguments, switched to wandb for logging, added greyscale support to the perceptual loss, and more. +""" diff --git a/setup.py b/setup.py index aea1f69..875a6b2 100644 --- a/setup.py +++ b/setup.py @@ -15,4 +15,9 @@ def get_version(): raise RuntimeError("bad version data in __init__.py") -setup(name="engiopt", version=get_version(), long_description=open("README.md").read()) +setup( + name="engiopt", + version=get_version(), + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown" +) From 1398f41ed3b0dcf35d1bd8828cf7c211026add8e Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Tue, 23 Sep 2025 15:55:46 +0200 Subject: [PATCH 03/22] add vqgan args and some new image transform functions --- engiopt/transforms.py | 40 ++++++++++ engiopt/vqgan/evaluate_vqgan.py | 1 + engiopt/vqgan/vqgan.py | 127 +++++++++++++++++++++++++++++++- 3 files changed, 167 insertions(+), 1 deletion(-) diff --git a/engiopt/transforms.py b/engiopt/transforms.py index 3c48e6b..edfd27d 100644 --- a/engiopt/transforms.py +++ b/engiopt/transforms.py @@ -1,10 +1,12 @@ """Transformations for the data.""" from collections.abc import Callable +import math from engibench.core import Problem from gymnasium import spaces import torch as th +import torch.nn.functional as f def flatten_dict_factory(problem: Problem, device: th.device) -> Callable: @@ -21,3 +23,41 @@ def flatten_dict(x): return th.stack(flattened) return flatten_dict + + +def _nearest_power_of_two(x: int) -> int: + """Round x to the nearest power of 2.""" + lower = 2 ** math.floor(math.log2(x)) + upper = 2 ** math.ceil(math.log2(x)) + return upper if abs(x - upper) < abs(x - lower) else lower + + +def upsample_nearest(data: th.Tensor, mode: str="bicubic") -> th.Tensor: + """Upsample 2D data to the nearest 2^n dimensions. Data should be a Tensor in the format (B, C, H, W).""" + _, _, h, w = data.shape + target_h = _nearest_power_of_two(h) + target_w = _nearest_power_of_two(w) + # If nearest power of two is smaller, multiply it by 2 + if target_h < h: + target_h *= 2 + if target_w < w: + target_w *= 2 + return f.interpolate(data, size=(target_h, target_w), mode=mode) + + +def downsample_nearest(data: th.Tensor, mode: str="bicubic") -> th.Tensor: + """Downsample 2D data to the nearest 2^n dimensions. Data should be a Tensor in the format (B, C, H, W).""" + _, _, h, w = data.shape + target_h = _nearest_power_of_two(h) + target_w = _nearest_power_of_two(w) + # If nearest power of two is larger, divide it by 2 + if target_h > h: + target_h //= 2 + if target_w > w: + target_w //= 2 + return f.interpolate(data, size=(target_h, target_w), mode=mode) + + +def resize_to(data: th.Tensor, h: int, w: int, mode: str = "bicubic") -> th.Tensor: + """Resize 2D data back to any desired (h, w). Data should be a Tensor in the format (B, C, H, W).""" + return f.interpolate(data, size=(h, w), mode=mode) diff --git a/engiopt/vqgan/evaluate_vqgan.py b/engiopt/vqgan/evaluate_vqgan.py index e69de29..5ca6dc0 100644 --- a/engiopt/vqgan/evaluate_vqgan.py +++ b/engiopt/vqgan/evaluate_vqgan.py @@ -0,0 +1 @@ +"""Evaluation for the VQGAN.""" diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index f86b235..7240658 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -1,6 +1,7 @@ +# ruff: noqa: F401 # REMOVE THIS LATER """Vector Quantized Generative Adversarial Network (VQGAN). -Based on https://github.com/dome272/VQGAN-pytorch. +Based on https://github.com/dome272/VQGAN-pytorch with an "Online Clustered Codebook" for better codebook usage from https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py VQGAN is composed of two primary Stages: - Stage 1 is similar to an autoencoder (AE) but with a discrete latent space represented by a codebook. @@ -14,3 +15,127 @@ We have updated the transformer architecture, converted VQGAN from a two-stage to a single-stage approach, added several new arguments, switched to wandb for logging, added greyscale support to the perceptual loss, and more. """ + +from __future__ import annotations + +from dataclasses import dataclass +import os +import random +import time + +from engibench.utils.all_problems import BUILTIN_PROBLEMS +import matplotlib.pyplot as plt +import numpy as np +import torch as th +from torch import autograd +from torch import nn +from torch.nn import functional +import tqdm +import tyro +import wandb + +from engiopt.metrics import dpp_diversity +from engiopt.metrics import mmd +from engiopt.transforms import resize_to +from engiopt.transforms import upsample_nearest + + +@dataclass +class Args: + """Command-line arguments for VQGAN.""" + + problem_id: str = "beams2d" + """Problem identifier for 2D engineering design.""" + algo: str = os.path.basename(__file__)[: -len(".py")] + """The name of this algorithm.""" + + # Tracking + track: bool = True + """Track the experiment with wandb.""" + wandb_project: str = "engiopt" + """Wandb project name.""" + wandb_entity: str | None = None + """Wandb entity name.""" + seed: int = 1 + """Random seed.""" + save_model: bool = False + """Saves the model to disk.""" + + # Algorithm-specific: General + conditional: bool = True + """whether the model is conditional or not""" + + # Algorithm-specific: Stage 1 (AE) + # From original implementation: assume image_channels=1, use greyscale LPIPS only, use_Online=True, determine image_size automatically, calculate decoder_start_resolution automatically + n_epochs_1: int = 100 + """number of epochs of training""" + batch_size_1: int = 16 + """size of the batches""" + lr_1: float = 2e-4 + """learning rate for Stage 1""" + beta: float = 0.25 + """beta hyperparameter for the codebook commitment loss""" + b1: float = 0.5 + """decay of first order momentum of gradient""" + b2: float = 0.9 + """decay of first order momentum of gradient""" + n_cpu: int = 8 + """number of cpu threads to use during batch generation""" + latent_dim: int = 16 + """dimensionality of the latent space""" + codebook_vectors: int = 256 + """number of vectors in the codebook""" + disc_start: int = 0 + """epoch to start discriminator training""" + disc_factor: float = 1.0 + """weighting factor for the adversarial loss from the discriminator""" + rec_loss_factor: float = 1.0 + """weighting factor for the reconstruction loss""" + perceptual_loss_factor: float = 1.0 + """weighting factor for the perceptual loss""" + encoder_channels: tuple = (128, 128, 128, 256, 256, 512) + """list of channel sizes for each encoder layer""" + encoder_attn_resolutions: tuple = (16,) + """list of resolutions at which to apply attention in the encoder""" + encoder_num_res_blocks: int = 2 + """number of residual blocks per encoder layer""" + encoder_start_resolution: int = 256 + """starting resolution for the encoder""" + decoder_channels: tuple = (512, 256, 256, 128, 128) + """list of channel sizes for each decoder layer""" + decoder_attn_resolutions: tuple = (16,) + """list of resolutions at which to apply attention in the decoder""" + decoder_num_res_blocks: int = 3 + """number of residual blocks per decoder layer""" + sample_interval: int = 1600 + """interval between image samples""" + + # Algorithm-specific: Stage 1 (Conditional AE if the model is conditional) + cond_dim: int = 3 + """dimensionality of the condition space""" + cond_hidden_dim: int = 256 + """hidden dimension of the CVQGAN MLP""" + cond_latent_dim: int = 4 + "individual code dimension for CVQGAN" + cond_codebook_vectors: int = 256 + """number of vectors in the CVQGAN codebook""" + cond_feature_map_dim: int = 4 + """feature map dimension for the CVQGAN encoder output""" + + + # Algorithm-specific: Stage 2 (Transformer) + # From original implementation: assume pkeep=1.0, sos_token=0, bias=True + n_epochs_2: int = 100 + """number of epochs of training""" + batch_size_2: int = 16 + """size of the batches""" + lr_2: float = 6e-4 + """learning rate for Stage 2""" + n_layer: int = 12 + """number of layers in the transformer""" + n_head: int = 12 + """number of attention heads""" + n_embd: int = 768 + """transformer embedding dimension""" + dropout: float = 0.3 + """dropout rate in the transformer""" From e33490088e2fff36bca8c0caad08c7dceae02037 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Tue, 23 Sep 2025 17:16:31 +0200 Subject: [PATCH 04/22] add stage 1 submodules --- engiopt/vqgan/vqgan.py | 364 ++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 + 2 files changed, 364 insertions(+), 2 deletions(-) diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 7240658..99661a6 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -1,7 +1,7 @@ # ruff: noqa: F401 # REMOVE THIS LATER """Vector Quantized Generative Adversarial Network (VQGAN). -Based on https://github.com/dome272/VQGAN-pytorch with an "Online Clustered Codebook" for better codebook usage from https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py +Based on https://github.com/dome272/VQGAN-pyth with an "Online Clustered Codebook" for better codebook usage from https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py VQGAN is composed of two primary Stages: - Stage 1 is similar to an autoencoder (AE) but with a discrete latent space represented by a codebook. @@ -23,13 +23,14 @@ import random import time +from einops import rearrange from engibench.utils.all_problems import BUILTIN_PROBLEMS import matplotlib.pyplot as plt import numpy as np import torch as th from torch import autograd from torch import nn -from torch.nn import functional +from torch.nn import functional as f import tqdm import tyro import wandb @@ -139,3 +140,362 @@ class Args: """transformer embedding dimension""" dropout: float = 0.3 """dropout rate in the transformer""" + + +class Codebook(nn.Module): + """Improved version over vector quantizer, with the dynamic initialization for the unoptimized "dead" vectors. + + num_embed: number of codebook entry + embed_dim: dimensionality of codebook entry + beta: weight for the commitment loss + distance: distance for looking up the closest code + anchor: anchor sampled methods + first_batch: if true, the offline version of our model + contras_loss: if true, use the contras_loss to further improve the performance + """ + def __init__(self, args): + super().__init__() + + self.num_embed = args.c_num_codebook_vectors if args.is_c else args.num_codebook_vectors + self.embed_dim = args.c_latent_dim if args.is_c else args.latent_dim + self.beta = args.beta + + # Fixed parameters from the original implementation + self.distance = "cos" + self.anchor = "probrandom" + self.first_batch = False + self.contras_loss = False + self.decay = 0.99 + self.init = False + + self.pool = FeaturePool(self.num_embed, self.embed_dim) + self.embedding = nn.Embedding(self.num_embed, self.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed) + self.register_buffer("embed_prob", th.zeros(self.num_embed)) + + + def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, "b c h w -> b h w c").contiguous() + z_flattened = z.view(-1, self.embed_dim) + + # clculate the distance + if self.distance == "l2": + # l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = - th.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \ + th.sum(self.embedding.weight ** 2, dim=1) + \ + 2 * th.einsum("bd, dn-> bn", z_flattened.detach(), rearrange(self.embedding.weight, "n d-> d n")) + elif self.distance == "cos": + # cosine distances from z to embeddings e_j + normed_z_flattened = f.normalize(z_flattened, dim=1).detach() + normed_codebook = f.normalize(self.embedding.weight, dim=1) + d = th.einsum("bd,dn->bn", normed_z_flattened, rearrange(normed_codebook, "n d -> d n")) + + # encoding + sort_distance, indices = d.sort(dim=1) + # look up the closest point for the indices + encoding_indices = indices[:,-1] + encodings = th.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device) + encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) + + # quantise and unflatten + z_q = th.matmul(encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = self.beta * th.mean((z_q.detach()-z)**2) + th.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + # reshape back to match original input shape + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + # count + avg_probs = th.mean(encodings, dim=0) + perplexity = th.exp(-th.sum(avg_probs * th.log(avg_probs + 1e-10))) + min_encodings = encodings + + # online clustered reinitialisation for unoptimized points + if self.training: + # calculate the average usage of code entries + self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay) + # running average updates + if self.anchor in ["closest", "random", "probrandom"] and (not self.init): + # closest sampling + if self.anchor == "closest": + sort_distance, indices = d.sort(dim=0) + random_feat = z_flattened.detach()[indices[-1,:]] + # feature pool based random sampling + elif self.anchor == "random": + random_feat = self.pool.query(z_flattened.detach()) + # probabilitical based random sampling + elif self.anchor == "probrandom": + norm_distance = f.softmax(d.t(), dim=1) + prob = th.multinomial(norm_distance, num_samples=1).view(-1) + random_feat = z_flattened.detach()[prob] + # decay parameter based on the average usage + decay = th.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim) + self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay + if self.first_batch: + self.init = True + # contrastive loss + if self.contras_loss: + sort_distance, indices = d.sort(dim=0) + dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True) + dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:] + dis = th.cat([dis_pos, dis_neg], dim=0).t() / 0.07 + contra_loss = f.cross_entropy(dis, th.zeros((dis.size(0),), dtype=th.long, device=dis.device)) + loss += contra_loss + + return z_q, encoding_indices, loss, min_encodings, perplexity + + +class FeaturePool: + """Implements a feature buffer that stores previously encoded features. + + This buffer enables us to initialize the codebook using a history of generated features rather than the ones produced by the latest encoders + """ + def __init__(self, pool_size, dim=64): + """Initialize the FeaturePool class. + + Parameters: + pool_size(int) -- the size of featue buffer + """ + self.pool_size = pool_size + if self.pool_size > 0: + self.nums_features = 0 + self.features = (th.rand((pool_size, dim)) * 2 - 1)/ pool_size + + def query(self, features: th.Tensor) -> th.Tensor: + """Return features from the pool.""" + self.features = self.features.to(features.device) + if self.nums_features < self.pool_size: + if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook + random_feat_id = th.randint(0, features.size(0), (int(self.pool_size),)) + self.features = features[random_feat_id] + self.nums_features = self.pool_size + else: + # if the mini-batch is not large nuough, just store it for the next update + num = self.nums_features + features.size(0) + self.features[self.nums_features:num] = features + self.nums_features = num + elif features.size(0) > int(self.pool_size): + random_feat_id = th.randint(0, features.size(0), (int(self.pool_size),)) + self.features = features[random_feat_id] + else: + random_id = th.randperm(self.pool_size) + self.features[random_id[:features.size(0)]] = features + + return self.features + + +class GroupNorm(nn.Module): + """Group Normalization block to be used in VQGAN Encoder and Decoder.""" + def __init__(self, channels): + super().__init__() + self.gn = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True) + + def forward(self, x: th.Tensor) -> th.Tensor: + return self.gn(x) + + +class Swish(nn.Module): + """Swish activation function to be used in VQGAN Encoder and Decoder.""" + def forward(self, x: th.Tensor) -> th.Tensor: + return x * th.sigmoid(x) + + +class ResidualBlock(nn.Module): + """Residual block to be used in VQGAN Encoder and Decoder.""" + def __init__(self, in_channels, out_channels): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.block = nn.Sequential( + GroupNorm(in_channels), + Swish(), + nn.Conv2d(in_channels, out_channels, 3, 1, 1), + GroupNorm(out_channels), + Swish(), + nn.Conv2d(out_channels, out_channels, 3, 1, 1) + ) + + if in_channels != out_channels: + self.channel_up = nn.Conv2d(in_channels, out_channels, 1, 1, 0) + + def forward(self, x: th.Tensor) -> th.Tensor: + if self.in_channels != self.out_channels: + return self.channel_up(x) + self.block(x) + return x + self.block(x) + + +class UpSampleBlock(nn.Module): + """Up-sampling block to be used in VQGAN Decoder.""" + def __init__(self, channels): + super().__init__() + self.conv = nn.Conv2d(channels, channels, 3, 1, 1) + + def forward(self, x: th.Tensor) -> th.Tensor: + x = f.interpolate(x, scale_factor=2.0) + return self.conv(x) + + +class DownSampleBlock(nn.Module): + """Down-sampling block to be used in VQGAN Encoder.""" + def __init__(self, channels): + super().__init__() + self.conv = nn.Conv2d(channels, channels, 3, 2, 0) + + def forward(self, x: th.Tensor) -> th.Tensor: + pad = (0, 1, 0, 1) + x = f.pad(x, pad, mode="constant", value=0) + return self.conv(x) + + +class NonLocalBlock(nn.Module): + """Non-local attention block to be used in VQGAN Encoder and Decoder.""" + def __init__(self, channels): + super().__init__() + self.in_channels = channels + + self.gn = GroupNorm(channels) + self.q = nn.Conv2d(channels, channels, 1, 1, 0) + self.k = nn.Conv2d(channels, channels, 1, 1, 0) + self.v = nn.Conv2d(channels, channels, 1, 1, 0) + self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0) + + def forward(self, x: th.Tensor) -> th.Tensor: + h_ = self.gn(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h*w) + v = v.reshape(b, c, h*w) + + attn = th.bmm(q, k) + attn = attn * (int(c)**(-0.5)) + attn = f.softmax(attn, dim=2) + attn = attn.permute(0, 2, 1) + + a = th.bmm(v, attn) + a = a.reshape(b, c, h, w) + + return x + a + + +class LinearCombo(nn.Module): + """Regular fully connected layer combo for the CVQGAN if enabled.""" + def __init__(self, in_features, out_features, alpha=0.2): + super().__init__() + self.model = nn.Sequential( + nn.Linear(in_features, out_features), + nn.LeakyReLU(alpha) + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + return self.model(x) + + +class Encoder(nn.Module): + """Encoder module for VQGAN. + + Consists of a series of convolutional, residual, and attention blocks arranged using the provided arguments. + The number of downsample blocks is determined by the length of the encoder channels list minus two. + For example, if encoder_channels=(128, 128, 128, 128) and the starting resolution is 128, the encoder will downsample the input image twice, from 128x128 to 32x32. + """ + def __init__(self, args): + super().__init__() + channels = args.encoder_channels + resolution = args.encoder_start_resolution + layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)] + for i in range(len(channels)-1): + in_channels = channels[i] + out_channels = channels[i + 1] + for _ in range(args.encoder_num_res_blocks): + layers.append(ResidualBlock(in_channels, out_channels)) + in_channels = out_channels + if resolution in args.encoder_attn_resolutions: + layers.append(NonLocalBlock(in_channels)) + if i != len(channels)-2: + layers.append(DownSampleBlock(channels[i+1])) + resolution //= 2 + layers.append(ResidualBlock(channels[-1], channels[-1])) + layers.append(NonLocalBlock(channels[-1])) + layers.append(ResidualBlock(channels[-1], channels[-1])) + layers.append(GroupNorm(channels[-1])) + layers.append(Swish()) + layers.append(nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1)) + self.model = nn.Sequential(*layers) + + def forward(self, x: th.Tensor) -> th.Tensor: + return self.model(x) + + +class CondEncoder(nn.Module): + """Simpler MLP-based encoder for the CVQGAN if enabled.""" + def __init__(self, args): + super().__init__() + self.c_fmap_dim = args.c_fmap_dim + self.model = nn.Sequential( + LinearCombo(args.c_input_dim, args.c_hidden_dim), + LinearCombo(args.c_hidden_dim, args.c_hidden_dim), + nn.Linear(args.c_hidden_dim, args.c_latent_dim*args.c_fmap_dim**2) + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + encoded = self.model(x) + s = encoded.shape + return encoded.view(s[0], s[1]//self.c_fmap_dim**2, self.c_fmap_dim, self.c_fmap_dim) + + +class Decoder(nn.Module): + """Decoder module for VQGAN. + + Consists of a series of convolutional, residual, and attention blocks arranged using the provided arguments. + The number of upsample blocks is determined by the length of the decoder channels list minus one. + For example, if decoder_channels=(128, 128, 128) and the starting resolution is 32, the decoder will upsample the input image twice, from 32x32 to 128x128. + """ + def __init__(self, args): + super().__init__() + in_channels = args.decoder_channels[0] + resolution = args.decoder_start_resolution + layers = [nn.Conv2d(args.latent_dim, in_channels, 3, 1, 1), + ResidualBlock(in_channels, in_channels), + NonLocalBlock(in_channels), + ResidualBlock(in_channels, in_channels)] + + for i in range(len(args.decoder_channels)): + out_channels = args.decoder_channels[i] + for _ in range(args.decoder_num_res_blocks): + layers.append(ResidualBlock(in_channels, out_channels)) + in_channels = out_channels + if resolution in args.decoder_attn_resolutions: + layers.append(NonLocalBlock(in_channels)) + + if i != 0: + layers.append(UpSampleBlock(in_channels)) + resolution *= 2 + + layers.append(GroupNorm(in_channels)) + layers.append(Swish()) + layers.append(nn.Conv2d(in_channels, args.image_channels, 3, 1, 1)) + self.model = nn.Sequential(*layers) + + def forward(self, x: th.Tensor) -> th.Tensor: + return self.model(x) + + +class CondDecoder(nn.Module): + """Simpler MLP-based decoder for the CVQGAN if enabled.""" + def __init__(self, args): + super().__init__() + + self.model = nn.Sequential( + LinearCombo(args.c_latent_dim*args.c_fmap_dim**2, args.c_hidden_dim), + LinearCombo(args.c_hidden_dim, args.c_hidden_dim), + nn.Linear(args.c_hidden_dim, args.c_input_dim) + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + return self.model(x.contiguous().view(len(x), -1)) diff --git a/pyproject.toml b/pyproject.toml index d93efba..0ce82a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -291,5 +291,7 @@ module = [ "hyppo.*", "plotly", "plotly.*", + "einops", + "einops.*", ] ignore_missing_imports = true From 8df08dbefd5a3aea50c80de24e9d88480cc84a88 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Wed, 24 Sep 2025 16:56:55 +0200 Subject: [PATCH 05/22] clean up stage 1 modules --- engiopt/vqgan/vqgan.py | 265 ++++++++++++++++++++++++++++++----------- 1 file changed, 193 insertions(+), 72 deletions(-) diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 99661a6..93301bf 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -94,17 +94,17 @@ class Args: """weighting factor for the reconstruction loss""" perceptual_loss_factor: float = 1.0 """weighting factor for the perceptual loss""" - encoder_channels: tuple = (128, 128, 128, 256, 256, 512) + encoder_channels: tuple[int, ...] = (128, 128, 128, 256, 256, 512) """list of channel sizes for each encoder layer""" - encoder_attn_resolutions: tuple = (16,) + encoder_attn_resolutions: tuple[int, ...] = (16,) """list of resolutions at which to apply attention in the encoder""" encoder_num_res_blocks: int = 2 """number of residual blocks per encoder layer""" encoder_start_resolution: int = 256 """starting resolution for the encoder""" - decoder_channels: tuple = (512, 256, 256, 128, 128) + decoder_channels: tuple[int, ...] = (512, 256, 256, 128, 128) """list of channel sizes for each decoder layer""" - decoder_attn_resolutions: tuple = (16,) + decoder_attn_resolutions: tuple[int, ...] = (16,) """list of resolutions at which to apply attention in the decoder""" decoder_num_res_blocks: int = 3 """number of residual blocks per decoder layer""" @@ -145,28 +145,40 @@ class Args: class Codebook(nn.Module): """Improved version over vector quantizer, with the dynamic initialization for the unoptimized "dead" vectors. - num_embed: number of codebook entry - embed_dim: dimensionality of codebook entry - beta: weight for the commitment loss - distance: distance for looking up the closest code - anchor: anchor sampled methods - first_batch: if true, the offline version of our model - contras_loss: if true, use the contras_loss to further improve the performance + Parameters: + num_codebook_vectors (int): number of codebook entries + latent_dim (int): dimensionality of codebook entries + beta (float): weight for the commitment loss + decay (float): decay for the moving average of code usage + distance (str): distance type for looking up the closest code + anchor (str): anchor sampling methods + first_batch (bool): if true, the offline version of the model + contras_loss (bool): if true, use the contras_loss to further improve the performance + init (bool): if true, the codebook has been initialized """ - def __init__(self, args): + def __init__( # noqa: PLR0913 + self, *, + num_codebook_vectors: int, + latent_dim: int, + beta: float = 0.25, + decay: float = 0.99, + distance: str = "cos", + anchor: str = "probrandom", + first_batch: bool = False, + contras_loss: bool = False, + init: bool = False, + ): super().__init__() - self.num_embed = args.c_num_codebook_vectors if args.is_c else args.num_codebook_vectors - self.embed_dim = args.c_latent_dim if args.is_c else args.latent_dim - self.beta = args.beta - - # Fixed parameters from the original implementation - self.distance = "cos" - self.anchor = "probrandom" - self.first_batch = False - self.contras_loss = False - self.decay = 0.99 - self.init = False + self.num_embed = num_codebook_vectors + self.embed_dim = latent_dim + self.beta = beta + self.decay = decay + self.distance = distance + self.anchor = anchor + self.first_batch = first_batch + self.contras_loss = contras_loss + self.init = init self.pool = FeaturePool(self.num_embed, self.embed_dim) self.embedding = nn.Embedding(self.num_embed, self.embed_dim) @@ -198,7 +210,7 @@ def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Ten encodings = th.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device) encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) - # quantise and unflatten + # quantize and unflatten z_q = th.matmul(encodings, self.embedding.weight).view(z.shape) # compute loss for embedding loss = self.beta * th.mean((z_q.detach()-z)**2) + th.mean((z_q - z.detach()) ** 2) @@ -211,7 +223,7 @@ def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Ten perplexity = th.exp(-th.sum(avg_probs * th.log(avg_probs + 1e-10))) min_encodings = encodings - # online clustered reinitialisation for unoptimized points + # online clustered reinitialization for unoptimized points if self.training: # calculate the average usage of code entries self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay) @@ -249,18 +261,21 @@ def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Ten class FeaturePool: """Implements a feature buffer that stores previously encoded features. - This buffer enables us to initialize the codebook using a history of generated features rather than the ones produced by the latest encoders - """ - def __init__(self, pool_size, dim=64): - """Initialize the FeaturePool class. + This buffer enables us to initialize the codebook using a history of generated features rather than the ones produced by the latest encoders. - Parameters: - pool_size(int) -- the size of featue buffer - """ + Parameters: + pool_size (int): the size of featue buffer + dim (int): the dimension of each feature + """ + def __init__( + self, + pool_size: int, + dim: int = 64 + ): self.pool_size = pool_size if self.pool_size > 0: self.nums_features = 0 - self.features = (th.rand((pool_size, dim)) * 2 - 1)/ pool_size + self.features = (th.rand((pool_size, dim)) * 2 - 1) / pool_size def query(self, features: th.Tensor) -> th.Tensor: """Return features from the pool.""" @@ -286,8 +301,15 @@ def query(self, features: th.Tensor) -> th.Tensor: class GroupNorm(nn.Module): - """Group Normalization block to be used in VQGAN Encoder and Decoder.""" - def __init__(self, channels): + """Group Normalization block to be used in VQGAN Encoder and Decoder. + + Parameters: + channels (int): number of channels in the input feature map + """ + def __init__( + self, + channels: int + ): super().__init__() self.gn = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True) @@ -302,8 +324,17 @@ def forward(self, x: th.Tensor) -> th.Tensor: class ResidualBlock(nn.Module): - """Residual block to be used in VQGAN Encoder and Decoder.""" - def __init__(self, in_channels, out_channels): + """Residual block to be used in VQGAN Encoder and Decoder. + + Parameters: + in_channels (int): number of input channels + out_channels (int): number of output channels + """ + def __init__( + self, + in_channels: int, + out_channels: int + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -326,8 +357,15 @@ def forward(self, x: th.Tensor) -> th.Tensor: class UpSampleBlock(nn.Module): - """Up-sampling block to be used in VQGAN Decoder.""" - def __init__(self, channels): + """Up-sampling block to be used in VQGAN Decoder. + + Parameters: + channels (int): number of channels in the input feature map + """ + def __init__( + self, + channels: int + ): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, 1, 1) @@ -337,8 +375,15 @@ def forward(self, x: th.Tensor) -> th.Tensor: class DownSampleBlock(nn.Module): - """Down-sampling block to be used in VQGAN Encoder.""" - def __init__(self, channels): + """Down-sampling block to be used in VQGAN Encoder. + + Parameters: + channels (int): number of channels in the input feature map + """ + def __init__( + self, + channels: int + ): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, 2, 0) @@ -349,8 +394,15 @@ def forward(self, x: th.Tensor) -> th.Tensor: class NonLocalBlock(nn.Module): - """Non-local attention block to be used in VQGAN Encoder and Decoder.""" - def __init__(self, channels): + """Non-local attention block to be used in VQGAN Encoder and Decoder. + + Parameters: + channels (int): number of channels in the input feature map + """ + def __init__( + self, + channels: int + ): super().__init__() self.in_channels = channels @@ -385,8 +437,19 @@ def forward(self, x: th.Tensor) -> th.Tensor: class LinearCombo(nn.Module): - """Regular fully connected layer combo for the CVQGAN if enabled.""" - def __init__(self, in_features, out_features, alpha=0.2): + """Regular fully connected layer combo for the CVQGAN if enabled. + + Parameters: + in_features (int): number of input features + out_features (int): number of output features + alpha (float): negative slope for LeakyReLU + """ + def __init__( + self, + in_features: int, + out_features: int, + alpha: float = 0.2 + ): super().__init__() self.model = nn.Sequential( nn.Linear(in_features, out_features), @@ -403,19 +466,35 @@ class Encoder(nn.Module): Consists of a series of convolutional, residual, and attention blocks arranged using the provided arguments. The number of downsample blocks is determined by the length of the encoder channels list minus two. For example, if encoder_channels=(128, 128, 128, 128) and the starting resolution is 128, the encoder will downsample the input image twice, from 128x128 to 32x32. + + Parameters: + encoder_channels (tuple[int, ...]): list of channel sizes for each encoder layer + encoder_start_resolution (int): starting resolution for the encoder + encoder_attn_resolutions (tuple[int, ...]): list of resolutions at which to apply attention in the encoder + encoder_num_res_blocks (int): number of residual blocks per encoder layer + image_channels (int): number of channels in the input image + latent_dim (int): dimensionality of the latent space """ - def __init__(self, args): + def __init__( # noqa: PLR0913 + self, + encoder_channels: tuple[int, ...], + encoder_start_resolution: int, + encoder_attn_resolutions: tuple[int, ...], + encoder_num_res_blocks: int, + image_channels: int, + latent_dim: int, + ): super().__init__() - channels = args.encoder_channels - resolution = args.encoder_start_resolution - layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)] + channels = encoder_channels + resolution = encoder_start_resolution + layers = [nn.Conv2d(image_channels, channels[0], 3, 1, 1)] for i in range(len(channels)-1): in_channels = channels[i] out_channels = channels[i + 1] - for _ in range(args.encoder_num_res_blocks): + for _ in range(encoder_num_res_blocks): layers.append(ResidualBlock(in_channels, out_channels)) in_channels = out_channels - if resolution in args.encoder_attn_resolutions: + if resolution in encoder_attn_resolutions: layers.append(NonLocalBlock(in_channels)) if i != len(channels)-2: layers.append(DownSampleBlock(channels[i+1])) @@ -425,7 +504,7 @@ def __init__(self, args): layers.append(ResidualBlock(channels[-1], channels[-1])) layers.append(GroupNorm(channels[-1])) layers.append(Swish()) - layers.append(nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1)) + layers.append(nn.Conv2d(channels[-1], latent_dim, 3, 1, 1)) self.model = nn.Sequential(*layers) def forward(self, x: th.Tensor) -> th.Tensor: @@ -433,14 +512,27 @@ def forward(self, x: th.Tensor) -> th.Tensor: class CondEncoder(nn.Module): - """Simpler MLP-based encoder for the CVQGAN if enabled.""" - def __init__(self, args): + """Simpler MLP-based encoder for the CVQGAN if enabled. + + Parameters: + c_fmap_dim (int): feature map dimension for the CVQGAN encoder output + c_input_dim (int): number of input features + c_hidden_dim (int): hidden dimension of the CVQGAN MLP + c_latent_dim (int): individual code dimension for CVQGAN + """ + def __init__( + self, + c_fmap_dim: int, + c_input_dim: int, + c_hidden_dim: int, + c_latent_dim: int + ): super().__init__() - self.c_fmap_dim = args.c_fmap_dim + self.c_fmap_dim = c_fmap_dim self.model = nn.Sequential( - LinearCombo(args.c_input_dim, args.c_hidden_dim), - LinearCombo(args.c_hidden_dim, args.c_hidden_dim), - nn.Linear(args.c_hidden_dim, args.c_latent_dim*args.c_fmap_dim**2) + LinearCombo(c_input_dim, c_hidden_dim), + LinearCombo(c_hidden_dim, c_hidden_dim), + nn.Linear(c_hidden_dim, c_latent_dim*c_fmap_dim**2) ) def forward(self, x: th.Tensor) -> th.Tensor: @@ -455,22 +547,38 @@ class Decoder(nn.Module): Consists of a series of convolutional, residual, and attention blocks arranged using the provided arguments. The number of upsample blocks is determined by the length of the decoder channels list minus one. For example, if decoder_channels=(128, 128, 128) and the starting resolution is 32, the decoder will upsample the input image twice, from 32x32 to 128x128. + + Parameters: + decoder_channels (tuple[int, ...]): list of channel sizes for each decoder layer + decoder_start_resolution (int): starting resolution for the decoder + decoder_attn_resolutions (tuple[int, ...]): list of resolutions at which to apply attention in the decoder + decoder_num_res_blocks (int): number of residual blocks per decoder layer + image_channels (int): number of channels in the output image + latent_dim (int): dimensionality of the latent space """ - def __init__(self, args): + def __init__( # noqa: PLR0913 + self, + decoder_channels: tuple[int, ...], + decoder_start_resolution: int, + decoder_attn_resolutions: tuple[int, ...], + decoder_num_res_blocks: int, + image_channels: int, + latent_dim: int + ): super().__init__() - in_channels = args.decoder_channels[0] - resolution = args.decoder_start_resolution - layers = [nn.Conv2d(args.latent_dim, in_channels, 3, 1, 1), + in_channels = decoder_channels[0] + resolution = decoder_start_resolution + layers = [nn.Conv2d(latent_dim, in_channels, 3, 1, 1), ResidualBlock(in_channels, in_channels), NonLocalBlock(in_channels), ResidualBlock(in_channels, in_channels)] - for i in range(len(args.decoder_channels)): - out_channels = args.decoder_channels[i] - for _ in range(args.decoder_num_res_blocks): + for i in range(len(decoder_channels)): + out_channels = decoder_channels[i] + for _ in range(decoder_num_res_blocks): layers.append(ResidualBlock(in_channels, out_channels)) in_channels = out_channels - if resolution in args.decoder_attn_resolutions: + if resolution in decoder_attn_resolutions: layers.append(NonLocalBlock(in_channels)) if i != 0: @@ -479,7 +587,7 @@ def __init__(self, args): layers.append(GroupNorm(in_channels)) layers.append(Swish()) - layers.append(nn.Conv2d(in_channels, args.image_channels, 3, 1, 1)) + layers.append(nn.Conv2d(in_channels, image_channels, 3, 1, 1)) self.model = nn.Sequential(*layers) def forward(self, x: th.Tensor) -> th.Tensor: @@ -487,14 +595,27 @@ def forward(self, x: th.Tensor) -> th.Tensor: class CondDecoder(nn.Module): - """Simpler MLP-based decoder for the CVQGAN if enabled.""" - def __init__(self, args): + """Simpler MLP-based decoder for the CVQGAN if enabled. + + Parameters: + c_fmap_dim (int): feature map dimension for the CVQGAN encoder output + c_input_dim (int): number of input features + c_hidden_dim (int): hidden dimension of the CVQGAN MLP + c_latent_dim (int): individual code dimension for CVQGAN + """ + def __init__( + self, + c_latent_dim: int, + c_input_dim: int, + c_hidden_dim: int, + c_fmap_dim: int + ): super().__init__() self.model = nn.Sequential( - LinearCombo(args.c_latent_dim*args.c_fmap_dim**2, args.c_hidden_dim), - LinearCombo(args.c_hidden_dim, args.c_hidden_dim), - nn.Linear(args.c_hidden_dim, args.c_input_dim) + LinearCombo(c_latent_dim*c_fmap_dim**2, c_hidden_dim), + LinearCombo(c_hidden_dim, c_hidden_dim), + nn.Linear(c_hidden_dim, c_input_dim) ) def forward(self, x: th.Tensor) -> th.Tensor: From e415de749aa3cfb241fd2a6ca00b087e223eb8f1 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Wed, 24 Sep 2025 17:59:14 +0200 Subject: [PATCH 06/22] add remaining cleaned up stage 1 modules --- engiopt/vqgan/vqgan.py | 296 +++++++++++++++++++++++++++++++++++++++-- pyproject.toml | 2 + 2 files changed, 286 insertions(+), 12 deletions(-) diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 93301bf..ce1f330 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -18,19 +18,24 @@ from __future__ import annotations +from collections import namedtuple from dataclasses import dataclass import os import random import time +from typing import Optional, TYPE_CHECKING from einops import rearrange from engibench.utils.all_problems import BUILTIN_PROBLEMS import matplotlib.pyplot as plt import numpy as np +import requests import torch as th from torch import autograd from torch import nn from torch.nn import functional as f +from torchvision.models import vgg16 +from torchvision.models import VGG16_Weights import tqdm import tyro import wandb @@ -40,6 +45,18 @@ from engiopt.transforms import resize_to from engiopt.transforms import upsample_nearest +if TYPE_CHECKING: + import logging + +# URL and checkpoint for LPIPS model +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + @dataclass class Args: @@ -95,17 +112,17 @@ class Args: perceptual_loss_factor: float = 1.0 """weighting factor for the perceptual loss""" encoder_channels: tuple[int, ...] = (128, 128, 128, 256, 256, 512) - """list of channel sizes for each encoder layer""" + """tuple of channel sizes for each encoder layer""" encoder_attn_resolutions: tuple[int, ...] = (16,) - """list of resolutions at which to apply attention in the encoder""" + """tuple of resolutions at which to apply attention in the encoder""" encoder_num_res_blocks: int = 2 """number of residual blocks per encoder layer""" encoder_start_resolution: int = 256 """starting resolution for the encoder""" decoder_channels: tuple[int, ...] = (512, 256, 256, 128, 128) - """list of channel sizes for each decoder layer""" + """tuple of channel sizes for each decoder layer""" decoder_attn_resolutions: tuple[int, ...] = (16,) - """list of resolutions at which to apply attention in the decoder""" + """tuple of resolutions at which to apply attention in the decoder""" decoder_num_res_blocks: int = 3 """number of residual blocks per decoder layer""" sample_interval: int = 1600 @@ -461,16 +478,16 @@ def forward(self, x: th.Tensor) -> th.Tensor: class Encoder(nn.Module): - """Encoder module for VQGAN. + """Encoder module for VQGAN Stage 1. Consists of a series of convolutional, residual, and attention blocks arranged using the provided arguments. - The number of downsample blocks is determined by the length of the encoder channels list minus two. + The number of downsample blocks is determined by the length of the encoder channels tuple minus two. For example, if encoder_channels=(128, 128, 128, 128) and the starting resolution is 128, the encoder will downsample the input image twice, from 128x128 to 32x32. Parameters: - encoder_channels (tuple[int, ...]): list of channel sizes for each encoder layer + encoder_channels (tuple[int, ...]): tuple of channel sizes for each encoder layer encoder_start_resolution (int): starting resolution for the encoder - encoder_attn_resolutions (tuple[int, ...]): list of resolutions at which to apply attention in the encoder + encoder_attn_resolutions (tuple[int, ...]): tuple of resolutions at which to apply attention in the encoder encoder_num_res_blocks (int): number of residual blocks per encoder layer image_channels (int): number of channels in the input image latent_dim (int): dimensionality of the latent space @@ -542,16 +559,16 @@ def forward(self, x: th.Tensor) -> th.Tensor: class Decoder(nn.Module): - """Decoder module for VQGAN. + """Decoder module for VQGAN Stage 1. Consists of a series of convolutional, residual, and attention blocks arranged using the provided arguments. - The number of upsample blocks is determined by the length of the decoder channels list minus one. + The number of upsample blocks is determined by the length of the decoder channels tuple minus one. For example, if decoder_channels=(128, 128, 128) and the starting resolution is 32, the decoder will upsample the input image twice, from 32x32 to 128x128. Parameters: - decoder_channels (tuple[int, ...]): list of channel sizes for each decoder layer + decoder_channels (tuple[int, ...]): tuple of channel sizes for each decoder layer decoder_start_resolution (int): starting resolution for the decoder - decoder_attn_resolutions (tuple[int, ...]): list of resolutions at which to apply attention in the decoder + decoder_attn_resolutions (tuple[int, ...]): tuple of resolutions at which to apply attention in the decoder decoder_num_res_blocks (int): number of residual blocks per decoder layer image_channels (int): number of channels in the output image latent_dim (int): dimensionality of the latent space @@ -620,3 +637,258 @@ def __init__( def forward(self, x: th.Tensor) -> th.Tensor: return self.model(x.contiguous().view(len(x), -1)) + + +class Discriminator(nn.Module): + """PatchGAN-style discriminator. + + Adapted from: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538 + + Parameters: + num_filters_last: Number of filters in the last conv layer. + n_layers: Number of convolutional layers. + image_channels: Number of channels in the input image. + image_size: Spatial size (H=W) of the input image. + """ + + def __init__( + self, + *, + num_filters_last: int = 64, + n_layers: int = 3, + image_channels: int = 1, + image_size: int = 128, + ) -> None: + super().__init__() + + # Convolutional backbone (PatchGAN) + layers: list[nn.Module] = [ + nn.Conv2d(image_channels, num_filters_last, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ] + num_filters_mult = 1 + + for i in range(1, n_layers + 1): + num_filters_mult_last = num_filters_mult + num_filters_mult = min(2**i, 8) + layers += [ + nn.Conv2d( + num_filters_last * num_filters_mult_last, + num_filters_last * num_filters_mult, + kernel_size=4, + stride=2 if i < n_layers else 1, + padding=1, + bias=False, + ), + nn.BatchNorm2d(num_filters_last * num_filters_mult), + nn.LeakyReLU(0.2, inplace=True), + ] + + layers.append( + nn.Conv2d(num_filters_last * num_filters_mult, 1, kernel_size=4, stride=1, padding=1) + ) + self.model = nn.Sequential(*layers) + + # Adapter for CVQGAN latent vectors → image + self.cvqgan_adapter = nn.Sequential( + nn.Linear(image_channels, image_size), + nn.ReLU(inplace=True), + nn.Linear(image_size, image_channels * image_size**2), + nn.ReLU(inplace=True), + nn.Unflatten(1, (image_channels, image_size, image_size)), + ) + + # Initialize weights + self.apply(self._weights_init) + + + @staticmethod + def _weights_init(m: nn.Module) -> None: + """Custom weight initialization (DCGAN-style).""" + classname = m.__class__.__name__ + if "Conv" in classname: + nn.init.normal_(m.weight.data, mean=0.0, std=0.02) + elif "BatchNorm" in classname: + nn.init.normal_(m.weight.data, mean=1.0, std=0.02) + nn.init.constant_(m.bias.data, 0.0) + + + def forward(self, x: th.Tensor, *, is_cvqgan: bool = False) -> th.Tensor: + """Forward pass with optional CVQGAN adapter.""" + if is_cvqgan: + x = self.cvqgan_adapter(x) + return self.model(x) + + +class ScalingLayer(nn.Module): + """Channel-wise affine normalization used by LPIPS.""" + + def __init__(self) -> None: + super().__init__() + self.register_buffer("shift", th.tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer("scale", th.tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, x: th.Tensor) -> th.Tensor: + return (x - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """1x1 conv with dropout (per-layer LPIPS linear head).""" + + def __init__(self, in_channels: int, out_channels: int = 1) -> None: + super().__init__() + self.model = nn.Sequential( + nn.Dropout(), + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + return self.model(x) + + +class VGG16(nn.Module): + """Torchvision VGG16 feature extractor sliced at LPIPS tap points.""" + + def __init__(self) -> None: + super().__init__() + vgg_feats = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features + blocks = [vgg_feats[i] for i in range(30)] + self.slice1 = nn.Sequential(*blocks[0:4]) # relu1_2 + self.slice2 = nn.Sequential(*blocks[4:9]) # relu2_2 + self.slice3 = nn.Sequential(*blocks[9:16]) # relu3_3 + self.slice4 = nn.Sequential(*blocks[16:23]) # relu4_3 + self.slice5 = nn.Sequential(*blocks[23:30]) # relu5_3 + self.requires_grad_(requires_grad=False) + + def forward(self, x: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: + h1 = self.slice1(x) + h2 = self.slice2(h1) + h3 = self.slice3(h2) + h4 = self.slice4(h3) + h5 = self.slice5(h4) + return (h1, h2, h3, h4, h5) + + +class GreyscaleLPIPS(nn.Module): + """LPIPS for greyscale/topological data with optional 'raw' aggregation. + + ``use_raw=True`` is often preferable for non-natural images since learned + linear heads are tuned on natural RGB photos. + + Parameters: + use_raw: If True, average raw per-layer squared diffs (no linear heads). + clamp_output: Clamp the final loss to ``>= 0``. + robust_clamp: Clamp inputs to [0, 1] before feature extraction. + warn_on_clamp: If True, log warnings when inputs fall outside [0, 1]. + freeze: If True, disables grads on all params. + ckpt_name: Key in URL_MAP/CKPT_MAP for loading LPIPS heads. + logger: Optional logger for non-intrusive messages/warnings. + """ + + def __init__( # noqa: PLR0913 + self, + *, + use_raw: bool = True, + clamp_output: bool = False, + robust_clamp: bool = True, + warn_on_clamp: bool = False, + freeze: bool = True, + ckpt_name: str = "vgg_lpips", + logger: logging.Logger | None = None, + ) -> None: + super().__init__() + self.use_raw = use_raw + self.clamp_output = clamp_output + self.robust_clamp = robust_clamp + self.warn_on_clamp = warn_on_clamp + self._logger = logger + + self.scaling_layer = ScalingLayer() + self.channels = (64, 128, 256, 512, 512) + self.vgg = VGG16() + self.lins = nn.ModuleList([NetLinLayer(c) for c in self.channels]) + + self._load_from_pretrained(name=ckpt_name) + if freeze: + self.requires_grad_(requires_grad=False) + + + def forward(self, real_x: th.Tensor, fake_x: th.Tensor) -> th.Tensor: + """Compute greyscale-aware LPIPS distance between two batches.""" + if self.warn_on_clamp and self._logger is not None: + with th.no_grad(): + if (fake_x < 0).any() or (fake_x > 1).any(): + self._logger.warning( + "GreyscaleLPIPS: generated input outside [0,1]: [%.4f, %.4f]", + float(fake_x.min().item()), float(fake_x.max().item()), + ) + if (real_x < 0).any() or (real_x > 1).any(): + self._logger.warning( + "GreyscaleLPIPS: reference input outside [0,1]: [%.4f, %.4f]", + float(real_x.min().item()), float(real_x.max().item()), + ) + + if self.robust_clamp: + real_x = th.clamp(real_x, 0.0, 1.0) + fake_x = th.clamp(fake_x, 0.0, 1.0) + + # Promote greyscale → RGB for VGG features + if real_x.shape[1] == 1: + real_x = real_x.repeat(1, 3, 1, 1) + if fake_x.shape[1] == 1: + fake_x = fake_x.repeat(1, 3, 1, 1) + + fr = self.vgg(self.scaling_layer(real_x)) + ff = self.vgg(self.scaling_layer(fake_x)) + diffs = [(self._norm_tensor(a) - self._norm_tensor(b)) ** 2 for a, b in zip(fr, ff)] + + if self.use_raw: + parts = [self._spatial_average(d).mean(dim=1, keepdim=True) for d in diffs] + else: + parts = [self._spatial_average(self.lins[i](d)) for i, d in enumerate(diffs)] + + loss = th.stack(parts, dim=0).sum() + if self.clamp_output: + loss = th.clamp(loss, min=0.0) + return loss + + # Helpers + @staticmethod + def _norm_tensor(x: th.Tensor) -> th.Tensor: + """L2-normalize channels per spatial location: BxCxHxW → BxCxHxW.""" + norm = th.sqrt(th.sum(x**2, dim=1, keepdim=True)) + return x / (norm + 1e-10) + + @staticmethod + def _spatial_average(x: th.Tensor) -> th.Tensor: + """Average over spatial dimensions with dims kept: BxCxHxW → BxCx1x1.""" + return x.mean(dim=(2, 3), keepdim=True) + + def _load_from_pretrained(self, *, name: str) -> None: + """Load LPIPS linear heads (and any required buffers) from a checkpoint.""" + ckpt = self._get_ckpt_path(name, "vgg_lpips") + state_dict = th.load(ckpt, map_location=th.device("cpu"), weights_only=True) + self.load_state_dict(state_dict, strict=False) + + @staticmethod + def _download(url: str, local_path: str, *, chunk_size: int = 1024) -> None: + """Stream a file to disk with a progress bar.""" + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True, timeout=10) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm.tqdm(total=total_size, unit="B", unit_scale=True) as pbar, open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(len(data)) + + def _get_ckpt_path(self, name: str, root: str) -> str: + """Return local path to a pretrained LPIPS checkpoint; download if missing.""" + assert name in URL_MAP, f"Unknown LPIPS checkpoint name: {name!r}" + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path): + if self._logger is not None: + self._logger.info("Downloading LPIPS weights '%s' from %s to %s", name, URL_MAP[name], path) + self._download(URL_MAP[name], path) + return path + diff --git a/pyproject.toml b/pyproject.toml index 0ce82a7..e79ad93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -293,5 +293,7 @@ module = [ "plotly.*", "einops", "einops.*", + "torchvision.*", + "requests", ] ignore_missing_imports = true From 97efc89b84cf8dfb12983fe3b30a4798b09ba49c Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Wed, 24 Sep 2025 18:57:21 +0200 Subject: [PATCH 07/22] add vqgan stage 1 main class and resolve arg names --- engiopt/vqgan/vqgan.py | 195 ++++++++++++++++++++++++++++++++++------- 1 file changed, 164 insertions(+), 31 deletions(-) diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index ce1f330..f59d6d9 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -24,6 +24,7 @@ import random import time from typing import Optional, TYPE_CHECKING +import warnings from einops import rearrange from engibench.utils.all_problems import BUILTIN_PROBLEMS @@ -101,7 +102,7 @@ class Args: """number of cpu threads to use during batch generation""" latent_dim: int = 16 """dimensionality of the latent space""" - codebook_vectors: int = 256 + num_codebook_vectors: int = 256 """number of vectors in the codebook""" disc_start: int = 0 """epoch to start discriminator training""" @@ -135,7 +136,7 @@ class Args: """hidden dimension of the CVQGAN MLP""" cond_latent_dim: int = 4 "individual code dimension for CVQGAN" - cond_codebook_vectors: int = 256 + cond_codebook_vectors: int = 64 """number of vectors in the CVQGAN codebook""" cond_feature_map_dim: int = 4 """feature map dimension for the CVQGAN encoder output""" @@ -202,7 +203,6 @@ def __init__( # noqa: PLR0913 self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed) self.register_buffer("embed_prob", th.zeros(self.num_embed)) - def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: # reshape z -> (batch, height, width, channel) and flatten z = rearrange(z, "b c h w -> b h w c").contiguous() @@ -532,30 +532,30 @@ class CondEncoder(nn.Module): """Simpler MLP-based encoder for the CVQGAN if enabled. Parameters: - c_fmap_dim (int): feature map dimension for the CVQGAN encoder output - c_input_dim (int): number of input features - c_hidden_dim (int): hidden dimension of the CVQGAN MLP - c_latent_dim (int): individual code dimension for CVQGAN + cond_feature_map_dim (int): feature map dimension for the CVQGAN encoder output + cond_dim (int): number of input features + cond_hidden_dim (int): hidden dimension of the CVQGAN MLP + cond_latent_dim (int): individual code dimension for CVQGAN """ def __init__( self, - c_fmap_dim: int, - c_input_dim: int, - c_hidden_dim: int, - c_latent_dim: int + cond_feature_map_dim: int, + cond_dim: int, + cond_hidden_dim: int, + cond_latent_dim: int ): super().__init__() - self.c_fmap_dim = c_fmap_dim + self.c_feature_map_dim = cond_feature_map_dim self.model = nn.Sequential( - LinearCombo(c_input_dim, c_hidden_dim), - LinearCombo(c_hidden_dim, c_hidden_dim), - nn.Linear(c_hidden_dim, c_latent_dim*c_fmap_dim**2) + LinearCombo(cond_dim, cond_hidden_dim), + LinearCombo(cond_hidden_dim, cond_hidden_dim), + nn.Linear(cond_hidden_dim, cond_latent_dim*cond_feature_map_dim**2) ) def forward(self, x: th.Tensor) -> th.Tensor: encoded = self.model(x) s = encoded.shape - return encoded.view(s[0], s[1]//self.c_fmap_dim**2, self.c_fmap_dim, self.c_fmap_dim) + return encoded.view(s[0], s[1]//self.c_feature_map_dim**2, self.c_feature_map_dim, self.c_feature_map_dim) class Decoder(nn.Module): @@ -615,24 +615,24 @@ class CondDecoder(nn.Module): """Simpler MLP-based decoder for the CVQGAN if enabled. Parameters: - c_fmap_dim (int): feature map dimension for the CVQGAN encoder output - c_input_dim (int): number of input features - c_hidden_dim (int): hidden dimension of the CVQGAN MLP - c_latent_dim (int): individual code dimension for CVQGAN + cond_feature_map_dim (int): feature map dimension for the CVQGAN encoder output + cond_dim (int): number of input features + cond_hidden_dim (int): hidden dimension of the CVQGAN MLP + cond_latent_dim (int): individual code dimension for CVQGAN """ def __init__( self, - c_latent_dim: int, - c_input_dim: int, - c_hidden_dim: int, - c_fmap_dim: int + cond_latent_dim: int, + cond_dim: int, + cond_hidden_dim: int, + cond_feature_map_dim: int ): super().__init__() self.model = nn.Sequential( - LinearCombo(c_latent_dim*c_fmap_dim**2, c_hidden_dim), - LinearCombo(c_hidden_dim, c_hidden_dim), - nn.Linear(c_hidden_dim, c_input_dim) + LinearCombo(cond_latent_dim*cond_feature_map_dim**2, cond_hidden_dim), + LinearCombo(cond_hidden_dim, cond_hidden_dim), + nn.Linear(cond_hidden_dim, cond_dim) ) def forward(self, x: th.Tensor) -> th.Tensor: @@ -689,7 +689,7 @@ def __init__( ) self.model = nn.Sequential(*layers) - # Adapter for CVQGAN latent vectors → image + # Adapter for CVQGAN latent vectors -> image self.cvqgan_adapter = nn.Sequential( nn.Linear(image_channels, image_size), nn.ReLU(inplace=True), @@ -832,7 +832,7 @@ def forward(self, real_x: th.Tensor, fake_x: th.Tensor) -> th.Tensor: real_x = th.clamp(real_x, 0.0, 1.0) fake_x = th.clamp(fake_x, 0.0, 1.0) - # Promote greyscale → RGB for VGG features + # Promote greyscale -> RGB for VGG features if real_x.shape[1] == 1: real_x = real_x.repeat(1, 3, 1, 1) if fake_x.shape[1] == 1: @@ -855,13 +855,13 @@ def forward(self, real_x: th.Tensor, fake_x: th.Tensor) -> th.Tensor: # Helpers @staticmethod def _norm_tensor(x: th.Tensor) -> th.Tensor: - """L2-normalize channels per spatial location: BxCxHxW → BxCxHxW.""" + """L2-normalize channels per spatial location: BxCxHxW -> BxCxHxW.""" norm = th.sqrt(th.sum(x**2, dim=1, keepdim=True)) return x / (norm + 1e-10) @staticmethod def _spatial_average(x: th.Tensor) -> th.Tensor: - """Average over spatial dimensions with dims kept: BxCxHxW → BxCx1x1.""" + """Average over spatial dimensions with dims kept: BxCxHxW -> BxCx1x1.""" return x.mean(dim=(2, 3), keepdim=True) def _load_from_pretrained(self, *, name: str) -> None: @@ -892,3 +892,136 @@ def _get_ckpt_path(self, name: str, root: str) -> str: self._download(URL_MAP[name], path) return path + +class VQGAN(nn.Module): + """VQGAN model for Stage 1. + + Can be configured as a CVQGAN if desired. + + Parameters: + device (th.device): torch device to use + + **CVQGAN params** + is_c (bool): If True, use CVQGAN architecture (MLP-based encoder/decoder). + cond_feature_map_dim (int): Feature map dimension for the CVQGAN encoder output. + cond_dim (int): Number of input features for the CVQGAN encoder. + cond_hidden_dim (int): Hidden dimension of the CVQGAN MLP. + cond_latent_dim (int): Individual code dimension for CVQGAN. + cond_codebook_vectors (int): Number of codebook vectors for CVQGAN. + + **VQGAN params** + encoder_channels (tuple[int, ...]): Tuple of channel sizes for each encoder layer. + encoder_start_resolution (int): Starting resolution for the encoder. + encoder_attn_resolutions (tuple[int, ...]): Tuple of resolutions at which to apply attention in the encoder. + encoder_num_res_blocks (int): Number of residual blocks per encoder layer. + decoder_channels (tuple[int, ...]): Tuple of channel sizes for each decoder layer. + decoder_start_resolution (int): Starting resolution for the decoder. + decoder_attn_resolutions (tuple[int, ...]): Tuple of resolutions at which to apply attention in the decoder. + decoder_num_res_blocks (int): Number of residual blocks per decoder layer. + image_channels (int): Number of channels in the input/output image. + latent_dim (int): Dimensionality of the latent space. + num_codebook_vectors (int): Number of codebook vectors. + """ + def __init__( # noqa: PLR0913 + self, *, + device: th.device, + + # CVQGAN parameters + is_c: bool = False, + cond_feature_map_dim: int = 4, + cond_dim: int = 3, + cond_hidden_dim: int = 256, + cond_latent_dim: int = 4, + cond_codebook_vectors: int = 64, + + # VQGAN + Codebook parameters + encoder_channels: tuple[int, ...], + encoder_start_resolution: int, + encoder_attn_resolutions: tuple[int, ...], + encoder_num_res_blocks: int, + decoder_channels: tuple[int, ...], + decoder_start_resolution: int, + decoder_attn_resolutions: tuple[int, ...], + decoder_num_res_blocks: int, + image_channels: int = 1, + latent_dim: int = 16, + num_codebook_vectors: int = 256, + + ): + super().__init__() + if is_c: + self.encoder = CondEncoder( + cond_feature_map_dim, + cond_dim, + cond_hidden_dim, + cond_latent_dim + ).to(device=device) + + self.decoder = CondDecoder( + cond_latent_dim, + cond_dim, + cond_hidden_dim, + cond_feature_map_dim + ).to(device=device) + + self.quant_conv = nn.Conv2d(cond_latent_dim, cond_latent_dim, 1).to(device=device) + self.post_quant_conv = nn.Conv2d(cond_latent_dim, cond_latent_dim, 1).to(device=device) + else: + self.encoder = Encoder( + encoder_channels, + encoder_start_resolution, + encoder_attn_resolutions, + encoder_num_res_blocks, + image_channels, + latent_dim + ).to(device=device) + + self.decoder = Decoder( + decoder_channels, + decoder_start_resolution, + decoder_attn_resolutions, + decoder_num_res_blocks, + image_channels, + latent_dim + ).to(device=device) + + self.quant_conv = nn.Conv2d(latent_dim, latent_dim, 1).to(device=device) + self.post_quant_conv = nn.Conv2d(latent_dim, latent_dim, 1).to(device=device) + + self.codebook = Codebook( + num_codebook_vectors = cond_codebook_vectors if is_c else num_codebook_vectors, + latent_dim = cond_latent_dim if is_c else latent_dim + ).to(device=device) + + def forward(self, imgs: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor]: + """Full VQGAN forward pass.""" + encoded = self.encoder(imgs) + quant_encoded = self.quant_conv(encoded) + quant, indices, q_loss = self.codebook(quant_encoded) + post_quant = self.post_quant_conv(quant) + decoded = self.decoder(post_quant) + return decoded, indices, q_loss + + def encode(self, imgs: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor]: + """Encode image batch into quantized latent representation.""" + encoded = self.encoder(imgs) + quant_encoded = self.quant_conv(encoded) + return self.codebook(quant_encoded) + + def decode(self, z: th.Tensor) -> th.Tensor: + """Decode quantized latent representation back to image space.""" + return self.decoder(self.post_quant_conv(z)) + + def calculate_lambda(self, perceptual_loss: th.Tensor, gan_loss: th.Tensor) -> th.Tensor: + """Compute balancing factor λ between discriminator loss and the remaining loss terms.""" + last_layer = self.decoder.model[-1] + last_weight = last_layer.weight + grad_perc = th.autograd.grad(perceptual_loss, last_weight, retain_graph=True)[0] + grad_gan = th.autograd.grad(gan_loss, last_weight, retain_graph=True)[0] + lamb = th.norm(grad_perc) / (th.norm(grad_gan) + 1e-4) + return 0.8 * th.clamp(lamb, 0.0, 1e4).detach() + + @staticmethod + def adopt_weight(disc_factor: float, i: int, threshold: int, value: float = 0.0) -> float: + """Adopt weight scheduling: zero out `disc_factor` before threshold.""" + return value if i < threshold else disc_factor From f9d565d23dba118dfb960c6a2da551a7de828d4c Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Thu, 25 Sep 2025 11:37:44 +0200 Subject: [PATCH 08/22] add vqgan stage 2 modules --- .pre-commit-config.yaml | 2 +- engiopt/vqgan/vqgan.py | 539 ++++++++++++++++++++++++++++++++++++++-- pyproject.toml | 3 +- 3 files changed, 515 insertions(+), 29 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fa1f04f..23a0869 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: name: pyright entry: pyright language: node - pass_filenames: false + pass_filenames: true types: [python] additional_dependencies: ["pyright@1.1.347"] args: diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index f59d6d9..b850c3a 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -20,6 +20,8 @@ from collections import namedtuple from dataclasses import dataclass +import inspect +import math import os import random import time @@ -38,6 +40,7 @@ from torchvision.models import vgg16 from torchvision.models import VGG16_Weights import tqdm +from transformers import GPT2LMHeadModel import tyro import wandb @@ -129,18 +132,24 @@ class Args: sample_interval: int = 1600 """interval between image samples""" - # Algorithm-specific: Stage 1 (Conditional AE if the model is conditional) + # Algorithm-specific: Stage 1 (Conditional AE or "CVQGAN" if the model is conditional) + # Note that a Discriminator is not used for CVQGAN, as it is generally a much simpler model. cond_dim: int = 3 """dimensionality of the condition space""" cond_hidden_dim: int = 256 """hidden dimension of the CVQGAN MLP""" cond_latent_dim: int = 4 - "individual code dimension for CVQGAN" + """individual code dimension for CVQGAN""" cond_codebook_vectors: int = 64 """number of vectors in the CVQGAN codebook""" cond_feature_map_dim: int = 4 """feature map dimension for the CVQGAN encoder output""" - + cond_epochs: int = 100 + """number of epochs of CVQGAN training""" + cond_lr: float = 2e-4 + """learning rate for CVQGAN""" + cond_sample_interval: int = 1600 + """interval between CVQGAN image samples""" # Algorithm-specific: Stage 2 (Transformer) # From original implementation: assume pkeep=1.0, sos_token=0, bias=True @@ -281,7 +290,7 @@ class FeaturePool: This buffer enables us to initialize the codebook using a history of generated features rather than the ones produced by the latest encoders. Parameters: - pool_size (int): the size of featue buffer + pool_size (int): the size of feature buffer dim (int): the dimension of each feature """ def __init__( @@ -643,22 +652,20 @@ class Discriminator(nn.Module): """PatchGAN-style discriminator. Adapted from: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538 + This assumes we never use a discriminator for the CVQGAN, since it is generally a much simpler model. Parameters: num_filters_last: Number of filters in the last conv layer. n_layers: Number of convolutional layers. image_channels: Number of channels in the input image. - image_size: Spatial size (H=W) of the input image. """ def __init__( self, - *, num_filters_last: int = 64, n_layers: int = 3, - image_channels: int = 1, - image_size: int = 128, - ) -> None: + image_channels: int = 1 + ): super().__init__() # Convolutional backbone (PatchGAN) @@ -689,15 +696,6 @@ def __init__( ) self.model = nn.Sequential(*layers) - # Adapter for CVQGAN latent vectors -> image - self.cvqgan_adapter = nn.Sequential( - nn.Linear(image_channels, image_size), - nn.ReLU(inplace=True), - nn.Linear(image_size, image_channels * image_size**2), - nn.ReLU(inplace=True), - nn.Unflatten(1, (image_channels, image_size, image_size)), - ) - # Initialize weights self.apply(self._weights_init) @@ -713,17 +711,15 @@ def _weights_init(m: nn.Module) -> None: nn.init.constant_(m.bias.data, 0.0) - def forward(self, x: th.Tensor, *, is_cvqgan: bool = False) -> th.Tensor: + def forward(self, x: th.Tensor) -> th.Tensor: """Forward pass with optional CVQGAN adapter.""" - if is_cvqgan: - x = self.cvqgan_adapter(x) return self.model(x) class ScalingLayer(nn.Module): """Channel-wise affine normalization used by LPIPS.""" - def __init__(self) -> None: + def __init__(self): super().__init__() self.register_buffer("shift", th.tensor([-.030, -.088, -.188])[None, :, None, None]) self.register_buffer("scale", th.tensor([.458, .448, .450])[None, :, None, None]) @@ -735,7 +731,7 @@ def forward(self, x: th.Tensor) -> th.Tensor: class NetLinLayer(nn.Module): """1x1 conv with dropout (per-layer LPIPS linear head).""" - def __init__(self, in_channels: int, out_channels: int = 1) -> None: + def __init__(self, in_channels: int, out_channels: int = 1): super().__init__() self.model = nn.Sequential( nn.Dropout(), @@ -749,7 +745,7 @@ def forward(self, x: th.Tensor) -> th.Tensor: class VGG16(nn.Module): """Torchvision VGG16 feature extractor sliced at LPIPS tap points.""" - def __init__(self) -> None: + def __init__(self): super().__init__() vgg_feats = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features blocks = [vgg_feats[i] for i in range(30)] @@ -795,7 +791,7 @@ def __init__( # noqa: PLR0913 freeze: bool = True, ckpt_name: str = "vgg_lpips", logger: logging.Logger | None = None, - ) -> None: + ): super().__init__() self.use_raw = use_raw self.clamp_output = clamp_output @@ -806,7 +802,7 @@ def __init__( # noqa: PLR0913 self.scaling_layer = ScalingLayer() self.channels = (64, 128, 256, 512, 512) self.vgg = VGG16() - self.lins = nn.ModuleList([NetLinLayer(c) for c in self.channels]) + self.linears = nn.ModuleList([NetLinLayer(c) for c in self.channels]) self._load_from_pretrained(name=ckpt_name) if freeze: @@ -845,7 +841,7 @@ def forward(self, real_x: th.Tensor, fake_x: th.Tensor) -> th.Tensor: if self.use_raw: parts = [self._spatial_average(d).mean(dim=1, keepdim=True) for d in diffs] else: - parts = [self._spatial_average(self.lins[i](d)) for i, d in enumerate(diffs)] + parts = [self._spatial_average(self.linears[i](d)) for i, d in enumerate(diffs)] loss = th.stack(parts, dim=0).sum() if self.clamp_output: @@ -1025,3 +1021,492 @@ def calculate_lambda(self, perceptual_loss: th.Tensor, gan_loss: th.Tensor) -> t def adopt_weight(disc_factor: float, i: int, threshold: int, value: float = 0.0) -> float: """Adopt weight scheduling: zero out `disc_factor` before threshold.""" return value if i < threshold else disc_factor + + +########################################### +########## GPT-2 BASE CODE BELOW ########## +########################################### +class LayerNorm(nn.Module): + """LayerNorm with optional bias (PyTorch lacks bias=False support).""" + + def __init__(self, ndim: int, *, bias: bool): + super().__init__() + self.weight = nn.Parameter(th.ones(ndim)) + self.bias = nn.Parameter(th.zeros(ndim)) if bias else None + + def forward(self, x: th.Tensor) -> th.Tensor: + return f.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) + + +class CausalSelfAttention(nn.Module): + """Causal self-attention with FlashAttention fallback when unavailable.""" + + def __init__(self, config: GPTConfig): + super().__init__() + assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head" + + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + + self.flash = hasattr(f, "scaled_dot_product_attention") + if not self.flash: + warnings.warn( + "Falling back to non-flash attention; PyTorch >= 2.0 enables FlashAttention.", + stacklevel=2, + ) + self.register_buffer( + "bias", + th.tril(th.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + b, t, c = x.size() + + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(b, t, self.n_head, c // self.n_head).transpose(1, 2) + q = q.view(b, t, self.n_head, c // self.n_head).transpose(1, 2) + v = v.view(b, t, self.n_head, c // self.n_head).transpose(1, 2) + + if self.flash: + y = f.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + ) + else: + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float("-inf")) + att = f.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v + + y = y.transpose(1, 2).contiguous().view(b, t, c) + return self.resid_dropout(self.c_proj(y)) + + +class MLP(nn.Module): + """Feed-forward block used inside Transformer blocks.""" + + def __init__(self, config: GPTConfig): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.gelu = nn.GELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x: th.Tensor) -> th.Tensor: + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + return self.dropout(x) + + +class Block(nn.Module): + """Transformer block: LayerNorm -> Self-Attn -> residual; LayerNorm -> MLP -> residual.""" + + def __init__(self, config: GPTConfig): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + + def forward(self, x: th.Tensor) -> th.Tensor: + x = x + self.attn(self.ln_1(x)) + return x + self.mlp(self.ln_2(x)) + + +@dataclass +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50304 # GPT-2 uses 50257; padded to multiple of 64 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # GPT-2 uses biases in Linear/LayerNorm + + +class GPT(nn.Module): + """Minimal GPT-2 style Transformer with HF weight import.""" + + def __init__(self, config: GPTConfig): + super().__init__() + self.config = config + + self.transformer = nn.ModuleDict( + { + "wte": nn.Embedding(config.vocab_size, config.n_embd), + "wpe": nn.Embedding(config.block_size, config.n_embd), + "drop": nn.Dropout(config.dropout), + "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + "ln_f": LayerNorm(config.n_embd, bias=config.bias), + } + ) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.transformer["wte"].weight = self.lm_head.weight # weight tying + + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + th.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def get_num_params(self, *, non_embedding: bool = True) -> int: + """Return total parameter count (optionally excluding position embeddings).""" + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.transformer["wpe"].weight.numel() + return n_params + + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, nn.Linear): + th.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + th.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + th.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward( + self, + idx: th.Tensor, + targets: th.Tensor | None = None, + ) -> tuple[th.Tensor, th.Tensor | None]: + """Forward pass returning logits and optional cross-entropy loss.""" + device = idx.device + _, t = idx.size() + assert t <= self.config.block_size, ( + f"Cannot forward sequence of length {t}; block size is {self.config.block_size}" + ) + pos = th.arange(0, t, dtype=th.long, device=device) + + tok_emb = self.transformer["wte"](idx) + pos_emb = self.transformer["wpe"](pos) + x = self.transformer["drop"](tok_emb + pos_emb) + for block in self.transformer["h"]: + x = block(x) + x = self.transformer["ln_f"](x) + + logits = self.lm_head(x) + loss: th.Tensor | None + if targets is not None: + loss = f.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + else: + loss = None + return logits, loss + + def crop_block_size(self, block_size: int) -> None: + """Reduce maximum context length and trim position embeddings.""" + assert block_size <= self.config.block_size + self.config.block_size = block_size + self.transformer["wpe"].weight = nn.Parameter(self.transformer["wpe"].weight[:block_size]) + for block in self.transformer["h"]: + attn = block.attn + if hasattr(attn, "bias"): + attn.bias = attn.bias[:, :, :block_size, :block_size] + + @classmethod + def from_pretrained( + cls, + model_type: str, + override_args: dict[str, float] | None = None, + ) -> GPT: + """Load HF GPT-2 weights into this minimal GPT implementation.""" + assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} + override_args = override_args or {} + assert all(k == "dropout" for k in override_args), "Only 'dropout' can be overridden" + + cfg_map: dict[str, dict[str, int]] = { + "gpt2": {"n_layer": 12, "n_head": 12, "n_embd": 768}, + "gpt2-medium": {"n_layer": 24, "n_head": 16, "n_embd": 1024}, + "gpt2-large": {"n_layer": 36, "n_head": 20, "n_embd": 1280}, + "gpt2-xl": {"n_layer": 48, "n_head": 25, "n_embd": 1600}, + } + + # Use object so we can mix int, float, and bool + config_args: dict[str, object] = dict(cfg_map[model_type]) + config_args.update({"vocab_size": 50257, "block_size": 1024, "bias": True}) + + if "dropout" in override_args: + config_args["dropout"] = float(override_args["dropout"]) + + config = GPTConfig(**config_args) # type: ignore[arg-type] + model = GPT(config) + + sd = model.state_dict() + sd_keys = [k for k in sd if not k.endswith(".attn.bias")] + + hf: GPT2LMHeadModel = GPT2LMHeadModel.from_pretrained(model_type) + sd_hf = hf.state_dict() + sd_keys_hf = [ + k + for k in sd_hf + if not (k.endswith((".attn.masked_bias", ".attn.bias"))) + ] + + transposed = {"attn.c_attn.weight", "attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight"} + assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" + + for k in sd_keys_hf: + if any(k.endswith(w) for w in transposed): + assert sd_hf[k].shape[::-1] == sd[k].shape + with th.no_grad(): + sd[k].copy_(sd_hf[k].t()) + else: + assert sd_hf[k].shape == sd[k].shape + with th.no_grad(): + sd[k].copy_(sd_hf[k]) + + return model + + def configure_optimizers( + self, + weight_decay: float, + learning_rate: float, + betas: tuple[float, float], + device_type: str, + ) -> th.optim.Optimizer: + """Create AdamW with decoupled weight decay for matrix weights only.""" + param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} + dim_threshold = 2 + decay_params = [p for p in param_dict.values() if p.dim() >= dim_threshold] + nodecay_params = [p for p in param_dict.values() if p.dim() < dim_threshold] + optim_groups = [ + {"params": decay_params, "weight_decay": weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ] + + fused_available = "fused" in inspect.signature(th.optim.AdamW).parameters + use_fused = bool(fused_available and device_type == "cuda") + extra_args: dict[str, object] = {"fused": True} if use_fused else {} + return th.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + + def estimate_mfu(self, fwdbwd_per_iter: int, dt: float) -> float: + """Estimate model FLOPS utilization relative to A100 bf16 peak (312 TFLOPS).""" + n = self.get_num_params() + cfg = self.config + l, h, q, t = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size + flops_per_token = 6 * n + 12 * l * h * q * t + flops_per_fwdbwd = flops_per_token * t + flops_per_iter = flops_per_fwdbwd * float(fwdbwd_per_iter) + flops_achieved = flops_per_iter * (1.0 / dt) + flops_peak = 312e12 + return float(flops_achieved / flops_peak) + + @th.no_grad() + def generate( + self, + idx: th.Tensor, + max_new_tokens: int, + *, + temperature: float = 1.0, + top_k: int | None = None, + ) -> th.Tensor: + """Autoregressively sample tokens conditioned on idx.""" + for _ in range(max_new_tokens): + idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :] + logits, _ = self(idx_cond) + logits = logits[:, -1, :] / temperature + if top_k is not None: + v, _ = th.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("inf") + probs = f.softmax(logits, dim=-1) + idx_next = th.multinomial(probs, num_samples=1) + idx = th.cat((idx, idx_next), dim=1) + return idx +########################################### +########## GPT-2 BASE CODE ABOVE ########## +########################################### + + +class VQGANTransformer(nn.Module): + """Wrapper for VQGAN Stage 2: Transformer. + + Generative component of VQGAN trained on the Stage 1 discrete latent space. + + Parameters: + conditional (bool): If True, use CVQGAN for conditioning. + vqgan (VQGAN): Pretrained VQGAN model for primary image encoding/decoding. + cvqgan (VQGAN): Pretrained CVQGAN model for conditional encoding (if conditional=True). + image_size (int): Input image size (assumed square). + decoder_channels (tuple[int, ...]): Decoder channels from the VQGAN model. + cond_fmap_dim (int): Feature map dimension from the CVQGAN encoder (if conditional=True). + num_codebook_vectors (int): Number of codebook vectors from the VQGAN model. + n_layer (int): Number of Transformer layers. + n_head (int): Number of attention heads in the Transformer. + n_embd (int): Embedding dimension in the Transformer. + dropout (float): Dropout rate in the Transformer. + bias (bool): If True, use bias terms in the Transformer layers. + """ + def __init__( # noqa: PLR0913 + self, *, + conditional: bool = True, + vqgan: VQGAN, + cvqgan: VQGAN, + image_size: int, + decoder_channels: tuple[int, ...], + cond_fmap_dim: int, + num_codebook_vectors: int, + n_layer: int, + n_head: int, + n_embd: int, + dropout: int, + bias: bool = True + ): + super().__init__() + self.sos_token = 0 + self.vqgan = vqgan.eval() + for param in self.vqgan.parameters(): + param.requires_grad = False + + if conditional: + self.cvqgan = cvqgan + for param in self.cvqgan.parameters(): + param.requires_grad = False + + # block_size is automatically set to the combined sequence length of the VQGAN and CVQGAN + block_size = (image_size // (2 ** (len(decoder_channels) - 1))) ** 2 + if conditional: + block_size += cond_fmap_dim ** 2 + + # Create config object for NanoGPT + transformer_config = GPTConfig( + vocab_size=num_codebook_vectors, + block_size=block_size, + n_layer=n_layer, + n_head=n_head, + n_embd=n_embd, + dropout=dropout, # Add dropout parameter (default in nanoGPT) + bias=bias # Add bias parameter (default in nanoGPT) + ) + self.transformer = GPT(transformer_config) + self.conditional = conditional + self.sidelen = image_size // (2 ** (len(decoder_channels) - 1)) # Note: assumes square image + + @th.no_grad() + def encode_to_z(self, *, x: th.Tensor, is_c: bool = False) -> tuple[th.Tensor, th.Tensor]: + """Encode images to quantized latent vectors (z) and their indices.""" + if is_c: # For the conditional tokens, use the CVQGAN encoder + quant_z, indices, _ = self.cvqgan.encode(x) + else: + quant_z, indices, _ = self.vqgan.encode(x) + indices = indices.view(quant_z.shape[0], -1) + return quant_z, indices + + @th.no_grad() + def z_to_image(self, indices: th.Tensor) -> th.Tensor: + """Convert quantized latent indices back to image space.""" + ix_to_vectors = self.vqgan.codebook.embedding(indices).reshape(indices.shape[0], self.sidelen, self.sidelen, -1) + ix_to_vectors = ix_to_vectors.permute(0, 3, 1, 2) + return self.vqgan.decode(ix_to_vectors) + + def forward(self, x: th.Tensor, c: th.Tensor, pkeep: float = 1.0) -> tuple[th.Tensor, th.Tensor]: + """Forward pass through the Transformer. Returns logits and targets for loss computation.""" + _, indices = self.encode_to_z(x=x) + + # Replace the start token with the encoded conditional input if using CVQGAN + if self.conditional: + _, sos_tokens = self.encode_to_z(x=c, is_c=True) + else: + sos_tokens = th.ones(x.shape[0], 1) * self.sos_token + sos_tokens = sos_tokens.long().to(x.device) + + if pkeep < 1.0: + mask = th.bernoulli(pkeep * th.ones(indices.shape, device=indices.device)) + mask = mask.round().to(dtype=th.int64) + random_indices = th.randint_like(indices, self.transformer.config.vocab_size) + new_indices = mask * indices + (1 - mask) * random_indices + else: + new_indices = indices + + new_indices = th.cat((sos_tokens, new_indices), dim=1) + + target = indices + + # NanoGPT forward doesn't use embeddings parameter, but takes targets + # We're ignoring the loss returned by NanoGPT + logits, _ = self.transformer(new_indices[:, :-1], None) + logits = logits[:, -indices.shape[1]:] # Always predict the last 256 tokens + + return logits, target + + def top_k_logits(self, logits: th.Tensor, k: int) -> th.Tensor: + """Zero out all logits that are not in the top-k.""" + v, _ = th.topk(logits, k) + out = logits.clone() + out[out < v[..., [-1]]] = -float("inf") + return out + + @th.no_grad() + def sample(self, x: th.Tensor, c: th.Tensor, steps: int, temperature: float = 1.0, top_k: int | None = None) -> th.Tensor: + """Autoregressively sample from the model given initial context x and conditional c.""" + x = th.cat((c, x), dim=1) + + # Keep the original sampling logic for compatibility + for _ in range(steps): + logits, _ = self.transformer(x, None) + logits = logits[:, -1, :] / temperature + + if top_k is not None: + # Determine the actual vocabulary size for this batch + # Count non-negative infinity values in the logits + n_tokens = th.sum(th.isfinite(logits), dim=-1).min().item() + + # Use the minimum of top_k and the actual number of tokens + effective_top_k = min(top_k, n_tokens) + + # Apply top_k with the effective value + if effective_top_k > 0: # Ensure we have at least one token to sample + logits = self.top_k_logits(logits, effective_top_k) + else: + # Fallback if all logits are -inf (shouldn't happen, but just in case) + print("Warning: No finite logits found for sampling") + # Make all logits equal (uniform distribution) + logits = th.zeros_like(logits) + + probs = f.softmax(logits, dim=-1) + + # In the VQGAN paper we use multinomial sampling (top_k=None, greedy=False) + ix = th.multinomial(probs, num_samples=1) + + x = th.cat((x, ix), dim=1) + + return x[:, c.shape[1]:] + + @th.no_grad() + def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tuple[dict[str, th.Tensor], th.Tensor]: + """Generate reconstructions and samples from the model for logging.""" + log = {} + + _, indices = self.encode_to_z(x=x) + # Replace the start token with the encoded conditional input if using CVQGAN + if self.conditional: + _, sos_tokens = self.encode_to_z(x=c, is_c=True) + else: + sos_tokens = th.ones(x.shape[0], 1) * self.sos_token + sos_tokens = sos_tokens.long().to(x.device) + + start_indices = indices[:, :indices.shape[1] // 2] + sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1] - start_indices.shape[1], top_k=top_k) + half_sample = self.z_to_image(sample_indices) + + start_indices = indices[:, :0] + sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1], top_k=top_k) + full_sample = self.z_to_image(sample_indices) + + x_rec = self.z_to_image(indices) + + log["input"] = x + log["rec"] = x_rec + log["half_sample"] = half_sample + log["full_sample"] = full_sample + + return log, th.concat((x, x_rec, half_sample, full_sample)) diff --git a/pyproject.toml b/pyproject.toml index e79ad93..d5b8118 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -223,7 +223,7 @@ strict = [] typeCheckingMode = "basic" pythonVersion = "3.9" pythonPlatform = "All" -typeshedPath = "typeshed" +# typeshedPath = "typeshed" -> commented out may lead to precommit out of memory error enableTypeIgnoreComments = true # This is required as the CI pre-commit does not download the module (i.e. numpy, pygame) @@ -295,5 +295,6 @@ module = [ "einops.*", "torchvision.*", "requests", + "transformers", ] ignore_missing_imports = true From 5b9625f46b56feedafe0e64961f246ac7d1e0acd Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Sun, 28 Sep 2025 01:40:40 +0200 Subject: [PATCH 09/22] add initial training loops and conditions augmentations --- engiopt/metrics.py | 7 +- engiopt/share/man/man1/isympy.1 | 188 ++++++++++++++++ engiopt/transforms.py | 95 ++++++-- engiopt/vqgan/vqgan.py | 374 ++++++++++++++++++++++++++------ pyproject.toml | 1 + 5 files changed, 577 insertions(+), 88 deletions(-) create mode 100644 engiopt/share/man/man1/isympy.1 diff --git a/engiopt/metrics.py b/engiopt/metrics.py index c2723fa..cc273cf 100644 --- a/engiopt/metrics.py +++ b/engiopt/metrics.py @@ -8,6 +8,7 @@ import multiprocessing import os +import sys import traceback from typing import Any, TYPE_CHECKING @@ -21,7 +22,11 @@ from engibench import OptiStep from engibench.core import Problem -multiprocessing.set_start_method("fork") + +if sys.platform != "win32": # only set fork on non-Windows + multiprocessing.set_start_method("fork", force=True) +else: + multiprocessing.set_start_method("spawn", force=True) def mmd(x: np.ndarray, y: np.ndarray, sigma: float = 1.0) -> float: diff --git a/engiopt/share/man/man1/isympy.1 b/engiopt/share/man/man1/isympy.1 new file mode 100644 index 0000000..0ff9661 --- /dev/null +++ b/engiopt/share/man/man1/isympy.1 @@ -0,0 +1,188 @@ +'\" -*- coding: us-ascii -*- +.if \n(.g .ds T< \\FC +.if \n(.g .ds T> \\F[\n[.fam]] +.de URL +\\$2 \(la\\$1\(ra\\$3 +.. +.if \n(.g .mso www.tmac +.TH isympy 1 2007-10-8 "" "" +.SH NAME +isympy \- interactive shell for SymPy +.SH SYNOPSIS +'nh +.fi +.ad l +\fBisympy\fR \kx +.if (\nx>(\n(.l/2)) .nr x (\n(.l/5) +'in \n(.iu+\nxu +[\fB-c\fR | \fB--console\fR] [\fB-p\fR ENCODING | \fB--pretty\fR ENCODING] [\fB-t\fR TYPE | \fB--types\fR TYPE] [\fB-o\fR ORDER | \fB--order\fR ORDER] [\fB-q\fR | \fB--quiet\fR] [\fB-d\fR | \fB--doctest\fR] [\fB-C\fR | \fB--no-cache\fR] [\fB-a\fR | \fB--auto\fR] [\fB-D\fR | \fB--debug\fR] [ +-- | PYTHONOPTIONS] +'in \n(.iu-\nxu +.ad b +'hy +'nh +.fi +.ad l +\fBisympy\fR \kx +.if (\nx>(\n(.l/2)) .nr x (\n(.l/5) +'in \n(.iu+\nxu +[ +{\fB-h\fR | \fB--help\fR} +| +{\fB-v\fR | \fB--version\fR} +] +'in \n(.iu-\nxu +.ad b +'hy +.SH DESCRIPTION +isympy is a Python shell for SymPy. It is just a normal python shell +(ipython shell if you have the ipython package installed) that executes +the following commands so that you don't have to: +.PP +.nf +\*(T< +>>> from __future__ import division +>>> from sympy import * +>>> x, y, z = symbols("x,y,z") +>>> k, m, n = symbols("k,m,n", integer=True) + \*(T> +.fi +.PP +So starting isympy is equivalent to starting python (or ipython) and +executing the above commands by hand. It is intended for easy and quick +experimentation with SymPy. For more complicated programs, it is recommended +to write a script and import things explicitly (using the "from sympy +import sin, log, Symbol, ..." idiom). +.SH OPTIONS +.TP +\*(T<\fB\-c \fR\*(T>\fISHELL\fR, \*(T<\fB\-\-console=\fR\*(T>\fISHELL\fR +Use the specified shell (python or ipython) as +console backend instead of the default one (ipython +if present or python otherwise). + +Example: isympy -c python + +\fISHELL\fR could be either +\&'ipython' or 'python' +.TP +\*(T<\fB\-p \fR\*(T>\fIENCODING\fR, \*(T<\fB\-\-pretty=\fR\*(T>\fIENCODING\fR +Setup pretty printing in SymPy. By default, the most pretty, unicode +printing is enabled (if the terminal supports it). You can use less +pretty ASCII printing instead or no pretty printing at all. + +Example: isympy -p no + +\fIENCODING\fR must be one of 'unicode', +\&'ascii' or 'no'. +.TP +\*(T<\fB\-t \fR\*(T>\fITYPE\fR, \*(T<\fB\-\-types=\fR\*(T>\fITYPE\fR +Setup the ground types for the polys. By default, gmpy ground types +are used if gmpy2 or gmpy is installed, otherwise it falls back to python +ground types, which are a little bit slower. You can manually +choose python ground types even if gmpy is installed (e.g., for testing purposes). + +Note that sympy ground types are not supported, and should be used +only for experimental purposes. + +Note that the gmpy1 ground type is primarily intended for testing; it the +use of gmpy even if gmpy2 is available. + +This is the same as setting the environment variable +SYMPY_GROUND_TYPES to the given ground type (e.g., +SYMPY_GROUND_TYPES='gmpy') + +The ground types can be determined interactively from the variable +sympy.polys.domains.GROUND_TYPES inside the isympy shell itself. + +Example: isympy -t python + +\fITYPE\fR must be one of 'gmpy', +\&'gmpy1' or 'python'. +.TP +\*(T<\fB\-o \fR\*(T>\fIORDER\fR, \*(T<\fB\-\-order=\fR\*(T>\fIORDER\fR +Setup the ordering of terms for printing. The default is lex, which +orders terms lexicographically (e.g., x**2 + x + 1). You can choose +other orderings, such as rev-lex, which will use reverse +lexicographic ordering (e.g., 1 + x + x**2). + +Note that for very large expressions, ORDER='none' may speed up +printing considerably, with the tradeoff that the order of the terms +in the printed expression will have no canonical order + +Example: isympy -o rev-lax + +\fIORDER\fR must be one of 'lex', 'rev-lex', 'grlex', +\&'rev-grlex', 'grevlex', 'rev-grevlex', 'old', or 'none'. +.TP +\*(T<\fB\-q\fR\*(T>, \*(T<\fB\-\-quiet\fR\*(T> +Print only Python's and SymPy's versions to stdout at startup, and nothing else. +.TP +\*(T<\fB\-d\fR\*(T>, \*(T<\fB\-\-doctest\fR\*(T> +Use the same format that should be used for doctests. This is +equivalent to '\fIisympy -c python -p no\fR'. +.TP +\*(T<\fB\-C\fR\*(T>, \*(T<\fB\-\-no\-cache\fR\*(T> +Disable the caching mechanism. Disabling the cache may slow certain +operations down considerably. This is useful for testing the cache, +or for benchmarking, as the cache can result in deceptive benchmark timings. + +This is the same as setting the environment variable SYMPY_USE_CACHE +to 'no'. +.TP +\*(T<\fB\-a\fR\*(T>, \*(T<\fB\-\-auto\fR\*(T> +Automatically create missing symbols. Normally, typing a name of a +Symbol that has not been instantiated first would raise NameError, +but with this option enabled, any undefined name will be +automatically created as a Symbol. This only works in IPython 0.11. + +Note that this is intended only for interactive, calculator style +usage. In a script that uses SymPy, Symbols should be instantiated +at the top, so that it's clear what they are. + +This will not override any names that are already defined, which +includes the single character letters represented by the mnemonic +QCOSINE (see the "Gotchas and Pitfalls" document in the +documentation). You can delete existing names by executing "del +name" in the shell itself. You can see if a name is defined by typing +"'name' in globals()". + +The Symbols that are created using this have default assumptions. +If you want to place assumptions on symbols, you should create them +using symbols() or var(). + +Finally, this only works in the top level namespace. So, for +example, if you define a function in isympy with an undefined +Symbol, it will not work. +.TP +\*(T<\fB\-D\fR\*(T>, \*(T<\fB\-\-debug\fR\*(T> +Enable debugging output. This is the same as setting the +environment variable SYMPY_DEBUG to 'True'. The debug status is set +in the variable SYMPY_DEBUG within isympy. +.TP +-- \fIPYTHONOPTIONS\fR +These options will be passed on to \fIipython (1)\fR shell. +Only supported when ipython is being used (standard python shell not supported). + +Two dashes (--) are required to separate \fIPYTHONOPTIONS\fR +from the other isympy options. + +For example, to run iSymPy without startup banner and colors: + +isympy -q -c ipython -- --colors=NoColor +.TP +\*(T<\fB\-h\fR\*(T>, \*(T<\fB\-\-help\fR\*(T> +Print help output and exit. +.TP +\*(T<\fB\-v\fR\*(T>, \*(T<\fB\-\-version\fR\*(T> +Print isympy version information and exit. +.SH FILES +.TP +\*(T<\fI${HOME}/.sympy\-history\fR\*(T> +Saves the history of commands when using the python +shell as backend. +.SH BUGS +The upstreams BTS can be found at \(lahttps://github.com/sympy/sympy/issues\(ra +Please report all bugs that you find in there, this will help improve +the overall quality of SymPy. +.SH "SEE ALSO" +\fBipython\fR(1), \fBpython\fR(1) diff --git a/engiopt/transforms.py b/engiopt/transforms.py index edfd27d..205ca72 100644 --- a/engiopt/transforms.py +++ b/engiopt/transforms.py @@ -3,6 +3,7 @@ from collections.abc import Callable import math +from datasets import Dataset from engibench.core import Problem from gymnasium import spaces import torch as th @@ -32,32 +33,82 @@ def _nearest_power_of_two(x: int) -> int: return upper if abs(x - upper) < abs(x - lower) else lower -def upsample_nearest(data: th.Tensor, mode: str="bicubic") -> th.Tensor: - """Upsample 2D data to the nearest 2^n dimensions. Data should be a Tensor in the format (B, C, H, W).""" +def upsample_nearest(data: th.Tensor, mode: str = "bicubic") -> th.Tensor: + """Upsample 2D data to the nearest square 2^n based on the maximum dimension. + + Accepts input of shape (B, H, W) or (B, C, H, W). + """ + low_dim = 3 + if data.ndim == low_dim: + data = data.unsqueeze(1) # (B, 1, H, W) _, _, h, w = data.shape - target_h = _nearest_power_of_two(h) - target_w = _nearest_power_of_two(w) - # If nearest power of two is smaller, multiply it by 2 - if target_h < h: - target_h *= 2 - if target_w < w: - target_w *= 2 - return f.interpolate(data, size=(target_h, target_w), mode=mode) - - -def downsample_nearest(data: th.Tensor, mode: str="bicubic") -> th.Tensor: - """Downsample 2D data to the nearest 2^n dimensions. Data should be a Tensor in the format (B, C, H, W).""" + + max_dim = max(h, w) + target = _nearest_power_of_two(max_dim) + if target < max_dim: + target *= 2 + + return f.interpolate(data, size=(target, target), mode=mode) + + +def downsample_nearest(data: th.Tensor, mode: str = "bicubic") -> th.Tensor: + """Downsample 2D data to the nearest square 2^n based on the maximum dimension. + + Accepts input of shape (B, H, W) or (B, C, H, W). + """ + low_dim = 3 + if data.ndim == low_dim: + data = data.unsqueeze(1) # (B, 1, H, W) _, _, h, w = data.shape - target_h = _nearest_power_of_two(h) - target_w = _nearest_power_of_two(w) - # If nearest power of two is larger, divide it by 2 - if target_h > h: - target_h //= 2 - if target_w > w: - target_w //= 2 - return f.interpolate(data, size=(target_h, target_w), mode=mode) + + max_dim = max(h, w) + target = _nearest_power_of_two(max_dim) + if target > max_dim: + target //= 2 + + return f.interpolate(data, size=(target, target), mode=mode) def resize_to(data: th.Tensor, h: int, w: int, mode: str = "bicubic") -> th.Tensor: """Resize 2D data back to any desired (h, w). Data should be a Tensor in the format (B, C, H, W).""" return f.interpolate(data, size=(h, w), mode=mode) + + +def normalize( + ds: Dataset, condition_names: list[str] +) -> tuple[Dataset, th.Tensor, th.Tensor]: + """Normalize specified condition columns with global mean/std (torch version, CPU).""" + # stack condition columns into a single tensor (N, C) on CPU + conds = th.stack([th.as_tensor(ds[c]).float() for c in condition_names], dim=1) + mean = conds.mean(dim=0) + std = conds.std(dim=0).clamp(min=1e-8) + + # normalize each condition column (HF expects numpy back) + ds = ds.map( + lambda batch: { + c: ((th.as_tensor(batch[c]).float() - mean[i]) / std[i]).numpy() + for i, c in enumerate(condition_names) + }, + batched=True, + ) + + return ds, mean, std + + +def drop_constant( + ds: Dataset, condition_names: list[str] +) -> tuple[Dataset, list[str]]: + """Drop constant condition columns (std=0) from dataset.""" + conds = th.stack([th.as_tensor(ds[c]).float() for c in condition_names], dim=1) + std = conds.std(dim=0) + + kept = [c for i, c in enumerate(condition_names) if std[i] > 0] + dropped = [c for i, c in enumerate(condition_names) if std[i] == 0] + + if dropped: + print(f"Warning: Dropping constant condition columns (std=0): {dropped}") + + # remove dropped columns from dataset + ds = ds.remove_columns(dropped) + + return ds, kept diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index b850c3a..5cd3e3e 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -1,4 +1,3 @@ -# ruff: noqa: F401 # REMOVE THIS LATER """Vector Quantized Generative Adversarial Network (VQGAN). Based on https://github.com/dome272/VQGAN-pyth with an "Online Clustered Codebook" for better codebook usage from https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py @@ -18,23 +17,20 @@ from __future__ import annotations -from collections import namedtuple from dataclasses import dataclass import inspect import math import os import random import time -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import warnings from einops import rearrange from engibench.utils.all_problems import BUILTIN_PROBLEMS -import matplotlib.pyplot as plt import numpy as np import requests import torch as th -from torch import autograd from torch import nn from torch.nn import functional as f from torchvision.models import vgg16 @@ -44,9 +40,8 @@ import tyro import wandb -from engiopt.metrics import dpp_diversity -from engiopt.metrics import mmd -from engiopt.transforms import resize_to +from engiopt.transforms import drop_constant +from engiopt.transforms import normalize from engiopt.transforms import upsample_nearest if TYPE_CHECKING: @@ -72,7 +67,7 @@ class Args: """The name of this algorithm.""" # Tracking - track: bool = True + track: bool = False """Track the experiment with wandb.""" wandb_project: str = "engiopt" """Wandb project name.""" @@ -86,13 +81,44 @@ class Args: # Algorithm-specific: General conditional: bool = True """whether the model is conditional or not""" + normalize_conditions: bool = True + """whether to normalize the condition columns to zero mean and unit std""" + drop_constant_conditions: bool = True + """whether to drop constant condition columns (i.e., overhang_constraint in beams2d)""" + image_size: int = 128 + """size of each image dimension (determined automatically later)""" + image_channels: int = 1 + """number of channels in the input image (determined automatically later)""" + latent_size: int = 16 + """size of each latent feature map dimension (determined automatically later)""" + + # Algorithm-specific: Stage 1 Conditional AE or "CVQGAN" if the model is specified as conditional + # Note that a Discriminator is not used for CVQGAN, as it is generally a much simpler model. + cond_dim: int = 3 + """dimensionality of the condition space""" + cond_hidden_dim: int = 256 + """hidden dimension of the CVQGAN MLP""" + cond_latent_dim: int = 4 + """individual code dimension for CVQGAN""" + cond_codebook_vectors: int = 64 + """number of vectors in the CVQGAN codebook""" + cond_feature_map_dim: int = 4 + """feature map dimension for the CVQGAN encoder output""" + batch_size_0: int = 16 + """size of the batches for CVQGAN""" + n_epochs_0: int = 1000 # Default: 1000 + """number of epochs of CVQGAN training""" + cond_lr: float = 2e-4 + """learning rate for CVQGAN""" + cond_sample_interval: int = 1600 + """interval between CVQGAN image samples""" # Algorithm-specific: Stage 1 (AE) # From original implementation: assume image_channels=1, use greyscale LPIPS only, use_Online=True, determine image_size automatically, calculate decoder_start_resolution automatically - n_epochs_1: int = 100 + n_epochs_1: int = 100 # Default: 100 """number of epochs of training""" batch_size_1: int = 16 - """size of the batches""" + """size of the batches for Stage 1""" lr_1: float = 2e-4 """learning rate for Stage 1""" beta: float = 0.25 @@ -115,15 +141,13 @@ class Args: """weighting factor for the reconstruction loss""" perceptual_loss_factor: float = 1.0 """weighting factor for the perceptual loss""" - encoder_channels: tuple[int, ...] = (128, 128, 128, 256, 256, 512) + encoder_channels: tuple[int, ...] = (64, 64, 128, 128, 256) """tuple of channel sizes for each encoder layer""" encoder_attn_resolutions: tuple[int, ...] = (16,) """tuple of resolutions at which to apply attention in the encoder""" encoder_num_res_blocks: int = 2 """number of residual blocks per encoder layer""" - encoder_start_resolution: int = 256 - """starting resolution for the encoder""" - decoder_channels: tuple[int, ...] = (512, 256, 256, 128, 128) + decoder_channels: tuple[int, ...] = (256, 128, 128, 64) """tuple of channel sizes for each decoder layer""" decoder_attn_resolutions: tuple[int, ...] = (16,) """tuple of resolutions at which to apply attention in the decoder""" @@ -132,31 +156,12 @@ class Args: sample_interval: int = 1600 """interval between image samples""" - # Algorithm-specific: Stage 1 (Conditional AE or "CVQGAN" if the model is conditional) - # Note that a Discriminator is not used for CVQGAN, as it is generally a much simpler model. - cond_dim: int = 3 - """dimensionality of the condition space""" - cond_hidden_dim: int = 256 - """hidden dimension of the CVQGAN MLP""" - cond_latent_dim: int = 4 - """individual code dimension for CVQGAN""" - cond_codebook_vectors: int = 64 - """number of vectors in the CVQGAN codebook""" - cond_feature_map_dim: int = 4 - """feature map dimension for the CVQGAN encoder output""" - cond_epochs: int = 100 - """number of epochs of CVQGAN training""" - cond_lr: float = 2e-4 - """learning rate for CVQGAN""" - cond_sample_interval: int = 1600 - """interval between CVQGAN image samples""" - # Algorithm-specific: Stage 2 (Transformer) # From original implementation: assume pkeep=1.0, sos_token=0, bias=True - n_epochs_2: int = 100 + n_epochs_2: int = 100 # Default: 100 """number of epochs of training""" batch_size_2: int = 16 - """size of the batches""" + """size of the batches for Stage 2""" lr_2: float = 6e-4 """learning rate for Stage 2""" n_layer: int = 12 @@ -931,18 +936,17 @@ def __init__( # noqa: PLR0913 cond_codebook_vectors: int = 64, # VQGAN + Codebook parameters - encoder_channels: tuple[int, ...], - encoder_start_resolution: int, - encoder_attn_resolutions: tuple[int, ...], - encoder_num_res_blocks: int, - decoder_channels: tuple[int, ...], - decoder_start_resolution: int, - decoder_attn_resolutions: tuple[int, ...], - decoder_num_res_blocks: int, + encoder_channels: tuple[int, ...] = (64, 64, 128, 128, 256), + encoder_start_resolution: int = 128, + encoder_attn_resolutions: tuple[int, ...] = (16,), + encoder_num_res_blocks: int = 2, + decoder_channels: tuple[int, ...] = (256, 128, 128, 64), + decoder_start_resolution: int = 16, + decoder_attn_resolutions: tuple[int, ...] = (16,), + decoder_num_res_blocks: int = 3, image_channels: int = 1, latent_dim: int = 16, - num_codebook_vectors: int = 256, - + num_codebook_vectors: int = 256 ): super().__init__() if is_c: @@ -989,18 +993,18 @@ def __init__( # noqa: PLR0913 latent_dim = cond_latent_dim if is_c else latent_dim ).to(device=device) - def forward(self, imgs: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor]: + def forward(self, designs: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor]: """Full VQGAN forward pass.""" - encoded = self.encoder(imgs) + encoded = self.encoder(designs) quant_encoded = self.quant_conv(encoded) - quant, indices, q_loss = self.codebook(quant_encoded) + quant, indices, q_loss, _, _ = self.codebook(quant_encoded) post_quant = self.post_quant_conv(quant) decoded = self.decoder(post_quant) return decoded, indices, q_loss - def encode(self, imgs: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor]: + def encode(self, designs: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: """Encode image batch into quantized latent representation.""" - encoded = self.encoder(imgs) + encoded = self.encoder(designs) quant_encoded = self.quant_conv(encoded) return self.codebook(quant_encoded) @@ -1338,7 +1342,7 @@ class VQGANTransformer(nn.Module): cvqgan (VQGAN): Pretrained CVQGAN model for conditional encoding (if conditional=True). image_size (int): Input image size (assumed square). decoder_channels (tuple[int, ...]): Decoder channels from the VQGAN model. - cond_fmap_dim (int): Feature map dimension from the CVQGAN encoder (if conditional=True). + cond_feature_map_dim (int): Feature map dimension from the CVQGAN encoder (if conditional=True). num_codebook_vectors (int): Number of codebook vectors from the VQGAN model. n_layer (int): Number of Transformer layers. n_head (int): Number of attention heads in the Transformer. @@ -1353,7 +1357,7 @@ def __init__( # noqa: PLR0913 cvqgan: VQGAN, image_size: int, decoder_channels: tuple[int, ...], - cond_fmap_dim: int, + cond_feature_map_dim: int, num_codebook_vectors: int, n_layer: int, n_head: int, @@ -1363,19 +1367,13 @@ def __init__( # noqa: PLR0913 ): super().__init__() self.sos_token = 0 - self.vqgan = vqgan.eval() - for param in self.vqgan.parameters(): - param.requires_grad = False - - if conditional: - self.cvqgan = cvqgan - for param in self.cvqgan.parameters(): - param.requires_grad = False + self.vqgan = vqgan + self.cvqgan = cvqgan # block_size is automatically set to the combined sequence length of the VQGAN and CVQGAN block_size = (image_size // (2 ** (len(decoder_channels) - 1))) ** 2 if conditional: - block_size += cond_fmap_dim ** 2 + block_size += cond_feature_map_dim ** 2 # Create config object for NanoGPT transformer_config = GPTConfig( @@ -1395,9 +1393,9 @@ def __init__( # noqa: PLR0913 def encode_to_z(self, *, x: th.Tensor, is_c: bool = False) -> tuple[th.Tensor, th.Tensor]: """Encode images to quantized latent vectors (z) and their indices.""" if is_c: # For the conditional tokens, use the CVQGAN encoder - quant_z, indices, _ = self.cvqgan.encode(x) + quant_z, indices, _, _, _ = self.cvqgan.encode(x) else: - quant_z, indices, _ = self.vqgan.encode(x) + quant_z, indices, _, _, _ = self.vqgan.encode(x) indices = indices.view(quant_z.shape[0], -1) return quant_z, indices @@ -1468,7 +1466,7 @@ def sample(self, x: th.Tensor, c: th.Tensor, steps: int, temperature: float = 1. logits = self.top_k_logits(logits, effective_top_k) else: # Fallback if all logits are -inf (shouldn't happen, but just in case) - print("Warning: No finite logits found for sampling") + warnings.warn("Warning: No finite logits found for sampling", stacklevel=2) # Make all logits equal (uniform distribution) logits = th.zeros_like(logits) @@ -1510,3 +1508,249 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu log["full_sample"] = full_sample return log, th.concat((x, x_rec, half_sample, full_sample)) + + +if __name__ == "__main__": + args = tyro.cli(Args) + + # Seeding + th.manual_seed(args.seed) + rng = np.random.default_rng(args.seed) + random.seed(args.seed) + th.backends.cudnn.deterministic = True + + os.makedirs("images", exist_ok=True) + + if th.backends.mps.is_available(): + device = th.device("mps") + elif th.cuda.is_available(): + device = th.device("cuda") + else: + device = th.device("cpu") + + problem = BUILTIN_PROBLEMS[args.problem_id]() + problem.reset(seed=args.seed) + + # Configure data loader (keep on CPU for preprocessing) + training_ds = problem.dataset.with_format("torch")["train"] + + # Add in the upsampled optimal design column and remove the original optimal design column + training_ds = training_ds.map( + lambda batch: { + "optimal_upsampled": upsample_nearest(batch["optimal_design"]).cpu().numpy() + }, + batched=True, + ) + training_ds = training_ds.remove_columns("optimal_design") + + args.image_size = training_ds["optimal_upsampled"].shape[-1] + args.latent_size = args.image_size // (2 ** (len(args.encoder_channels) - 2)) + conditions = problem.conditions_keys + + # Optionally normalize condition columns + if args.normalize_conditions: + training_ds, mean, std = normalize(training_ds, conditions) + + # Optionally drop condition columns that are constant like overhang_constraint in beams2d + if args.drop_constant_conditions: + training_ds, conditions = drop_constant(training_ds, conditions) + + args.cond_dim = len(conditions) + + # Move to device only here + th_training_ds = th.utils.data.TensorDataset( + th.as_tensor(training_ds["optimal_upsampled"]).to(device), + *[th.as_tensor(training_ds[key]).to(device) for key in conditions], + ) + dataloader_0 = th.utils.data.DataLoader( + th_training_ds, + batch_size=args.batch_size_0, + shuffle=True, + ) + dataloader_1 = th.utils.data.DataLoader( + th_training_ds, + batch_size=args.batch_size_1, + shuffle=True, + ) + dataloader_2 = th.utils.data.DataLoader( + th_training_ds, + batch_size=args.batch_size_2, + shuffle=True, + ) + # Logging + run_name = f"{args.problem_id}__{args.algo}__{args.seed}__{int(time.time())}" + if args.track: + wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), save_code=True, name=run_name) + + vqgan = VQGAN( + device=device, + is_c=False, + encoder_channels=args.encoder_channels, + encoder_start_resolution=args.image_size, + encoder_attn_resolutions=args.encoder_attn_resolutions, + encoder_num_res_blocks=args.encoder_num_res_blocks, + decoder_channels=args.decoder_channels, + decoder_start_resolution=args.latent_size, + decoder_attn_resolutions=args.decoder_attn_resolutions, + decoder_num_res_blocks=args.decoder_num_res_blocks, + image_channels=args.image_channels, + latent_dim=args.latent_dim, + num_codebook_vectors=args.num_codebook_vectors + ).to(device=device) + + discriminator = Discriminator(image_channels=args.image_channels).to(device=device) + + cvqgan = VQGAN( + device=device, + is_c=True, + cond_feature_map_dim=args.cond_feature_map_dim, + cond_dim=args.cond_dim, + cond_hidden_dim=args.cond_hidden_dim, + cond_latent_dim=args.cond_latent_dim, + cond_codebook_vectors=args.cond_codebook_vectors + ).to(device=device) + + transformer = VQGANTransformer( + conditional=args.conditional, + vqgan=vqgan, + cvqgan=cvqgan, + image_size=args.image_size, + decoder_channels=args.decoder_channels, + cond_feature_map_dim=args.cond_feature_map_dim, + num_codebook_vectors=args.num_codebook_vectors, + n_layer=args.n_layer, + n_head=args.n_head, + n_embd=args.n_embd, + dropout=args.dropout + ).to(device=device) + + # CVQGAN Stage 0 optimizer + opt_cvq = th.optim.Adam( + list(cvqgan.encoder.parameters()) + + list(cvqgan.decoder.parameters()) + + list(cvqgan.codebook.parameters()) + + list(cvqgan.quant_conv.parameters()) + + list(cvqgan.post_quant_conv.parameters()), + lr=args.cond_lr, eps=1e-08, betas=(args.b1, args.b2) + ) + + # VQGAN Stage 1 optimizer + opt_vq = th.optim.Adam( + list(vqgan.encoder.parameters()) + + list(vqgan.decoder.parameters()) + + list(vqgan.codebook.parameters()) + + list(vqgan.quant_conv.parameters()) + + list(vqgan.post_quant_conv.parameters()), + lr=args.lr_1, eps=1e-08, betas=(args.b1, args.b2) + ) + # VQGAN Stage 1 discriminator optimizer + opt_disc = th.optim.Adam(discriminator.parameters(), + lr=args.lr_1, eps=1e-08, betas=(args.b1, args.b2)) + + # Transformer Stage 2 optimizer + decay, no_decay = set(), set() + whitelist_weight_modules = (nn.Linear, ) + blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) + + for mn, m in transformer.transformer.named_modules(): + for pn, _ in m.named_parameters(): + fpn = f"{mn}.{pn}" if mn else pn + + if pn.endswith("bias"): + no_decay.add(fpn) + + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + decay.add(fpn) + + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + no_decay.add(fpn) + + no_decay.add("pos_emb") + + param_dict = dict(transformer.transformer.named_parameters()) + decay = {pn for pn in decay if pn in param_dict} + no_decay = {pn for pn in no_decay if pn in param_dict} + + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(decay)], "weight_decay": 0.01}, + {"params": [param_dict[pn] for pn in sorted(no_decay)], "weight_decay": 0.0}, + ] + + opt_transformer = th.optim.AdamW(optim_groups, lr=args.lr_2, betas=(0.9, 0.95)) + + perceptual_loss_fcn = GreyscaleLPIPS().eval().to(device) + + if args.conditional: + print("Stage 0: Training CVQGAN") + cvqgan.train() + for _ in tqdm.trange(args.n_epochs_0): + for _, data in enumerate(dataloader_0): + # THIS IS PROBLEM DEPENDENT + conds = th.stack((data[1:]), dim=1).to(dtype=th.float32, device=device) + decoded_images, codebook_indices, q_loss = cvqgan(conds) + + opt_cvq.zero_grad() + rec_loss = th.abs(conds - decoded_images).mean() + cvq_loss = rec_loss + q_loss + cvq_loss.backward() + opt_cvq.step() + + # Freeze CVQGAN for later use in Stage 2 Transformer + for p in cvqgan.parameters(): + p.requires_grad_(requires_grad=False) + cvqgan.eval() + + print("Stage 1: Training VQGAN") + vqgan.train() + discriminator.train() + for epoch in tqdm.trange(args.n_epochs_1): + for _, data in enumerate(dataloader_1): + # THIS IS PROBLEM DEPENDENT + designs = data[0].to(dtype=th.float32, device=device) + decoded_images, codebook_indices, q_loss = vqgan(designs) + + disc_real = discriminator(designs) + disc_fake = discriminator(decoded_images) + + disc_factor = vqgan.adopt_weight(args.disc_factor, epoch, threshold=args.disc_start) + + perceptual_loss = perceptual_loss_fcn(designs, decoded_images) + rec_loss = th.abs(designs - decoded_images) + perceptual_rec_loss = args.perceptual_loss_factor * perceptual_loss + args.rec_loss_factor * rec_loss + perceptual_rec_loss = perceptual_rec_loss.mean() + g_loss = -th.mean(disc_fake) + + lamb = vqgan.calculate_lambda(perceptual_rec_loss, g_loss) + vq_loss = perceptual_rec_loss + q_loss + disc_factor * lamb * g_loss + + d_loss_real = th.mean(f.relu(1. - disc_real)) + d_loss_fake = th.mean(f.relu(1. + disc_fake)) + gan_loss = disc_factor * 0.5*(d_loss_real + d_loss_fake) + + opt_vq.zero_grad() + vq_loss.backward(retain_graph=True) + + opt_disc.zero_grad() + gan_loss.backward() + + opt_vq.step() + opt_disc.step() + + # Freeze VQGAN for later use in Stage 2 Transformer + for p in vqgan.parameters(): + p.requires_grad_(requires_grad=False) + vqgan.eval() + + print("Stage 2: Training Transformer") + transformer.train() + for _ in tqdm.trange(args.n_epochs_2): + for _, data in enumerate(dataloader_2): + # THIS IS PROBLEM DEPENDENT + designs = data[0].to(dtype=th.float32, device=device) + conds = th.stack((data[1:]), dim=1).to(dtype=th.float32, device=device) + + opt_transformer.zero_grad() + logits, targets = transformer(designs, conds) + loss = f.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) + loss.backward() + opt_transformer.step() diff --git a/pyproject.toml b/pyproject.toml index d5b8118..c6061a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -296,5 +296,6 @@ module = [ "torchvision.*", "requests", "transformers", + "wandb", ] ignore_missing_imports = true From c0e4ae0a3d409fac268149cd7a8c596058d72034 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Sun, 28 Sep 2025 18:35:35 +0200 Subject: [PATCH 10/22] add training tracking --- .gitignore | 1 + engiopt/share/man/man1/isympy.1 | 188 -------------------------------- engiopt/transforms.py | 6 +- engiopt/vqgan/vqgan.py | 155 +++++++++++++++++++++++--- pyproject.toml | 3 +- 5 files changed, 147 insertions(+), 206 deletions(-) delete mode 100644 engiopt/share/man/man1/isympy.1 diff --git a/.gitignore b/.gitignore index 93b33db..052e04a 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,7 @@ cython_debug/ wandb/* images/* +logs/* # Editors .idea/ .vscode/ diff --git a/engiopt/share/man/man1/isympy.1 b/engiopt/share/man/man1/isympy.1 deleted file mode 100644 index 0ff9661..0000000 --- a/engiopt/share/man/man1/isympy.1 +++ /dev/null @@ -1,188 +0,0 @@ -'\" -*- coding: us-ascii -*- -.if \n(.g .ds T< \\FC -.if \n(.g .ds T> \\F[\n[.fam]] -.de URL -\\$2 \(la\\$1\(ra\\$3 -.. -.if \n(.g .mso www.tmac -.TH isympy 1 2007-10-8 "" "" -.SH NAME -isympy \- interactive shell for SymPy -.SH SYNOPSIS -'nh -.fi -.ad l -\fBisympy\fR \kx -.if (\nx>(\n(.l/2)) .nr x (\n(.l/5) -'in \n(.iu+\nxu -[\fB-c\fR | \fB--console\fR] [\fB-p\fR ENCODING | \fB--pretty\fR ENCODING] [\fB-t\fR TYPE | \fB--types\fR TYPE] [\fB-o\fR ORDER | \fB--order\fR ORDER] [\fB-q\fR | \fB--quiet\fR] [\fB-d\fR | \fB--doctest\fR] [\fB-C\fR | \fB--no-cache\fR] [\fB-a\fR | \fB--auto\fR] [\fB-D\fR | \fB--debug\fR] [ --- | PYTHONOPTIONS] -'in \n(.iu-\nxu -.ad b -'hy -'nh -.fi -.ad l -\fBisympy\fR \kx -.if (\nx>(\n(.l/2)) .nr x (\n(.l/5) -'in \n(.iu+\nxu -[ -{\fB-h\fR | \fB--help\fR} -| -{\fB-v\fR | \fB--version\fR} -] -'in \n(.iu-\nxu -.ad b -'hy -.SH DESCRIPTION -isympy is a Python shell for SymPy. It is just a normal python shell -(ipython shell if you have the ipython package installed) that executes -the following commands so that you don't have to: -.PP -.nf -\*(T< ->>> from __future__ import division ->>> from sympy import * ->>> x, y, z = symbols("x,y,z") ->>> k, m, n = symbols("k,m,n", integer=True) - \*(T> -.fi -.PP -So starting isympy is equivalent to starting python (or ipython) and -executing the above commands by hand. It is intended for easy and quick -experimentation with SymPy. For more complicated programs, it is recommended -to write a script and import things explicitly (using the "from sympy -import sin, log, Symbol, ..." idiom). -.SH OPTIONS -.TP -\*(T<\fB\-c \fR\*(T>\fISHELL\fR, \*(T<\fB\-\-console=\fR\*(T>\fISHELL\fR -Use the specified shell (python or ipython) as -console backend instead of the default one (ipython -if present or python otherwise). - -Example: isympy -c python - -\fISHELL\fR could be either -\&'ipython' or 'python' -.TP -\*(T<\fB\-p \fR\*(T>\fIENCODING\fR, \*(T<\fB\-\-pretty=\fR\*(T>\fIENCODING\fR -Setup pretty printing in SymPy. By default, the most pretty, unicode -printing is enabled (if the terminal supports it). You can use less -pretty ASCII printing instead or no pretty printing at all. - -Example: isympy -p no - -\fIENCODING\fR must be one of 'unicode', -\&'ascii' or 'no'. -.TP -\*(T<\fB\-t \fR\*(T>\fITYPE\fR, \*(T<\fB\-\-types=\fR\*(T>\fITYPE\fR -Setup the ground types for the polys. By default, gmpy ground types -are used if gmpy2 or gmpy is installed, otherwise it falls back to python -ground types, which are a little bit slower. You can manually -choose python ground types even if gmpy is installed (e.g., for testing purposes). - -Note that sympy ground types are not supported, and should be used -only for experimental purposes. - -Note that the gmpy1 ground type is primarily intended for testing; it the -use of gmpy even if gmpy2 is available. - -This is the same as setting the environment variable -SYMPY_GROUND_TYPES to the given ground type (e.g., -SYMPY_GROUND_TYPES='gmpy') - -The ground types can be determined interactively from the variable -sympy.polys.domains.GROUND_TYPES inside the isympy shell itself. - -Example: isympy -t python - -\fITYPE\fR must be one of 'gmpy', -\&'gmpy1' or 'python'. -.TP -\*(T<\fB\-o \fR\*(T>\fIORDER\fR, \*(T<\fB\-\-order=\fR\*(T>\fIORDER\fR -Setup the ordering of terms for printing. The default is lex, which -orders terms lexicographically (e.g., x**2 + x + 1). You can choose -other orderings, such as rev-lex, which will use reverse -lexicographic ordering (e.g., 1 + x + x**2). - -Note that for very large expressions, ORDER='none' may speed up -printing considerably, with the tradeoff that the order of the terms -in the printed expression will have no canonical order - -Example: isympy -o rev-lax - -\fIORDER\fR must be one of 'lex', 'rev-lex', 'grlex', -\&'rev-grlex', 'grevlex', 'rev-grevlex', 'old', or 'none'. -.TP -\*(T<\fB\-q\fR\*(T>, \*(T<\fB\-\-quiet\fR\*(T> -Print only Python's and SymPy's versions to stdout at startup, and nothing else. -.TP -\*(T<\fB\-d\fR\*(T>, \*(T<\fB\-\-doctest\fR\*(T> -Use the same format that should be used for doctests. This is -equivalent to '\fIisympy -c python -p no\fR'. -.TP -\*(T<\fB\-C\fR\*(T>, \*(T<\fB\-\-no\-cache\fR\*(T> -Disable the caching mechanism. Disabling the cache may slow certain -operations down considerably. This is useful for testing the cache, -or for benchmarking, as the cache can result in deceptive benchmark timings. - -This is the same as setting the environment variable SYMPY_USE_CACHE -to 'no'. -.TP -\*(T<\fB\-a\fR\*(T>, \*(T<\fB\-\-auto\fR\*(T> -Automatically create missing symbols. Normally, typing a name of a -Symbol that has not been instantiated first would raise NameError, -but with this option enabled, any undefined name will be -automatically created as a Symbol. This only works in IPython 0.11. - -Note that this is intended only for interactive, calculator style -usage. In a script that uses SymPy, Symbols should be instantiated -at the top, so that it's clear what they are. - -This will not override any names that are already defined, which -includes the single character letters represented by the mnemonic -QCOSINE (see the "Gotchas and Pitfalls" document in the -documentation). You can delete existing names by executing "del -name" in the shell itself. You can see if a name is defined by typing -"'name' in globals()". - -The Symbols that are created using this have default assumptions. -If you want to place assumptions on symbols, you should create them -using symbols() or var(). - -Finally, this only works in the top level namespace. So, for -example, if you define a function in isympy with an undefined -Symbol, it will not work. -.TP -\*(T<\fB\-D\fR\*(T>, \*(T<\fB\-\-debug\fR\*(T> -Enable debugging output. This is the same as setting the -environment variable SYMPY_DEBUG to 'True'. The debug status is set -in the variable SYMPY_DEBUG within isympy. -.TP --- \fIPYTHONOPTIONS\fR -These options will be passed on to \fIipython (1)\fR shell. -Only supported when ipython is being used (standard python shell not supported). - -Two dashes (--) are required to separate \fIPYTHONOPTIONS\fR -from the other isympy options. - -For example, to run iSymPy without startup banner and colors: - -isympy -q -c ipython -- --colors=NoColor -.TP -\*(T<\fB\-h\fR\*(T>, \*(T<\fB\-\-help\fR\*(T> -Print help output and exit. -.TP -\*(T<\fB\-v\fR\*(T>, \*(T<\fB\-\-version\fR\*(T> -Print isympy version information and exit. -.SH FILES -.TP -\*(T<\fI${HOME}/.sympy\-history\fR\*(T> -Saves the history of commands when using the python -shell as backend. -.SH BUGS -The upstreams BTS can be found at \(lahttps://github.com/sympy/sympy/issues\(ra -Please report all bugs that you find in there, this will help improve -the overall quality of SymPy. -.SH "SEE ALSO" -\fBipython\fR(1), \fBpython\fR(1) diff --git a/engiopt/transforms.py b/engiopt/transforms.py index 205ca72..5f35e9a 100644 --- a/engiopt/transforms.py +++ b/engiopt/transforms.py @@ -79,14 +79,14 @@ def normalize( ) -> tuple[Dataset, th.Tensor, th.Tensor]: """Normalize specified condition columns with global mean/std (torch version, CPU).""" # stack condition columns into a single tensor (N, C) on CPU - conds = th.stack([th.as_tensor(ds[c]).float() for c in condition_names], dim=1) + conds = th.stack([th.as_tensor(ds[c][:]).float() for c in condition_names], dim=1) mean = conds.mean(dim=0) std = conds.std(dim=0).clamp(min=1e-8) # normalize each condition column (HF expects numpy back) ds = ds.map( lambda batch: { - c: ((th.as_tensor(batch[c]).float() - mean[i]) / std[i]).numpy() + c: ((th.as_tensor(batch[c][:]).float() - mean[i]) / std[i]).numpy() for i, c in enumerate(condition_names) }, batched=True, @@ -99,7 +99,7 @@ def drop_constant( ds: Dataset, condition_names: list[str] ) -> tuple[Dataset, list[str]]: """Drop constant condition columns (std=0) from dataset.""" - conds = th.stack([th.as_tensor(ds[c]).float() for c in condition_names], dim=1) + conds = th.stack([th.as_tensor(ds[c][:]).float() for c in condition_names], dim=1) std = conds.std(dim=0) kept = [c for i, c in enumerate(condition_names) if std[i] > 0] diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 5cd3e3e..c00aa47 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -67,7 +67,7 @@ class Args: """The name of this algorithm.""" # Tracking - track: bool = False + track: bool = True """Track the experiment with wandb.""" wandb_project: str = "engiopt" """Wandb project name.""" @@ -75,7 +75,7 @@ class Args: """Wandb entity name.""" seed: int = 1 """Random seed.""" - save_model: bool = False + save_model: bool = True """Saves the model to disk.""" # Algorithm-specific: General @@ -1530,6 +1530,7 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu problem = BUILTIN_PROBLEMS[args.problem_id]() problem.reset(seed=args.seed) + design_shape = problem.design_space.shape # Configure data loader (keep on CPU for preprocessing) training_ds = problem.dataset.with_format("torch")["train"] @@ -1537,13 +1538,14 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu # Add in the upsampled optimal design column and remove the original optimal design column training_ds = training_ds.map( lambda batch: { - "optimal_upsampled": upsample_nearest(batch["optimal_design"]).cpu().numpy() + "optimal_upsampled": upsample_nearest(batch["optimal_design"][:]).cpu().numpy() }, batched=True, ) training_ds = training_ds.remove_columns("optimal_design") + design_shape = training_ds["optimal_upsampled"][:].shape[-2:] - args.image_size = training_ds["optimal_upsampled"].shape[-1] + args.image_size = training_ds["optimal_upsampled"][:].shape[-1] args.latent_size = args.image_size // (2 ** (len(args.encoder_channels) - 2)) conditions = problem.conditions_keys @@ -1555,12 +1557,13 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu if args.drop_constant_conditions: training_ds, conditions = drop_constant(training_ds, conditions) - args.cond_dim = len(conditions) + n_conds = len(conditions) + args.cond_dim = n_conds # Move to device only here th_training_ds = th.utils.data.TensorDataset( - th.as_tensor(training_ds["optimal_upsampled"]).to(device), - *[th.as_tensor(training_ds[key]).to(device) for key in conditions], + th.as_tensor(training_ds["optimal_upsampled"][:]).to(device), + *[th.as_tensor(training_ds[key][:]).to(device) for key in conditions], ) dataloader_0 = th.utils.data.DataLoader( th_training_ds, @@ -1580,7 +1583,7 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu # Logging run_name = f"{args.problem_id}__{args.algo}__{args.seed}__{int(time.time())}" if args.track: - wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), save_code=True, name=run_name) + wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), save_code=True, name=run_name, dir="./logs/wandb") vqgan = VQGAN( device=device, @@ -1680,11 +1683,14 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu perceptual_loss_fcn = GreyscaleLPIPS().eval().to(device) + # --------------------------- + # Stage 0: Training CVQGAN + # --------------------------- if args.conditional: print("Stage 0: Training CVQGAN") cvqgan.train() - for _ in tqdm.trange(args.n_epochs_0): - for _, data in enumerate(dataloader_0): + for epoch in tqdm.trange(args.n_epochs_0): + for i, data in enumerate(dataloader_0): # THIS IS PROBLEM DEPENDENT conds = th.stack((data[1:]), dim=1).to(dtype=th.float32, device=device) decoded_images, codebook_indices, q_loss = cvqgan(conds) @@ -1695,16 +1701,53 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu cvq_loss.backward() opt_cvq.step() + # ---------- + # Logging + # ---------- + if args.track: + batches_done = epoch * len(dataloader_0) + i + wandb.log( + { + "cvq_loss": cvq_loss.item(), + "epoch": epoch, + "batch": batches_done, + } + ) + print( + f"[Epoch {epoch}/{args.n_epochs_0}] [Batch {i}/{len(dataloader_0)}] [CVQ loss: {cvq_loss.item()}]" + ) + + # -------------- + # Save model + # -------------- + if args.save_model and epoch == args.n_epochs_0 - 1 and i == len(dataloader_0) - 1: + ckpt_cvq = { + "epoch": epoch, + "batches_done": batches_done, + "cvqgan": cvqgan.state_dict(), + "optimizer_cvqgan": opt_cvq.state_dict(), + "loss": cvq_loss.item(), + } + + th.save(ckpt_cvq, "cvqgan.pth") + if args.track: + artifact_cvq = wandb.Artifact(f"{args.problem_id}_{args.algo}_cvqgan", type="model") + artifact_cvq.add_file("cvqgan.pth") + wandb.log_artifact(artifact_cvq, aliases=[f"seed_{args.seed}"]) + # Freeze CVQGAN for later use in Stage 2 Transformer for p in cvqgan.parameters(): p.requires_grad_(requires_grad=False) cvqgan.eval() + # -------------------------- + # Stage 1: Training VQGAN + # -------------------------- print("Stage 1: Training VQGAN") vqgan.train() discriminator.train() for epoch in tqdm.trange(args.n_epochs_1): - for _, data in enumerate(dataloader_1): + for i, data in enumerate(dataloader_1): # THIS IS PROBLEM DEPENDENT designs = data[0].to(dtype=th.float32, device=device) decoded_images, codebook_indices, q_loss = vqgan(designs) @@ -1736,15 +1779,65 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu opt_vq.step() opt_disc.step() + # ---------- + # Logging + # ---------- + if args.track: + batches_done = epoch * len(dataloader_1) + i + wandb.log( + { + "vq_loss": vq_loss.item(), + "d_loss": gan_loss.item(), + "epoch": epoch, + "batch": batches_done, + } + ) + print( + f"[Epoch {epoch}/{args.n_epochs_1}] [Batch {i}/{len(dataloader_1)}] [D loss: {gan_loss.item()}] [VQ loss: {vq_loss.item()}]" + ) + + # -------------- + # Save models + # -------------- + if args.save_model and epoch == args.n_epochs_1 - 1 and i == len(dataloader_1) - 1: + ckpt_vq = { + "epoch": epoch, + "batches_done": batches_done, + "vqgan": vqgan.state_dict(), + "optimizer_vqgan": opt_vq.state_dict(), + "loss": vq_loss.item(), + } + ckpt_disc = { + "epoch": epoch, + "batches_done": batches_done, + "discriminator": discriminator.state_dict(), + "optimizer_discriminator": opt_disc.state_dict(), + "loss": gan_loss.item(), + } + + th.save(ckpt_vq, "vqgan.pth") + th.save(ckpt_disc, "discriminator.pth") + if args.track: + artifact_vq = wandb.Artifact(f"{args.problem_id}_{args.algo}_vqgan", type="model") + artifact_vq.add_file("vqgan.pth") + artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model") + artifact_disc.add_file("discriminator.pth") + + wandb.log_artifact(artifact_vq, aliases=[f"seed_{args.seed}"]) + wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"]) + # Freeze VQGAN for later use in Stage 2 Transformer for p in vqgan.parameters(): p.requires_grad_(requires_grad=False) vqgan.eval() + # -------------------------------- + # Stage 2: Training Transformer + # -------------------------------- print("Stage 2: Training Transformer") transformer.train() - for _ in tqdm.trange(args.n_epochs_2): - for _, data in enumerate(dataloader_2): + for epoch in tqdm.trange(args.n_epochs_2): + for i, data in enumerate(dataloader_2): # THIS IS PROBLEM DEPENDENT designs = data[0].to(dtype=th.float32, device=device) conds = th.stack((data[1:]), dim=1).to(dtype=th.float32, device=device) @@ -1754,3 +1847,39 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu loss = f.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) loss.backward() opt_transformer.step() + + # ---------- + # Logging + # ---------- + if args.track: + batches_done = epoch * len(dataloader_2) + i + wandb.log( + { + "tr_loss": loss.item(), + "epoch": epoch, + "batch": batches_done, + } + ) + print( + f"[Epoch {epoch}/{args.n_epochs_2}] [Batch {i}/{len(dataloader_2)}] [Transformer loss: {loss.item()}]" + ) + + # -------------- + # Save model + # -------------- + if args.save_model and epoch == args.n_epochs_2 - 1 and i == len(dataloader_2) - 1: + ckpt_transformer = { + "epoch": epoch, + "batches_done": batches_done, + "transformer": transformer.state_dict(), + "optimizer_transformer": opt_transformer.state_dict(), + "loss": loss.item(), + } + + th.save(ckpt_transformer, "transformer.pth") + if args.track: + artifact_cvq = wandb.Artifact(f"{args.problem_id}_{args.algo}_transformer", type="model") + artifact_cvq.add_file("transformer.pth") + wandb.log_artifact(artifact_cvq, aliases=[f"seed_{args.seed}"]) + + wandb.finish() diff --git a/pyproject.toml b/pyproject.toml index 641efce..b22683e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,7 +255,7 @@ exclude = [ ".*/templates/.*", "^engibench_studies/problems/airfoil/study_[^/]*/", "^docs/", - "^wandb/", + "wandb/.*", ] [[tool.mypy.overrides]] module = [ @@ -296,6 +296,5 @@ module = [ "torchvision.*", "requests", "transformers", - "wandb", ] ignore_missing_imports = true From 93cc9cdcbff64abdcf6a77c1d598f715a5c71c78 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Mon, 29 Sep 2025 18:20:13 +0200 Subject: [PATCH 11/22] add image logging to training --- engiopt/vqgan/evaluate_vqgan.py | 65 +++++++++++ engiopt/vqgan/vqgan.py | 184 ++++++++++++++++++++++---------- 2 files changed, 195 insertions(+), 54 deletions(-) diff --git a/engiopt/vqgan/evaluate_vqgan.py b/engiopt/vqgan/evaluate_vqgan.py index 5ca6dc0..f5bb97d 100644 --- a/engiopt/vqgan/evaluate_vqgan.py +++ b/engiopt/vqgan/evaluate_vqgan.py @@ -1 +1,66 @@ +# ruff: noqa: F401 # REMOVE LATER """Evaluation for the VQGAN.""" + +from __future__ import annotations + +import dataclasses +import os + +from engibench.utils.all_problems import BUILTIN_PROBLEMS +import numpy as np +import pandas as pd +import torch as th +import tyro +import wandb + +from engiopt import metrics +from engiopt.dataset_sample_conditions import sample_conditions +from engiopt.vqgan.vqgan import VQGAN +from engiopt.vqgan.vqgan import VQGANTransformer + + +@dataclasses.dataclass +class Args: + """Command-line arguments for a single-seed VQGAN 2D evaluation.""" + + problem_id: str = "beams2d" + """Problem identifier.""" + seed: int = 1 + """Random seed to run.""" + wandb_project: str = "engiopt" + """Wandb project name.""" + wandb_entity: str | None = None + """Wandb entity name.""" + n_samples: int = 50 + """Number of generated samples per seed.""" + sigma: float = 10.0 + """Kernel bandwidth for MMD and DPP metrics.""" + output_csv: str = "vqgan_{problem_id}_metrics.csv" + """Output CSV path template; may include {problem_id}.""" + + +if __name__ == "__main__": + args = tyro.cli(Args) + + seed = args.seed + problem = BUILTIN_PROBLEMS[args.problem_id]() + problem.reset(seed=seed) + + # Reproducibility + th.manual_seed(seed) + rng = np.random.default_rng(seed) + th.backends.cudnn.deterministic = True + + if th.backends.mps.is_available(): + device = th.device("mps") + elif th.cuda.is_available(): + device = th.device("cuda") + else: + device = th.device("cpu") + + ### Set up testing conditions ### + conditions_tensor, sampled_conditions, sampled_designs_np, _ = sample_conditions( + problem=problem, n_samples=args.n_samples, device=device, seed=seed + ) + + # Reshape to match the expected input shape for the model diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index c00aa47..b503823 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -23,11 +23,11 @@ import os import random import time -from typing import TYPE_CHECKING import warnings from einops import rearrange from engibench.utils.all_problems import BUILTIN_PROBLEMS +import matplotlib.pyplot as plt import numpy as np import requests import torch as th @@ -44,9 +44,6 @@ from engiopt.transforms import normalize from engiopt.transforms import upsample_nearest -if TYPE_CHECKING: - import logging - # URL and checkpoint for LPIPS model URL_MAP = { "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" @@ -108,10 +105,8 @@ class Args: """size of the batches for CVQGAN""" n_epochs_0: int = 1000 # Default: 1000 """number of epochs of CVQGAN training""" - cond_lr: float = 2e-4 + cond_lr: float = 2e-4 # Default: 2e-4 """learning rate for CVQGAN""" - cond_sample_interval: int = 1600 - """interval between CVQGAN image samples""" # Algorithm-specific: Stage 1 (AE) # From original implementation: assume image_channels=1, use greyscale LPIPS only, use_Online=True, determine image_size automatically, calculate decoder_start_resolution automatically @@ -119,7 +114,7 @@ class Args: """number of epochs of training""" batch_size_1: int = 16 """size of the batches for Stage 1""" - lr_1: float = 2e-4 + lr_1: float = 2e-4 # Default: 2e-4 """learning rate for Stage 1""" beta: float = 0.25 """beta hyperparameter for the codebook commitment loss""" @@ -135,7 +130,7 @@ class Args: """number of vectors in the codebook""" disc_start: int = 0 """epoch to start discriminator training""" - disc_factor: float = 1.0 + disc_factor: float = 0.1 """weighting factor for the adversarial loss from the discriminator""" rec_loss_factor: float = 1.0 """weighting factor for the reconstruction loss""" @@ -153,8 +148,8 @@ class Args: """tuple of resolutions at which to apply attention in the decoder""" decoder_num_res_blocks: int = 3 """number of residual blocks per decoder layer""" - sample_interval: int = 1600 - """interval between image samples""" + sample_interval_1: int = 100 + """interval between Stage 1 image samples""" # Algorithm-specific: Stage 2 (Transformer) # From original implementation: assume pkeep=1.0, sos_token=0, bias=True @@ -162,7 +157,7 @@ class Args: """number of epochs of training""" batch_size_2: int = 16 """size of the batches for Stage 2""" - lr_2: float = 6e-4 + lr_2: float = 6e-4 # Default: 6e-4 """learning rate for Stage 2""" n_layer: int = 12 """number of layers in the transformer""" @@ -172,6 +167,8 @@ class Args: """transformer embedding dimension""" dropout: float = 0.3 """dropout rate in the transformer""" + sample_interval_2: int = 100 + """interval between Stage 2 image samples""" class Codebook(nn.Module): @@ -783,7 +780,6 @@ class GreyscaleLPIPS(nn.Module): warn_on_clamp: If True, log warnings when inputs fall outside [0, 1]. freeze: If True, disables grads on all params. ckpt_name: Key in URL_MAP/CKPT_MAP for loading LPIPS heads. - logger: Optional logger for non-intrusive messages/warnings. """ def __init__( # noqa: PLR0913 @@ -795,14 +791,12 @@ def __init__( # noqa: PLR0913 warn_on_clamp: bool = False, freeze: bool = True, ckpt_name: str = "vgg_lpips", - logger: logging.Logger | None = None, ): super().__init__() self.use_raw = use_raw self.clamp_output = clamp_output self.robust_clamp = robust_clamp self.warn_on_clamp = warn_on_clamp - self._logger = logger self.scaling_layer = ScalingLayer() self.channels = (64, 128, 256, 512, 512) @@ -816,19 +810,6 @@ def __init__( # noqa: PLR0913 def forward(self, real_x: th.Tensor, fake_x: th.Tensor) -> th.Tensor: """Compute greyscale-aware LPIPS distance between two batches.""" - if self.warn_on_clamp and self._logger is not None: - with th.no_grad(): - if (fake_x < 0).any() or (fake_x > 1).any(): - self._logger.warning( - "GreyscaleLPIPS: generated input outside [0,1]: [%.4f, %.4f]", - float(fake_x.min().item()), float(fake_x.max().item()), - ) - if (real_x < 0).any() or (real_x > 1).any(): - self._logger.warning( - "GreyscaleLPIPS: reference input outside [0,1]: [%.4f, %.4f]", - float(real_x.min().item()), float(real_x.max().item()), - ) - if self.robust_clamp: real_x = th.clamp(real_x, 0.0, 1.0) fake_x = th.clamp(fake_x, 0.0, 1.0) @@ -888,8 +869,6 @@ def _get_ckpt_path(self, name: str, root: str) -> str: assert name in URL_MAP, f"Unknown LPIPS checkpoint name: {name!r}" path = os.path.join(root, CKPT_MAP[name]) if not os.path.exists(path): - if self._logger is not None: - self._logger.info("Downloading LPIPS weights '%s' from %s to %s", name, URL_MAP[name], path) self._download(URL_MAP[name], path) return path @@ -1519,7 +1498,8 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu random.seed(args.seed) th.backends.cudnn.deterministic = True - os.makedirs("images", exist_ok=True) + os.makedirs("images/vqgan_1", exist_ok=True) + os.makedirs("images/vqgan_2", exist_ok=True) if th.backends.mps.is_available(): device = th.device("mps") @@ -1534,6 +1514,7 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu # Configure data loader (keep on CPU for preprocessing) training_ds = problem.dataset.with_format("torch")["train"] + len_dataset = len(training_ds) # Add in the upsampled optimal design column and remove the original optimal design column training_ds = training_ds.map( @@ -1559,6 +1540,7 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu n_conds = len(conditions) args.cond_dim = n_conds + condition_tensors = [training_ds[key][:] for key in conditions] # Move to device only here th_training_ds = th.utils.data.TensorDataset( @@ -1580,10 +1562,30 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu batch_size=args.batch_size_2, shuffle=True, ) - # Logging + # For logging a fixed set of designs in Stage 1 + n_logged_designs = 25 + fixed_indices = random.sample(range(len_dataset), n_logged_designs) + log_subset = th.utils.data.Subset(th_training_ds, fixed_indices) + log_dataloader = th.utils.data.DataLoader( + log_subset, + batch_size=n_logged_designs, + shuffle=False, + ) + + # Logging run_name = f"{args.problem_id}__{args.algo}__{args.seed}__{int(time.time())}" if args.track: wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), save_code=True, name=run_name, dir="./logs/wandb") + wandb.define_metric("0_step", summary="max") + wandb.define_metric("cvq_loss", step_metric="0_step") + wandb.define_metric("epoch_0", step_metric="0_step") + wandb.define_metric("1_step", summary="max") + wandb.define_metric("vq_loss", step_metric="1_step") + wandb.define_metric("d_loss", step_metric="1_step") + wandb.define_metric("epoch_1", step_metric="1_step") + wandb.define_metric("2_step", summary="max") + wandb.define_metric("tr_loss", step_metric="2_step") + wandb.define_metric("epoch_2", step_metric="2_step") vqgan = VQGAN( device=device, @@ -1683,6 +1685,46 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu perceptual_loss_fcn = GreyscaleLPIPS().eval().to(device) + @th.no_grad() + def sample_designs_1(n_designs: int) -> list[th.Tensor]: + """Sample reconstructions from trained VQGAN Stage 1.""" + vqgan.eval() + + designs, *_ = next(iter(log_dataloader)) + designs = designs[:n_designs].to(device) + reconstructions, _, _ = vqgan(designs) + + vqgan.train() + return reconstructions + + @th.no_grad() + def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: + """Sample generated designs from trained VQGAN Stage 2.""" + transformer.eval() + + # Create condition grid + all_conditions = th.stack(condition_tensors, dim=1) + linspaces = [ + th.linspace(all_conditions[:, i].min(), all_conditions[:, i].max(), n_designs, device=device) + for i in range(all_conditions.shape[1]) + ] + desired_conds = th.stack(linspaces, dim=1) + + if args.conditional: + c = transformer.encode_to_z(x=desired_conds, is_c=True)[1] + else: + c = th.ones(n_designs, 1, dtype=th.int64, device=device) * transformer.sos_token + + latent_imgs = transformer.sample( + x=th.empty(n_designs, 0, dtype=th.int64, device=device), + c=c, + steps=(args.latent_size ** 2) + ) + gen_imgs = transformer.z_to_image(latent_imgs) + + transformer.train() + return desired_conds, gen_imgs + # --------------------------- # Stage 0: Training CVQGAN # --------------------------- @@ -1706,13 +1748,8 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu # ---------- if args.track: batches_done = epoch * len(dataloader_0) + i - wandb.log( - { - "cvq_loss": cvq_loss.item(), - "epoch": epoch, - "batch": batches_done, - } - ) + wandb.log({"cvq_loss": cvq_loss.item(), "0_step": batches_done}) + wandb.log({"epoch_0": epoch, "0_step": batches_done}) print( f"[Epoch {epoch}/{args.n_epochs_0}] [Batch {i}/{len(dataloader_0)}] [CVQ loss: {cvq_loss.item()}]" ) @@ -1784,18 +1821,36 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu # ---------- if args.track: batches_done = epoch * len(dataloader_1) + i - wandb.log( - { - "vq_loss": vq_loss.item(), - "d_loss": gan_loss.item(), - "epoch": epoch, - "batch": batches_done, - } - ) + wandb.log({"vq_loss": vq_loss.item(), "1_step": batches_done}) + wandb.log({"d_loss": gan_loss.item(), "1_step": batches_done}) + wandb.log({"epoch_1": epoch, "1_step": batches_done}) print( f"[Epoch {epoch}/{args.n_epochs_1}] [Batch {i}/{len(dataloader_1)}] [D loss: {gan_loss.item()}] [VQ loss: {vq_loss.item()}]" ) + # This saves a grid image of 25 generated designs every sample_interval + if batches_done % args.sample_interval_1 == 0: + # Extract 25 designs + designs = sample_designs_1(n_designs=n_logged_designs) + fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + + # Flatten axes for easy indexing + axes = axes.flatten() + + # Plot each tensor as a scatter plot + for j, tensor in enumerate(designs): + img = tensor.cpu().numpy().reshape(design_shape[-2], design_shape[-1]) # Extract x and y coordinates + axes[j].imshow(img) # Scatter plot + axes[j].title.set_text(f"Design {j + 1}") # Set title + axes[j].set_xticks([]) # Hide x ticks + axes[j].set_yticks([]) # Hide y ticks + + plt.tight_layout() + img_fname = f"images/vqgan_1/{batches_done}.png" + plt.savefig(img_fname) + plt.close() + wandb.log({"designs_1": wandb.Image(img_fname)}) + # -------------- # Save models # -------------- @@ -1853,17 +1908,38 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu # ---------- if args.track: batches_done = epoch * len(dataloader_2) + i - wandb.log( - { - "tr_loss": loss.item(), - "epoch": epoch, - "batch": batches_done, - } - ) + wandb.log({"tr_loss": loss.item(), "2_step": batches_done}) + wandb.log({"epoch_2": epoch, "2_step": batches_done}) print( f"[Epoch {epoch}/{args.n_epochs_2}] [Batch {i}/{len(dataloader_2)}] [Transformer loss: {loss.item()}]" ) + # This saves a grid image of 25 generated designs every sample_interval + if batches_done % args.sample_interval_2 == 0: + # Extract 25 designs + desired_conds, designs = sample_designs_2(n_designs=n_logged_designs) + fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + + # Flatten axes for easy indexing + axes = axes.flatten() + + # Plot each tensor as a scatter plot + for j, tensor in enumerate(designs): + img = tensor.cpu().numpy().reshape(design_shape[-2], design_shape[-1]) # Extract x and y coordinates + dc = desired_conds[j].cpu() + axes[j].imshow(img) # Scatter plot + title = [(conditions[i][0], f"{dc[i]:.2f}") for i in range(n_conds)] + title_string = "\n ".join(f"{condition}: {value}" for condition, value in title) + axes[j].title.set_text(title_string) # Set title + axes[j].set_xticks([]) # Hide x ticks + axes[j].set_yticks([]) # Hide y ticks + + plt.tight_layout() + img_fname = f"images/vqgan_2/{batches_done}.png" + plt.savefig(img_fname) + plt.close() + wandb.log({"designs_2": wandb.Image(img_fname)}) + # -------------- # Save model # -------------- From ff937fa5436106f1fe520c436be8ac7431e8dd8f Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Tue, 30 Sep 2025 16:13:59 +0200 Subject: [PATCH 12/22] simplify args and transforms; remove unneeded transformer code --- engiopt/transforms.py | 47 +---------- engiopt/vqgan/vqgan.py | 178 +++++++---------------------------------- pyproject.toml | 1 - 3 files changed, 32 insertions(+), 194 deletions(-) diff --git a/engiopt/transforms.py b/engiopt/transforms.py index 5f35e9a..a507398 100644 --- a/engiopt/transforms.py +++ b/engiopt/transforms.py @@ -1,7 +1,6 @@ """Transformations for the data.""" from collections.abc import Callable -import math from datasets import Dataset from engibench.core import Problem @@ -26,58 +25,18 @@ def flatten_dict(x): return flatten_dict -def _nearest_power_of_two(x: int) -> int: - """Round x to the nearest power of 2.""" - lower = 2 ** math.floor(math.log2(x)) - upper = 2 ** math.ceil(math.log2(x)) - return upper if abs(x - upper) < abs(x - lower) else lower - - -def upsample_nearest(data: th.Tensor, mode: str = "bicubic") -> th.Tensor: - """Upsample 2D data to the nearest square 2^n based on the maximum dimension. - - Accepts input of shape (B, H, W) or (B, C, H, W). - """ - low_dim = 3 - if data.ndim == low_dim: - data = data.unsqueeze(1) # (B, 1, H, W) - _, _, h, w = data.shape - - max_dim = max(h, w) - target = _nearest_power_of_two(max_dim) - if target < max_dim: - target *= 2 - - return f.interpolate(data, size=(target, target), mode=mode) - - -def downsample_nearest(data: th.Tensor, mode: str = "bicubic") -> th.Tensor: - """Downsample 2D data to the nearest square 2^n based on the maximum dimension. - - Accepts input of shape (B, H, W) or (B, C, H, W). - """ +def resize_to(data: th.Tensor, h: int, w: int, mode: str = "bicubic") -> th.Tensor: + """Resize 2D data back to any desired (h, w). Data should be a Tensor in the format (B, C, H, W).""" low_dim = 3 if data.ndim == low_dim: data = data.unsqueeze(1) # (B, 1, H, W) - _, _, h, w = data.shape - - max_dim = max(h, w) - target = _nearest_power_of_two(max_dim) - if target > max_dim: - target //= 2 - - return f.interpolate(data, size=(target, target), mode=mode) - - -def resize_to(data: th.Tensor, h: int, w: int, mode: str = "bicubic") -> th.Tensor: - """Resize 2D data back to any desired (h, w). Data should be a Tensor in the format (B, C, H, W).""" return f.interpolate(data, size=(h, w), mode=mode) def normalize( ds: Dataset, condition_names: list[str] ) -> tuple[Dataset, th.Tensor, th.Tensor]: - """Normalize specified condition columns with global mean/std (torch version, CPU).""" + """Normalize specified condition columns with global mean/std.""" # stack condition columns into a single tensor (N, C) on CPU conds = th.stack([th.as_tensor(ds[c][:]).float() for c in condition_names], dim=1) mean = conds.mean(dim=0) diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index b503823..e373fd5 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -18,7 +18,6 @@ from __future__ import annotations from dataclasses import dataclass -import inspect import math import os import random @@ -36,13 +35,12 @@ from torchvision.models import vgg16 from torchvision.models import VGG16_Weights import tqdm -from transformers import GPT2LMHeadModel import tyro import wandb from engiopt.transforms import drop_constant from engiopt.transforms import normalize -from engiopt.transforms import upsample_nearest +from engiopt.transforms import resize_to # URL and checkpoint for LPIPS model URL_MAP = { @@ -83,11 +81,7 @@ class Args: drop_constant_conditions: bool = True """whether to drop constant condition columns (i.e., overhang_constraint in beams2d)""" image_size: int = 128 - """size of each image dimension (determined automatically later)""" - image_channels: int = 1 - """number of channels in the input image (determined automatically later)""" - latent_size: int = 16 - """size of each latent feature map dimension (determined automatically later)""" + """desired size of the square image input to the model""" # Algorithm-specific: Stage 1 Conditional AE or "CVQGAN" if the model is specified as conditional # Note that a Discriminator is not used for CVQGAN, as it is generally a much simpler model. @@ -114,7 +108,7 @@ class Args: """number of epochs of training""" batch_size_1: int = 16 """size of the batches for Stage 1""" - lr_1: float = 2e-4 # Default: 2e-4 + lr_1: float = 5e-5 # Default: 2e-4 """learning rate for Stage 1""" beta: float = 0.25 """beta hyperparameter for the codebook commitment loss""" @@ -134,7 +128,7 @@ class Args: """weighting factor for the adversarial loss from the discriminator""" rec_loss_factor: float = 1.0 """weighting factor for the reconstruction loss""" - perceptual_loss_factor: float = 1.0 + perceptual_loss_factor: float = 0.1 """weighting factor for the perceptual loss""" encoder_channels: tuple[int, ...] = (64, 64, 128, 128, 256) """tuple of channel sizes for each encoder layer""" @@ -1120,7 +1114,7 @@ class GPTConfig: class GPT(nn.Module): - """Minimal GPT-2 style Transformer with HF weight import.""" + """Minimal GPT-2 style Transformer.""" def __init__(self, config: GPTConfig): super().__init__() @@ -1143,13 +1137,6 @@ def __init__(self, config: GPTConfig): if pn.endswith("c_proj.weight"): th.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) - def get_num_params(self, *, non_embedding: bool = True) -> int: - """Return total parameter count (optionally excluding position embeddings).""" - n_params = sum(p.numel() for p in self.parameters()) - if non_embedding: - n_params -= self.transformer["wpe"].weight.numel() - return n_params - def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): th.nn.init.normal_(module.weight, mean=0.0, std=0.02) @@ -1185,126 +1172,6 @@ def forward( else: loss = None return logits, loss - - def crop_block_size(self, block_size: int) -> None: - """Reduce maximum context length and trim position embeddings.""" - assert block_size <= self.config.block_size - self.config.block_size = block_size - self.transformer["wpe"].weight = nn.Parameter(self.transformer["wpe"].weight[:block_size]) - for block in self.transformer["h"]: - attn = block.attn - if hasattr(attn, "bias"): - attn.bias = attn.bias[:, :, :block_size, :block_size] - - @classmethod - def from_pretrained( - cls, - model_type: str, - override_args: dict[str, float] | None = None, - ) -> GPT: - """Load HF GPT-2 weights into this minimal GPT implementation.""" - assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} - override_args = override_args or {} - assert all(k == "dropout" for k in override_args), "Only 'dropout' can be overridden" - - cfg_map: dict[str, dict[str, int]] = { - "gpt2": {"n_layer": 12, "n_head": 12, "n_embd": 768}, - "gpt2-medium": {"n_layer": 24, "n_head": 16, "n_embd": 1024}, - "gpt2-large": {"n_layer": 36, "n_head": 20, "n_embd": 1280}, - "gpt2-xl": {"n_layer": 48, "n_head": 25, "n_embd": 1600}, - } - - # Use object so we can mix int, float, and bool - config_args: dict[str, object] = dict(cfg_map[model_type]) - config_args.update({"vocab_size": 50257, "block_size": 1024, "bias": True}) - - if "dropout" in override_args: - config_args["dropout"] = float(override_args["dropout"]) - - config = GPTConfig(**config_args) # type: ignore[arg-type] - model = GPT(config) - - sd = model.state_dict() - sd_keys = [k for k in sd if not k.endswith(".attn.bias")] - - hf: GPT2LMHeadModel = GPT2LMHeadModel.from_pretrained(model_type) - sd_hf = hf.state_dict() - sd_keys_hf = [ - k - for k in sd_hf - if not (k.endswith((".attn.masked_bias", ".attn.bias"))) - ] - - transposed = {"attn.c_attn.weight", "attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight"} - assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" - - for k in sd_keys_hf: - if any(k.endswith(w) for w in transposed): - assert sd_hf[k].shape[::-1] == sd[k].shape - with th.no_grad(): - sd[k].copy_(sd_hf[k].t()) - else: - assert sd_hf[k].shape == sd[k].shape - with th.no_grad(): - sd[k].copy_(sd_hf[k]) - - return model - - def configure_optimizers( - self, - weight_decay: float, - learning_rate: float, - betas: tuple[float, float], - device_type: str, - ) -> th.optim.Optimizer: - """Create AdamW with decoupled weight decay for matrix weights only.""" - param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} - dim_threshold = 2 - decay_params = [p for p in param_dict.values() if p.dim() >= dim_threshold] - nodecay_params = [p for p in param_dict.values() if p.dim() < dim_threshold] - optim_groups = [ - {"params": decay_params, "weight_decay": weight_decay}, - {"params": nodecay_params, "weight_decay": 0.0}, - ] - - fused_available = "fused" in inspect.signature(th.optim.AdamW).parameters - use_fused = bool(fused_available and device_type == "cuda") - extra_args: dict[str, object] = {"fused": True} if use_fused else {} - return th.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) - - def estimate_mfu(self, fwdbwd_per_iter: int, dt: float) -> float: - """Estimate model FLOPS utilization relative to A100 bf16 peak (312 TFLOPS).""" - n = self.get_num_params() - cfg = self.config - l, h, q, t = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size - flops_per_token = 6 * n + 12 * l * h * q * t - flops_per_fwdbwd = flops_per_token * t - flops_per_iter = flops_per_fwdbwd * float(fwdbwd_per_iter) - flops_achieved = flops_per_iter * (1.0 / dt) - flops_peak = 312e12 - return float(flops_achieved / flops_peak) - - @th.no_grad() - def generate( - self, - idx: th.Tensor, - max_new_tokens: int, - *, - temperature: float = 1.0, - top_k: int | None = None, - ) -> th.Tensor: - """Autoregressively sample tokens conditioned on idx.""" - for _ in range(max_new_tokens): - idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :] - logits, _ = self(idx_cond) - logits = logits[:, -1, :] / temperature - if top_k is not None: - v, _ = th.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float("inf") - probs = f.softmax(logits, dim=-1) - idx_next = th.multinomial(probs, num_samples=1) - idx = th.cat((idx, idx_next), dim=1) - return idx ########################################### ########## GPT-2 BASE CODE ABOVE ########## ########################################### @@ -1519,15 +1386,19 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu # Add in the upsampled optimal design column and remove the original optimal design column training_ds = training_ds.map( lambda batch: { - "optimal_upsampled": upsample_nearest(batch["optimal_design"][:]).cpu().numpy() + "optimal_upsampled": resize_to( + data=batch["optimal_design"][:], + h=args.image_size, + w=args.image_size + ).cpu().numpy() }, batched=True, ) training_ds = training_ds.remove_columns("optimal_design") - design_shape = training_ds["optimal_upsampled"][:].shape[-2:] - args.image_size = training_ds["optimal_upsampled"][:].shape[-1] - args.latent_size = args.image_size // (2 ** (len(args.encoder_channels) - 2)) + # Now we assume the dataset is of shape (N, C, H, W) and work from there + image_channels = training_ds["optimal_upsampled"][:].shape[1] + latent_size = args.image_size // (2 ** (len(args.encoder_channels) - 2)) conditions = problem.conditions_keys # Optionally normalize condition columns @@ -1595,15 +1466,15 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu encoder_attn_resolutions=args.encoder_attn_resolutions, encoder_num_res_blocks=args.encoder_num_res_blocks, decoder_channels=args.decoder_channels, - decoder_start_resolution=args.latent_size, + decoder_start_resolution=latent_size, decoder_attn_resolutions=args.decoder_attn_resolutions, decoder_num_res_blocks=args.decoder_num_res_blocks, - image_channels=args.image_channels, + image_channels=image_channels, latent_dim=args.latent_dim, num_codebook_vectors=args.num_codebook_vectors ).to(device=device) - discriminator = Discriminator(image_channels=args.image_channels).to(device=device) + discriminator = Discriminator(image_channels=image_channels).to(device=device) cvqgan = VQGAN( device=device, @@ -1718,7 +1589,7 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: latent_imgs = transformer.sample( x=th.empty(n_designs, 0, dtype=th.int64, device=device), c=c, - steps=(args.latent_size ** 2) + steps=(latent_size ** 2) ) gen_imgs = transformer.z_to_image(latent_imgs) @@ -1831,7 +1702,11 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: # This saves a grid image of 25 generated designs every sample_interval if batches_done % args.sample_interval_1 == 0: # Extract 25 designs - designs = sample_designs_1(n_designs=n_logged_designs) + designs = resize_to( + data=sample_designs_1(n_designs=n_logged_designs), + h=design_shape[0], + w=design_shape[1] + ) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing @@ -1839,7 +1714,7 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: # Plot each tensor as a scatter plot for j, tensor in enumerate(designs): - img = tensor.cpu().numpy().reshape(design_shape[-2], design_shape[-1]) # Extract x and y coordinates + img = tensor.cpu().numpy().reshape(design_shape[0], design_shape[1]) # Extract x and y coordinates axes[j].imshow(img) # Scatter plot axes[j].title.set_text(f"Design {j + 1}") # Set title axes[j].set_xticks([]) # Hide x ticks @@ -1918,6 +1793,11 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: if batches_done % args.sample_interval_2 == 0: # Extract 25 designs desired_conds, designs = sample_designs_2(n_designs=n_logged_designs) + designs = resize_to( + data=designs, + h=design_shape[0], + w=design_shape[1] + ) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing @@ -1925,7 +1805,7 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: # Plot each tensor as a scatter plot for j, tensor in enumerate(designs): - img = tensor.cpu().numpy().reshape(design_shape[-2], design_shape[-1]) # Extract x and y coordinates + img = tensor.cpu().numpy().reshape(design_shape[0], design_shape[1]) # Extract x and y coordinates dc = desired_conds[j].cpu() axes[j].imshow(img) # Scatter plot title = [(conditions[i][0], f"{dc[i]:.2f}") for i in range(n_conds)] diff --git a/pyproject.toml b/pyproject.toml index b22683e..72221a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -295,6 +295,5 @@ module = [ "einops.*", "torchvision.*", "requests", - "transformers", ] ignore_missing_imports = true From e7bafcef663ca393e096c522fa2177e0685e3d78 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Tue, 30 Sep 2025 16:45:58 +0200 Subject: [PATCH 13/22] minor plotting fixes --- README.md | 1 + engiopt/vqgan/vqgan.py | 12 +++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 45a8c47..39d4c9a 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ As much as we can, we follow the [CleanRL](https://github.com/vwxyzjn/cleanrl) p [gan_bezier](engiopt/gan_bezier/) | Inverse Design | 1D | ❌ | GAN + Bezier layer [gan_cnn_2d](engiopt/gan_cnn_2d/) | Inverse Design | 2D | ❌ | GAN + CNN [surrogate_model](engiopt/surrogate_model/) | Surrogate Model | 1D | ❌ | MLP +[vqgan](engiopt/vqgan) | Inverse Design | 2D | ✅ | VQVAE + Transformer ## Dashboards The integration with WandB allows us to access live dashboards of our runs (on the cluster or not). We also upload the trained models there. You can access some of our runs at https://wandb.ai/engibench/engiopt. diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index e373fd5..358fc38 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -1401,14 +1401,14 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu latent_size = args.image_size // (2 ** (len(args.encoder_channels) - 2)) conditions = problem.conditions_keys - # Optionally normalize condition columns - if args.normalize_conditions: - training_ds, mean, std = normalize(training_ds, conditions) - # Optionally drop condition columns that are constant like overhang_constraint in beams2d if args.drop_constant_conditions: training_ds, conditions = drop_constant(training_ds, conditions) + # Optionally normalize condition columns + if args.normalize_conditions: + training_ds, mean, std = normalize(training_ds, conditions) + n_conds = len(conditions) args.cond_dim = n_conds condition_tensors = [training_ds[key][:] for key in conditions] @@ -1716,7 +1716,7 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: for j, tensor in enumerate(designs): img = tensor.cpu().numpy().reshape(design_shape[0], design_shape[1]) # Extract x and y coordinates axes[j].imshow(img) # Scatter plot - axes[j].title.set_text(f"Design {j + 1}") # Set title + axes[j].title.set_text(f"Reconstruction {j + 1}") # Set title axes[j].set_xticks([]) # Hide x ticks axes[j].set_yticks([]) # Hide y ticks @@ -1793,6 +1793,8 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: if batches_done % args.sample_interval_2 == 0: # Extract 25 designs desired_conds, designs = sample_designs_2(n_designs=n_logged_designs) + if args.normalize_conditions: + desired_conds = (desired_conds.cpu() * std) + mean designs = resize_to( data=designs, h=design_shape[0], From b845834444888ec2bb2f73e64c6299694f5e2252 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Tue, 30 Sep 2025 18:25:43 +0200 Subject: [PATCH 14/22] add initial vqgan eval script --- engiopt/vqgan/evaluate_vqgan.py | 153 +++++++++++++++++++++++++++++++- engiopt/vqgan/vqgan.py | 16 ++-- 2 files changed, 160 insertions(+), 9 deletions(-) diff --git a/engiopt/vqgan/evaluate_vqgan.py b/engiopt/vqgan/evaluate_vqgan.py index f5bb97d..0985227 100644 --- a/engiopt/vqgan/evaluate_vqgan.py +++ b/engiopt/vqgan/evaluate_vqgan.py @@ -1,4 +1,3 @@ -# ruff: noqa: F401 # REMOVE LATER """Evaluation for the VQGAN.""" from __future__ import annotations @@ -15,6 +14,9 @@ from engiopt import metrics from engiopt.dataset_sample_conditions import sample_conditions +from engiopt.transforms import drop_constant +from engiopt.transforms import normalize +from engiopt.transforms import resize_to from engiopt.vqgan.vqgan import VQGAN from engiopt.vqgan.vqgan import VQGANTransformer @@ -58,9 +60,154 @@ class Args: else: device = th.device("cpu") + ### Set Up Transformer ### + + # Restores the pytorch model from wandb + if args.wandb_entity is not None: + artifact_path_0 = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}" + artifact_path_1 = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}" + artifact_path_2 = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}" + else: + artifact_path_0 = f"{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}" + artifact_path_1 = f"{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}" + artifact_path_2 = f"{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}" + + api = wandb.Api() + artifact_0 = api.artifact(artifact_path_0, type="model") + artifact_1 = api.artifact(artifact_path_1, type="model") + artifact_2 = api.artifact(artifact_path_2, type="model") + + class RunRetrievalError(ValueError): + def __init__(self): + super().__init__("Failed to retrieve the run") + + run = artifact_2.logged_by() + if run is None or not hasattr(run, "config"): + raise RunRetrievalError + artifact_dir_0 = artifact_0.download() + artifact_dir_1 = artifact_1.download() + artifact_dir_2 = artifact_2.download() + + ckpt_path_0 = os.path.join(artifact_dir_0, "cvqgan.pth") + ckpt_path_1 = os.path.join(artifact_dir_1, "vqgan.pth") + ckpt_path_2 = os.path.join(artifact_dir_2, "transformer.pth") + ckpt_0 = th.load(ckpt_path_0, map_location=th.device(device), weights_only=False) + ckpt_1 = th.load(ckpt_path_1, map_location=th.device(device), weights_only=False) + ckpt_2 = th.load(ckpt_path_2, map_location=th.device(device), weights_only=False) + + vqgan = VQGAN( + device=device, + is_c=False, + encoder_channels=run.config["encoder_channels"], + encoder_start_resolution=run.config["image_size"], + encoder_attn_resolutions=run.config["encoder_attn_resolutions"], + encoder_num_res_blocks=run.config["encoder_num_res_blocks"], + decoder_channels=run.config["decoder_channels"], + decoder_start_resolution=run.config["latent_size"], + decoder_attn_resolutions=run.config["decoder_attn_resolutions"], + decoder_num_res_blocks=run.config["decoder_num_res_blocks"], + image_channels=run.config["image_channels"], + latent_dim=run.config["latent_dim"], + num_codebook_vectors=run.config["num_codebook_vectors"] + ) + vqgan.load_state_dict(ckpt_1["vqgan"]) + vqgan.eval() # Set to evaluation mode + vqgan.to(device) + + cvqgan = VQGAN( + device=device, + is_c=True, + cond_feature_map_dim=run.config["cond_feature_map_dim"], + cond_dim=run.config["cond_dim"], + cond_hidden_dim=run.config["cond_hidden_dim"], + cond_latent_dim=run.config["cond_latent_dim"], + cond_codebook_vectors=run.config["cond_codebook_vectors"] + ) + cvqgan.load_state_dict(ckpt_0["cvqgan"]) + cvqgan.eval() # Set to evaluation mode + cvqgan.to(device) + + model = VQGANTransformer( + conditional=run.config["conditional"], + vqgan=vqgan, + cvqgan=cvqgan, + image_size=run.config["image_size"], + decoder_channels=run.config["decoder_channels"], + cond_feature_map_dim=run.config["cond_feature_map_dim"], + num_codebook_vectors=run.config["num_codebook_vectors"], + n_layer=run.config["n_layer"], + n_head=run.config["n_head"], + n_embd=run.config["n_embd"], + dropout=run.config["dropout"] + ) + model.load_state_dict(ckpt_2["transformer"]) + model.eval() # Set to evaluation mode + model.to(device) + ### Set up testing conditions ### - conditions_tensor, sampled_conditions, sampled_designs_np, _ = sample_conditions( + _, sampled_conditions, sampled_designs_np, _ = sample_conditions( problem=problem, n_samples=args.n_samples, device=device, seed=seed ) - # Reshape to match the expected input shape for the model + # Clean up conditions based on model training settings and convert back to tensor + sampled_conditions_new = sampled_conditions + if run.config["drop_constant_conditions"]: + sampled_conditions_new, conditions = drop_constant(sampled_conditions_new, sampled_conditions_new.column_names) + + if run.config["normalize_conditions"]: + sampled_conditions_new, mean, std = normalize(sampled_conditions_new, conditions) + + conditions_tensor = th.stack( + [th.as_tensor(sampled_conditions_new[key][:]).float() for key in conditions], + dim=1 + ).to(device) + + # Set the start-of-sequence tokens for the transformer using the CVQGAN to discretize the conditions if enabled + if run.config["conditional"]: + c = model.encode_to_z(x=conditions_tensor, is_c=True)[1] + else: + c = th.ones(args.n_samples, 1, dtype=th.int64, device=device) * model.sos_token + + # Generate a batch of designs + latent_designs = model.sample( + x=th.empty(args.n_samples, 0, dtype=th.int64, device=device), + c=c, + steps=(run.config["latent_size"] ** 2) + ) + gen_designs = resize_to( + data=model.z_to_image(latent_designs), + h=problem.design_space.shape[0], + w=problem.design_space.shape[1] + ) + gen_designs_np = gen_designs.detach().cpu().numpy() + gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) + + # Clip to boundaries for running THIS IS PROBLEM DEPENDENT + gen_designs_np = np.clip(gen_designs_np, 1e-3, 1) + + # Compute metrics + metrics_dict = metrics.metrics( + problem, + gen_designs_np, + sampled_designs_np, + sampled_conditions, + sigma=args.sigma, + ) + + metrics_dict.update( + { + "seed": seed, + "problem_id": args.problem_id, + "model_id": "vqgan", + "n_samples": args.n_samples, + "sigma": args.sigma, + } + ) + + # Append result row to CSV + metrics_df = pd.DataFrame([metrics_dict]) + out_path = args.output_csv.format(problem_id=args.problem_id) + write_header = not os.path.exists(out_path) + metrics_df.to_csv(out_path, mode="a", header=write_header, index=False) + + print(f"Seed {seed} done; appended to {out_path}") diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 358fc38..3218433 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -101,6 +101,10 @@ class Args: """number of epochs of CVQGAN training""" cond_lr: float = 2e-4 # Default: 2e-4 """learning rate for CVQGAN""" + latent_size: int = 16 + """size of the latent feature map (automatically determined later)""" + image_channels: int = 1 + """number of channels in the input image (automatically determined later)""" # Algorithm-specific: Stage 1 (AE) # From original implementation: assume image_channels=1, use greyscale LPIPS only, use_Online=True, determine image_size automatically, calculate decoder_start_resolution automatically @@ -1397,8 +1401,8 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu training_ds = training_ds.remove_columns("optimal_design") # Now we assume the dataset is of shape (N, C, H, W) and work from there - image_channels = training_ds["optimal_upsampled"][:].shape[1] - latent_size = args.image_size // (2 ** (len(args.encoder_channels) - 2)) + args.image_channels = training_ds["optimal_upsampled"][:].shape[1] + args.latent_size = args.image_size // (2 ** (len(args.encoder_channels) - 2)) conditions = problem.conditions_keys # Optionally drop condition columns that are constant like overhang_constraint in beams2d @@ -1466,15 +1470,15 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu encoder_attn_resolutions=args.encoder_attn_resolutions, encoder_num_res_blocks=args.encoder_num_res_blocks, decoder_channels=args.decoder_channels, - decoder_start_resolution=latent_size, + decoder_start_resolution=args.latent_size, decoder_attn_resolutions=args.decoder_attn_resolutions, decoder_num_res_blocks=args.decoder_num_res_blocks, - image_channels=image_channels, + image_channels=args.image_channels, latent_dim=args.latent_dim, num_codebook_vectors=args.num_codebook_vectors ).to(device=device) - discriminator = Discriminator(image_channels=image_channels).to(device=device) + discriminator = Discriminator(image_channels=args.image_channels).to(device=device) cvqgan = VQGAN( device=device, @@ -1589,7 +1593,7 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: latent_imgs = transformer.sample( x=th.empty(n_designs, 0, dtype=th.int64, device=device), c=c, - steps=(latent_size ** 2) + steps=(args.latent_size ** 2) ) gen_imgs = transformer.z_to_image(latent_imgs) From c5433451548d113bc398e8dc7b966f479bea1c77 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Tue, 30 Sep 2025 18:31:19 +0200 Subject: [PATCH 15/22] minor eval conditions fixes --- engiopt/vqgan/evaluate_vqgan.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/engiopt/vqgan/evaluate_vqgan.py b/engiopt/vqgan/evaluate_vqgan.py index 0985227..03160fb 100644 --- a/engiopt/vqgan/evaluate_vqgan.py +++ b/engiopt/vqgan/evaluate_vqgan.py @@ -150,15 +150,22 @@ def __init__(self): ) # Clean up conditions based on model training settings and convert back to tensor - sampled_conditions_new = sampled_conditions + sampled_conditions_new = sampled_conditions.select(range(len(sampled_conditions))) + conditions = sampled_conditions_new.column_names + + # Drop constant condition columns if enabled if run.config["drop_constant_conditions"]: - sampled_conditions_new, conditions = drop_constant(sampled_conditions_new, sampled_conditions_new.column_names) + sampled_conditions_new, conditions = drop_constant( + sampled_conditions_new, sampled_conditions_new.column_names + ) + # Normalize condition columns if enabled if run.config["normalize_conditions"]: sampled_conditions_new, mean, std = normalize(sampled_conditions_new, conditions) + # Convert to tensor conditions_tensor = th.stack( - [th.as_tensor(sampled_conditions_new[key][:]).float() for key in conditions], + [th.as_tensor(sampled_conditions_new[c][:]).float() for c in conditions], dim=1 ).to(device) From ef8aa725388ed346e73590832e98672532639573 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Tue, 30 Sep 2025 19:16:56 +0200 Subject: [PATCH 16/22] ruff fixes and add dependencies --- .gitattributes | 2 +- .gitignore | 1 + engiopt/transforms.py | 11 +- engiopt/vqgan/evaluate_vqgan.py | 23 +-- engiopt/vqgan/vqgan.py | 328 ++++++++++++++------------------ pyproject.toml | 3 +- setup.py | 2 +- 7 files changed, 161 insertions(+), 209 deletions(-) diff --git a/.gitattributes b/.gitattributes index 9137b2f..dac068e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -13,4 +13,4 @@ *.gif binary *.pdf binary *.pkl binary -*.npy binary \ No newline at end of file +*.npy binary diff --git a/.gitignore b/.gitignore index 052e04a..128df19 100644 --- a/.gitignore +++ b/.gitignore @@ -161,6 +161,7 @@ cython_debug/ wandb/* images/* logs/* +*.csv # Editors .idea/ .vscode/ diff --git a/engiopt/transforms.py b/engiopt/transforms.py index a507398..c26d02f 100644 --- a/engiopt/transforms.py +++ b/engiopt/transforms.py @@ -33,9 +33,7 @@ def resize_to(data: th.Tensor, h: int, w: int, mode: str = "bicubic") -> th.Tens return f.interpolate(data, size=(h, w), mode=mode) -def normalize( - ds: Dataset, condition_names: list[str] -) -> tuple[Dataset, th.Tensor, th.Tensor]: +def normalize(ds: Dataset, condition_names: list[str]) -> tuple[Dataset, th.Tensor, th.Tensor]: """Normalize specified condition columns with global mean/std.""" # stack condition columns into a single tensor (N, C) on CPU conds = th.stack([th.as_tensor(ds[c][:]).float() for c in condition_names], dim=1) @@ -45,8 +43,7 @@ def normalize( # normalize each condition column (HF expects numpy back) ds = ds.map( lambda batch: { - c: ((th.as_tensor(batch[c][:]).float() - mean[i]) / std[i]).numpy() - for i, c in enumerate(condition_names) + c: ((th.as_tensor(batch[c][:]).float() - mean[i]) / std[i]).numpy() for i, c in enumerate(condition_names) }, batched=True, ) @@ -54,9 +51,7 @@ def normalize( return ds, mean, std -def drop_constant( - ds: Dataset, condition_names: list[str] -) -> tuple[Dataset, list[str]]: +def drop_constant(ds: Dataset, condition_names: list[str]) -> tuple[Dataset, list[str]]: """Drop constant condition columns (std=0) from dataset.""" conds = th.stack([th.as_tensor(ds[c][:]).float() for c in condition_names], dim=1) std = conds.std(dim=0) diff --git a/engiopt/vqgan/evaluate_vqgan.py b/engiopt/vqgan/evaluate_vqgan.py index 03160fb..6f00087 100644 --- a/engiopt/vqgan/evaluate_vqgan.py +++ b/engiopt/vqgan/evaluate_vqgan.py @@ -108,7 +108,7 @@ def __init__(self): decoder_num_res_blocks=run.config["decoder_num_res_blocks"], image_channels=run.config["image_channels"], latent_dim=run.config["latent_dim"], - num_codebook_vectors=run.config["num_codebook_vectors"] + num_codebook_vectors=run.config["num_codebook_vectors"], ) vqgan.load_state_dict(ckpt_1["vqgan"]) vqgan.eval() # Set to evaluation mode @@ -121,7 +121,7 @@ def __init__(self): cond_dim=run.config["cond_dim"], cond_hidden_dim=run.config["cond_hidden_dim"], cond_latent_dim=run.config["cond_latent_dim"], - cond_codebook_vectors=run.config["cond_codebook_vectors"] + cond_codebook_vectors=run.config["cond_codebook_vectors"], ) cvqgan.load_state_dict(ckpt_0["cvqgan"]) cvqgan.eval() # Set to evaluation mode @@ -138,7 +138,7 @@ def __init__(self): n_layer=run.config["n_layer"], n_head=run.config["n_head"], n_embd=run.config["n_embd"], - dropout=run.config["dropout"] + dropout=run.config["dropout"], ) model.load_state_dict(ckpt_2["transformer"]) model.eval() # Set to evaluation mode @@ -155,19 +155,14 @@ def __init__(self): # Drop constant condition columns if enabled if run.config["drop_constant_conditions"]: - sampled_conditions_new, conditions = drop_constant( - sampled_conditions_new, sampled_conditions_new.column_names - ) + sampled_conditions_new, conditions = drop_constant(sampled_conditions_new, sampled_conditions_new.column_names) # Normalize condition columns if enabled if run.config["normalize_conditions"]: sampled_conditions_new, mean, std = normalize(sampled_conditions_new, conditions) # Convert to tensor - conditions_tensor = th.stack( - [th.as_tensor(sampled_conditions_new[c][:]).float() for c in conditions], - dim=1 - ).to(device) + conditions_tensor = th.stack([th.as_tensor(sampled_conditions_new[c][:]).float() for c in conditions], dim=1).to(device) # Set the start-of-sequence tokens for the transformer using the CVQGAN to discretize the conditions if enabled if run.config["conditional"]: @@ -177,14 +172,10 @@ def __init__(self): # Generate a batch of designs latent_designs = model.sample( - x=th.empty(args.n_samples, 0, dtype=th.int64, device=device), - c=c, - steps=(run.config["latent_size"] ** 2) + x=th.empty(args.n_samples, 0, dtype=th.int64, device=device), c=c, steps=(run.config["latent_size"] ** 2) ) gen_designs = resize_to( - data=model.z_to_image(latent_designs), - h=problem.design_space.shape[0], - w=problem.design_space.shape[1] + data=model.z_to_image(latent_designs), h=problem.design_space.shape[0], w=problem.design_space.shape[1] ) gen_designs_np = gen_designs.detach().cpu().numpy() gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 3218433..8656836 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -43,13 +43,9 @@ from engiopt.transforms import resize_to # URL and checkpoint for LPIPS model -URL_MAP = { - "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" -} +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} -CKPT_MAP = { - "vgg_lpips": "vgg.pth" -} +CKPT_MAP = {"vgg_lpips": "vgg.pth"} @dataclass @@ -183,8 +179,10 @@ class Codebook(nn.Module): contras_loss (bool): if true, use the contras_loss to further improve the performance init (bool): if true, the codebook has been initialized """ + def __init__( # noqa: PLR0913 - self, *, + self, + *, num_codebook_vectors: int, latent_dim: int, beta: float = 0.25, @@ -220,9 +218,11 @@ def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Ten # clculate the distance if self.distance == "l2": # l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = - th.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \ - th.sum(self.embedding.weight ** 2, dim=1) + \ - 2 * th.einsum("bd, dn-> bn", z_flattened.detach(), rearrange(self.embedding.weight, "n d-> d n")) + d = ( + -th.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) + - th.sum(self.embedding.weight**2, dim=1) + + 2 * th.einsum("bd, dn-> bn", z_flattened.detach(), rearrange(self.embedding.weight, "n d-> d n")) + ) elif self.distance == "cos": # cosine distances from z to embeddings e_j normed_z_flattened = f.normalize(z_flattened, dim=1).detach() @@ -232,14 +232,14 @@ def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Ten # encoding sort_distance, indices = d.sort(dim=1) # look up the closest point for the indices - encoding_indices = indices[:,-1] + encoding_indices = indices[:, -1] encodings = th.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device) encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) # quantize and unflatten z_q = th.matmul(encodings, self.embedding.weight).view(z.shape) # compute loss for embedding - loss = self.beta * th.mean((z_q.detach()-z)**2) + th.mean((z_q - z.detach()) ** 2) + loss = self.beta * th.mean((z_q.detach() - z) ** 2) + th.mean((z_q - z.detach()) ** 2) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape @@ -252,13 +252,13 @@ def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Ten # online clustered reinitialization for unoptimized points if self.training: # calculate the average usage of code entries - self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay) + self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1 - self.decay) # running average updates if self.anchor in ["closest", "random", "probrandom"] and (not self.init): # closest sampling if self.anchor == "closest": sort_distance, indices = d.sort(dim=0) - random_feat = z_flattened.detach()[indices[-1,:]] + random_feat = z_flattened.detach()[indices[-1, :]] # feature pool based random sampling elif self.anchor == "random": random_feat = self.pool.query(z_flattened.detach()) @@ -268,18 +268,22 @@ def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Ten prob = th.multinomial(norm_distance, num_samples=1).view(-1) random_feat = z_flattened.detach()[prob] # decay parameter based on the average usage - decay = th.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim) + decay = ( + th.exp(-(self.embed_prob * self.num_embed * 10) / (1 - self.decay) - 1e-3) + .unsqueeze(1) + .repeat(1, self.embed_dim) + ) self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay if self.first_batch: self.init = True # contrastive loss if self.contras_loss: sort_distance, indices = d.sort(dim=0) - dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True) - dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:] + dis_pos = sort_distance[-max(1, int(sort_distance.size(0) / self.num_embed)) :, :].mean(dim=0, keepdim=True) + dis_neg = sort_distance[: int(sort_distance.size(0) * 1 / 2), :] dis = th.cat([dis_pos, dis_neg], dim=0).t() / 0.07 contra_loss = f.cross_entropy(dis, th.zeros((dis.size(0),), dtype=th.long, device=dis.device)) - loss += contra_loss + loss += contra_loss return z_q, encoding_indices, loss, min_encodings, perplexity @@ -293,11 +297,8 @@ class FeaturePool: pool_size (int): the size of feature buffer dim (int): the dimension of each feature """ - def __init__( - self, - pool_size: int, - dim: int = 64 - ): + + def __init__(self, pool_size: int, dim: int = 64): self.pool_size = pool_size if self.pool_size > 0: self.nums_features = 0 @@ -307,21 +308,21 @@ def query(self, features: th.Tensor) -> th.Tensor: """Return features from the pool.""" self.features = self.features.to(features.device) if self.nums_features < self.pool_size: - if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook + if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook random_feat_id = th.randint(0, features.size(0), (int(self.pool_size),)) self.features = features[random_feat_id] self.nums_features = self.pool_size else: # if the mini-batch is not large nuough, just store it for the next update num = self.nums_features + features.size(0) - self.features[self.nums_features:num] = features + self.features[self.nums_features : num] = features self.nums_features = num elif features.size(0) > int(self.pool_size): random_feat_id = th.randint(0, features.size(0), (int(self.pool_size),)) self.features = features[random_feat_id] else: random_id = th.randperm(self.pool_size) - self.features[random_id[:features.size(0)]] = features + self.features[random_id[: features.size(0)]] = features return self.features @@ -332,10 +333,8 @@ class GroupNorm(nn.Module): Parameters: channels (int): number of channels in the input feature map """ - def __init__( - self, - channels: int - ): + + def __init__(self, channels: int): super().__init__() self.gn = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True) @@ -345,6 +344,7 @@ def forward(self, x: th.Tensor) -> th.Tensor: class Swish(nn.Module): """Swish activation function to be used in VQGAN Encoder and Decoder.""" + def forward(self, x: th.Tensor) -> th.Tensor: return x * th.sigmoid(x) @@ -356,11 +356,8 @@ class ResidualBlock(nn.Module): in_channels (int): number of input channels out_channels (int): number of output channels """ - def __init__( - self, - in_channels: int, - out_channels: int - ): + + def __init__(self, in_channels: int, out_channels: int): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -370,7 +367,7 @@ def __init__( nn.Conv2d(in_channels, out_channels, 3, 1, 1), GroupNorm(out_channels), Swish(), - nn.Conv2d(out_channels, out_channels, 3, 1, 1) + nn.Conv2d(out_channels, out_channels, 3, 1, 1), ) if in_channels != out_channels: @@ -388,10 +385,8 @@ class UpSampleBlock(nn.Module): Parameters: channels (int): number of channels in the input feature map """ - def __init__( - self, - channels: int - ): + + def __init__(self, channels: int): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, 1, 1) @@ -406,10 +401,8 @@ class DownSampleBlock(nn.Module): Parameters: channels (int): number of channels in the input feature map """ - def __init__( - self, - channels: int - ): + + def __init__(self, channels: int): super().__init__() self.conv = nn.Conv2d(channels, channels, 3, 2, 0) @@ -425,10 +418,8 @@ class NonLocalBlock(nn.Module): Parameters: channels (int): number of channels in the input feature map """ - def __init__( - self, - channels: int - ): + + def __init__(self, channels: int): super().__init__() self.in_channels = channels @@ -446,13 +437,13 @@ def forward(self, x: th.Tensor) -> th.Tensor: b, c, h, w = q.shape - q = q.reshape(b, c, h*w) + q = q.reshape(b, c, h * w) q = q.permute(0, 2, 1) - k = k.reshape(b, c, h*w) - v = v.reshape(b, c, h*w) + k = k.reshape(b, c, h * w) + v = v.reshape(b, c, h * w) attn = th.bmm(q, k) - attn = attn * (int(c)**(-0.5)) + attn = attn * (int(c) ** (-0.5)) attn = f.softmax(attn, dim=2) attn = attn.permute(0, 2, 1) @@ -470,17 +461,10 @@ class LinearCombo(nn.Module): out_features (int): number of output features alpha (float): negative slope for LeakyReLU """ - def __init__( - self, - in_features: int, - out_features: int, - alpha: float = 0.2 - ): + + def __init__(self, in_features: int, out_features: int, alpha: float = 0.2): super().__init__() - self.model = nn.Sequential( - nn.Linear(in_features, out_features), - nn.LeakyReLU(alpha) - ) + self.model = nn.Sequential(nn.Linear(in_features, out_features), nn.LeakyReLU(alpha)) def forward(self, x: th.Tensor) -> th.Tensor: return self.model(x) @@ -501,6 +485,7 @@ class Encoder(nn.Module): image_channels (int): number of channels in the input image latent_dim (int): dimensionality of the latent space """ + def __init__( # noqa: PLR0913 self, encoder_channels: tuple[int, ...], @@ -514,7 +499,7 @@ def __init__( # noqa: PLR0913 channels = encoder_channels resolution = encoder_start_resolution layers = [nn.Conv2d(image_channels, channels[0], 3, 1, 1)] - for i in range(len(channels)-1): + for i in range(len(channels) - 1): in_channels = channels[i] out_channels = channels[i + 1] for _ in range(encoder_num_res_blocks): @@ -522,8 +507,8 @@ def __init__( # noqa: PLR0913 in_channels = out_channels if resolution in encoder_attn_resolutions: layers.append(NonLocalBlock(in_channels)) - if i != len(channels)-2: - layers.append(DownSampleBlock(channels[i+1])) + if i != len(channels) - 2: + layers.append(DownSampleBlock(channels[i + 1])) resolution //= 2 layers.append(ResidualBlock(channels[-1], channels[-1])) layers.append(NonLocalBlock(channels[-1])) @@ -546,25 +531,20 @@ class CondEncoder(nn.Module): cond_hidden_dim (int): hidden dimension of the CVQGAN MLP cond_latent_dim (int): individual code dimension for CVQGAN """ - def __init__( - self, - cond_feature_map_dim: int, - cond_dim: int, - cond_hidden_dim: int, - cond_latent_dim: int - ): + + def __init__(self, cond_feature_map_dim: int, cond_dim: int, cond_hidden_dim: int, cond_latent_dim: int): super().__init__() self.c_feature_map_dim = cond_feature_map_dim self.model = nn.Sequential( LinearCombo(cond_dim, cond_hidden_dim), LinearCombo(cond_hidden_dim, cond_hidden_dim), - nn.Linear(cond_hidden_dim, cond_latent_dim*cond_feature_map_dim**2) + nn.Linear(cond_hidden_dim, cond_latent_dim * cond_feature_map_dim**2), ) def forward(self, x: th.Tensor) -> th.Tensor: encoded = self.model(x) s = encoded.shape - return encoded.view(s[0], s[1]//self.c_feature_map_dim**2, self.c_feature_map_dim, self.c_feature_map_dim) + return encoded.view(s[0], s[1] // self.c_feature_map_dim**2, self.c_feature_map_dim, self.c_feature_map_dim) class Decoder(nn.Module): @@ -582,6 +562,7 @@ class Decoder(nn.Module): image_channels (int): number of channels in the output image latent_dim (int): dimensionality of the latent space """ + def __init__( # noqa: PLR0913 self, decoder_channels: tuple[int, ...], @@ -589,15 +570,17 @@ def __init__( # noqa: PLR0913 decoder_attn_resolutions: tuple[int, ...], decoder_num_res_blocks: int, image_channels: int, - latent_dim: int + latent_dim: int, ): super().__init__() in_channels = decoder_channels[0] resolution = decoder_start_resolution - layers = [nn.Conv2d(latent_dim, in_channels, 3, 1, 1), - ResidualBlock(in_channels, in_channels), - NonLocalBlock(in_channels), - ResidualBlock(in_channels, in_channels)] + layers = [ + nn.Conv2d(latent_dim, in_channels, 3, 1, 1), + ResidualBlock(in_channels, in_channels), + NonLocalBlock(in_channels), + ResidualBlock(in_channels, in_channels), + ] for i in range(len(decoder_channels)): out_channels = decoder_channels[i] @@ -629,19 +612,14 @@ class CondDecoder(nn.Module): cond_hidden_dim (int): hidden dimension of the CVQGAN MLP cond_latent_dim (int): individual code dimension for CVQGAN """ - def __init__( - self, - cond_latent_dim: int, - cond_dim: int, - cond_hidden_dim: int, - cond_feature_map_dim: int - ): + + def __init__(self, cond_latent_dim: int, cond_dim: int, cond_hidden_dim: int, cond_feature_map_dim: int): super().__init__() self.model = nn.Sequential( - LinearCombo(cond_latent_dim*cond_feature_map_dim**2, cond_hidden_dim), + LinearCombo(cond_latent_dim * cond_feature_map_dim**2, cond_hidden_dim), LinearCombo(cond_hidden_dim, cond_hidden_dim), - nn.Linear(cond_hidden_dim, cond_dim) + nn.Linear(cond_hidden_dim, cond_dim), ) def forward(self, x: th.Tensor) -> th.Tensor: @@ -660,12 +638,7 @@ class Discriminator(nn.Module): image_channels: Number of channels in the input image. """ - def __init__( - self, - num_filters_last: int = 64, - n_layers: int = 3, - image_channels: int = 1 - ): + def __init__(self, num_filters_last: int = 64, n_layers: int = 3, image_channels: int = 1): super().__init__() # Convolutional backbone (PatchGAN) @@ -691,15 +664,12 @@ def __init__( nn.LeakyReLU(0.2, inplace=True), ] - layers.append( - nn.Conv2d(num_filters_last * num_filters_mult, 1, kernel_size=4, stride=1, padding=1) - ) + layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, kernel_size=4, stride=1, padding=1)) self.model = nn.Sequential(*layers) # Initialize weights self.apply(self._weights_init) - @staticmethod def _weights_init(m: nn.Module) -> None: """Custom weight initialization (DCGAN-style).""" @@ -710,7 +680,6 @@ def _weights_init(m: nn.Module) -> None: nn.init.normal_(m.weight.data, mean=1.0, std=0.02) nn.init.constant_(m.bias.data, 0.0) - def forward(self, x: th.Tensor) -> th.Tensor: """Forward pass with optional CVQGAN adapter.""" return self.model(x) @@ -721,8 +690,8 @@ class ScalingLayer(nn.Module): def __init__(self): super().__init__() - self.register_buffer("shift", th.tensor([-.030, -.088, -.188])[None, :, None, None]) - self.register_buffer("scale", th.tensor([.458, .448, .450])[None, :, None, None]) + self.register_buffer("shift", th.tensor([-0.030, -0.088, -0.188])[None, :, None, None]) + self.register_buffer("scale", th.tensor([0.458, 0.448, 0.450])[None, :, None, None]) def forward(self, x: th.Tensor) -> th.Tensor: return (x - self.shift) / self.scale @@ -749,9 +718,9 @@ def __init__(self): super().__init__() vgg_feats = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features blocks = [vgg_feats[i] for i in range(30)] - self.slice1 = nn.Sequential(*blocks[0:4]) # relu1_2 - self.slice2 = nn.Sequential(*blocks[4:9]) # relu2_2 - self.slice3 = nn.Sequential(*blocks[9:16]) # relu3_3 + self.slice1 = nn.Sequential(*blocks[0:4]) # relu1_2 + self.slice2 = nn.Sequential(*blocks[4:9]) # relu2_2 + self.slice3 = nn.Sequential(*blocks[9:16]) # relu3_3 self.slice4 = nn.Sequential(*blocks[16:23]) # relu4_3 self.slice5 = nn.Sequential(*blocks[23:30]) # relu5_3 self.requires_grad_(requires_grad=False) @@ -805,7 +774,6 @@ def __init__( # noqa: PLR0913 if freeze: self.requires_grad_(requires_grad=False) - def forward(self, real_x: th.Tensor, fake_x: th.Tensor) -> th.Tensor: """Compute greyscale-aware LPIPS distance between two batches.""" if self.robust_clamp: @@ -900,10 +868,11 @@ class VQGAN(nn.Module): latent_dim (int): Dimensionality of the latent space. num_codebook_vectors (int): Number of codebook vectors. """ + def __init__( # noqa: PLR0913 - self, *, + self, + *, device: th.device, - # CVQGAN parameters is_c: bool = False, cond_feature_map_dim: int = 4, @@ -911,7 +880,6 @@ def __init__( # noqa: PLR0913 cond_hidden_dim: int = 256, cond_latent_dim: int = 4, cond_codebook_vectors: int = 64, - # VQGAN + Codebook parameters encoder_channels: tuple[int, ...] = (64, 64, 128, 128, 256), encoder_start_resolution: int = 128, @@ -923,23 +891,13 @@ def __init__( # noqa: PLR0913 decoder_num_res_blocks: int = 3, image_channels: int = 1, latent_dim: int = 16, - num_codebook_vectors: int = 256 + num_codebook_vectors: int = 256, ): super().__init__() if is_c: - self.encoder = CondEncoder( - cond_feature_map_dim, - cond_dim, - cond_hidden_dim, - cond_latent_dim - ).to(device=device) + self.encoder = CondEncoder(cond_feature_map_dim, cond_dim, cond_hidden_dim, cond_latent_dim).to(device=device) - self.decoder = CondDecoder( - cond_latent_dim, - cond_dim, - cond_hidden_dim, - cond_feature_map_dim - ).to(device=device) + self.decoder = CondDecoder(cond_latent_dim, cond_dim, cond_hidden_dim, cond_feature_map_dim).to(device=device) self.quant_conv = nn.Conv2d(cond_latent_dim, cond_latent_dim, 1).to(device=device) self.post_quant_conv = nn.Conv2d(cond_latent_dim, cond_latent_dim, 1).to(device=device) @@ -950,7 +908,7 @@ def __init__( # noqa: PLR0913 encoder_attn_resolutions, encoder_num_res_blocks, image_channels, - latent_dim + latent_dim, ).to(device=device) self.decoder = Decoder( @@ -959,15 +917,15 @@ def __init__( # noqa: PLR0913 decoder_attn_resolutions, decoder_num_res_blocks, image_channels, - latent_dim + latent_dim, ).to(device=device) self.quant_conv = nn.Conv2d(latent_dim, latent_dim, 1).to(device=device) self.post_quant_conv = nn.Conv2d(latent_dim, latent_dim, 1).to(device=device) self.codebook = Codebook( - num_codebook_vectors = cond_codebook_vectors if is_c else num_codebook_vectors, - latent_dim = cond_latent_dim if is_c else latent_dim + num_codebook_vectors=cond_codebook_vectors if is_c else num_codebook_vectors, + latent_dim=cond_latent_dim if is_c else latent_dim, ).to(device=device) def forward(self, designs: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor]: @@ -1043,9 +1001,7 @@ def __init__(self, config: GPTConfig): ) self.register_buffer( "bias", - th.tril(th.ones(config.block_size, config.block_size)).view( - 1, 1, config.block_size, config.block_size - ), + th.tril(th.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size), ) def forward(self, x: th.Tensor) -> th.Tensor: @@ -1058,7 +1014,9 @@ def forward(self, x: th.Tensor) -> th.Tensor: if self.flash: y = f.scaled_dot_product_attention( - q, k, v, + q, + k, + v, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True, @@ -1157,9 +1115,7 @@ def forward( """Forward pass returning logits and optional cross-entropy loss.""" device = idx.device _, t = idx.size() - assert t <= self.config.block_size, ( - f"Cannot forward sequence of length {t}; block size is {self.config.block_size}" - ) + assert t <= self.config.block_size, f"Cannot forward sequence of length {t}; block size is {self.config.block_size}" pos = th.arange(0, t, dtype=th.long, device=device) tok_emb = self.transformer["wte"](idx) @@ -1176,6 +1132,8 @@ def forward( else: loss = None return logits, loss + + ########################################### ########## GPT-2 BASE CODE ABOVE ########## ########################################### @@ -1200,8 +1158,10 @@ class VQGANTransformer(nn.Module): dropout (float): Dropout rate in the Transformer. bias (bool): If True, use bias terms in the Transformer layers. """ + def __init__( # noqa: PLR0913 - self, *, + self, + *, conditional: bool = True, vqgan: VQGAN, cvqgan: VQGAN, @@ -1213,7 +1173,7 @@ def __init__( # noqa: PLR0913 n_head: int, n_embd: int, dropout: int, - bias: bool = True + bias: bool = True, ): super().__init__() self.sos_token = 0 @@ -1223,7 +1183,7 @@ def __init__( # noqa: PLR0913 # block_size is automatically set to the combined sequence length of the VQGAN and CVQGAN block_size = (image_size // (2 ** (len(decoder_channels) - 1))) ** 2 if conditional: - block_size += cond_feature_map_dim ** 2 + block_size += cond_feature_map_dim**2 # Create config object for NanoGPT transformer_config = GPTConfig( @@ -1232,8 +1192,8 @@ def __init__( # noqa: PLR0913 n_layer=n_layer, n_head=n_head, n_embd=n_embd, - dropout=dropout, # Add dropout parameter (default in nanoGPT) - bias=bias # Add bias parameter (default in nanoGPT) + dropout=dropout, # Add dropout parameter (default in nanoGPT) + bias=bias, # Add bias parameter (default in nanoGPT) ) self.transformer = GPT(transformer_config) self.conditional = conditional @@ -1282,7 +1242,7 @@ def forward(self, x: th.Tensor, c: th.Tensor, pkeep: float = 1.0) -> tuple[th.Te # NanoGPT forward doesn't use embeddings parameter, but takes targets # We're ignoring the loss returned by NanoGPT logits, _ = self.transformer(new_indices[:, :-1], None) - logits = logits[:, -indices.shape[1]:] # Always predict the last 256 tokens + logits = logits[:, -indices.shape[1] :] # Always predict the last 256 tokens return logits, target @@ -1294,7 +1254,9 @@ def top_k_logits(self, logits: th.Tensor, k: int) -> th.Tensor: return out @th.no_grad() - def sample(self, x: th.Tensor, c: th.Tensor, steps: int, temperature: float = 1.0, top_k: int | None = None) -> th.Tensor: + def sample( + self, x: th.Tensor, c: th.Tensor, steps: int, temperature: float = 1.0, top_k: int | None = None + ) -> th.Tensor: """Autoregressively sample from the model given initial context x and conditional c.""" x = th.cat((c, x), dim=1) @@ -1327,7 +1289,7 @@ def sample(self, x: th.Tensor, c: th.Tensor, steps: int, temperature: float = 1. x = th.cat((x, ix), dim=1) - return x[:, c.shape[1]:] + return x[:, c.shape[1] :] @th.no_grad() def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tuple[dict[str, th.Tensor], th.Tensor]: @@ -1342,8 +1304,10 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu sos_tokens = th.ones(x.shape[0], 1) * self.sos_token sos_tokens = sos_tokens.long().to(x.device) - start_indices = indices[:, :indices.shape[1] // 2] - sample_indices = self.sample(start_indices, sos_tokens, steps=indices.shape[1] - start_indices.shape[1], top_k=top_k) + start_indices = indices[:, : indices.shape[1] // 2] + sample_indices = self.sample( + start_indices, sos_tokens, steps=indices.shape[1] - start_indices.shape[1], top_k=top_k + ) half_sample = self.z_to_image(sample_indices) start_indices = indices[:, :0] @@ -1390,11 +1354,9 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu # Add in the upsampled optimal design column and remove the original optimal design column training_ds = training_ds.map( lambda batch: { - "optimal_upsampled": resize_to( - data=batch["optimal_design"][:], - h=args.image_size, - w=args.image_size - ).cpu().numpy() + "optimal_upsampled": resize_to(data=batch["optimal_design"][:], h=args.image_size, w=args.image_size) + .cpu() + .numpy() }, batched=True, ) @@ -1450,7 +1412,14 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu # Logging run_name = f"{args.problem_id}__{args.algo}__{args.seed}__{int(time.time())}" if args.track: - wandb.init(project=args.wandb_project, entity=args.wandb_entity, config=vars(args), save_code=True, name=run_name, dir="./logs/wandb") + wandb.init( + project=args.wandb_project, + entity=args.wandb_entity, + config=vars(args), + save_code=True, + name=run_name, + dir="./logs/wandb", + ) wandb.define_metric("0_step", summary="max") wandb.define_metric("cvq_loss", step_metric="0_step") wandb.define_metric("epoch_0", step_metric="0_step") @@ -1475,7 +1444,7 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu decoder_num_res_blocks=args.decoder_num_res_blocks, image_channels=args.image_channels, latent_dim=args.latent_dim, - num_codebook_vectors=args.num_codebook_vectors + num_codebook_vectors=args.num_codebook_vectors, ).to(device=device) discriminator = Discriminator(image_channels=args.image_channels).to(device=device) @@ -1487,7 +1456,7 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu cond_dim=args.cond_dim, cond_hidden_dim=args.cond_hidden_dim, cond_latent_dim=args.cond_latent_dim, - cond_codebook_vectors=args.cond_codebook_vectors + cond_codebook_vectors=args.cond_codebook_vectors, ).to(device=device) transformer = VQGANTransformer( @@ -1501,35 +1470,38 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, - dropout=args.dropout + dropout=args.dropout, ).to(device=device) # CVQGAN Stage 0 optimizer opt_cvq = th.optim.Adam( - list(cvqgan.encoder.parameters()) + - list(cvqgan.decoder.parameters()) + - list(cvqgan.codebook.parameters()) + - list(cvqgan.quant_conv.parameters()) + - list(cvqgan.post_quant_conv.parameters()), - lr=args.cond_lr, eps=1e-08, betas=(args.b1, args.b2) - ) + list(cvqgan.encoder.parameters()) + + list(cvqgan.decoder.parameters()) + + list(cvqgan.codebook.parameters()) + + list(cvqgan.quant_conv.parameters()) + + list(cvqgan.post_quant_conv.parameters()), + lr=args.cond_lr, + eps=1e-08, + betas=(args.b1, args.b2), + ) # VQGAN Stage 1 optimizer opt_vq = th.optim.Adam( - list(vqgan.encoder.parameters()) + - list(vqgan.decoder.parameters()) + - list(vqgan.codebook.parameters()) + - list(vqgan.quant_conv.parameters()) + - list(vqgan.post_quant_conv.parameters()), - lr=args.lr_1, eps=1e-08, betas=(args.b1, args.b2) - ) + list(vqgan.encoder.parameters()) + + list(vqgan.decoder.parameters()) + + list(vqgan.codebook.parameters()) + + list(vqgan.quant_conv.parameters()) + + list(vqgan.post_quant_conv.parameters()), + lr=args.lr_1, + eps=1e-08, + betas=(args.b1, args.b2), + ) # VQGAN Stage 1 discriminator optimizer - opt_disc = th.optim.Adam(discriminator.parameters(), - lr=args.lr_1, eps=1e-08, betas=(args.b1, args.b2)) + opt_disc = th.optim.Adam(discriminator.parameters(), lr=args.lr_1, eps=1e-08, betas=(args.b1, args.b2)) # Transformer Stage 2 optimizer decay, no_decay = set(), set() - whitelist_weight_modules = (nn.Linear, ) + whitelist_weight_modules = (nn.Linear,) blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) for mn, m in transformer.transformer.named_modules(): @@ -1591,9 +1563,7 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: c = th.ones(n_designs, 1, dtype=th.int64, device=device) * transformer.sos_token latent_imgs = transformer.sample( - x=th.empty(n_designs, 0, dtype=th.int64, device=device), - c=c, - steps=(args.latent_size ** 2) + x=th.empty(n_designs, 0, dtype=th.int64, device=device), c=c, steps=(args.latent_size**2) ) gen_imgs = transformer.z_to_image(latent_imgs) @@ -1678,9 +1648,9 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: lamb = vqgan.calculate_lambda(perceptual_rec_loss, g_loss) vq_loss = perceptual_rec_loss + q_loss + disc_factor * lamb * g_loss - d_loss_real = th.mean(f.relu(1. - disc_real)) - d_loss_fake = th.mean(f.relu(1. + disc_fake)) - gan_loss = disc_factor * 0.5*(d_loss_real + d_loss_fake) + d_loss_real = th.mean(f.relu(1.0 - disc_real)) + d_loss_fake = th.mean(f.relu(1.0 + disc_fake)) + gan_loss = disc_factor * 0.5 * (d_loss_real + d_loss_fake) opt_vq.zero_grad() vq_loss.backward(retain_graph=True) @@ -1707,9 +1677,7 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: if batches_done % args.sample_interval_1 == 0: # Extract 25 designs designs = resize_to( - data=sample_designs_1(n_designs=n_logged_designs), - h=design_shape[0], - w=design_shape[1] + data=sample_designs_1(n_designs=n_logged_designs), h=design_shape[0], w=design_shape[1] ) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) @@ -1799,11 +1767,7 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: desired_conds, designs = sample_designs_2(n_designs=n_logged_designs) if args.normalize_conditions: desired_conds = (desired_conds.cpu() * std) + mean - designs = resize_to( - data=designs, - h=design_shape[0], - w=design_shape[1] - ) + designs = resize_to(data=designs, h=design_shape[0], w=design_shape[1]) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) # Flatten axes for easy indexing diff --git a/pyproject.toml b/pyproject.toml index 72221a1..5ff340a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,8 @@ dependencies = [ "hyppo >= 0.5.0", "kaleido >= 0.2.1", "datasets >=4.0.0", + "einops >=0.8.0", + "requests >=2.31.0", ] dynamic = ["version"] @@ -132,7 +134,6 @@ ignore = [ "S607", # start-process-with-partial-path "T201", # print "TRY003", # print - "PD901", # avoid using df to name dataframe ] extend-select = ["I"] diff --git a/setup.py b/setup.py index 875a6b2..dc8f9f5 100644 --- a/setup.py +++ b/setup.py @@ -19,5 +19,5 @@ def get_version(): name="engiopt", version=get_version(), long_description=open("README.md", encoding="utf-8").read(), - long_description_content_type="text/markdown" + long_description_content_type="text/markdown", ) From 992a3fd38bfd4ad97c0d6092129189b536d66475 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Thu, 2 Oct 2025 11:20:21 +0200 Subject: [PATCH 17/22] add early stopping arg for transformer --- engiopt/vqgan/vqgan.py | 80 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 8656836..2d47aa9 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -149,6 +149,12 @@ class Args: # From original implementation: assume pkeep=1.0, sos_token=0, bias=True n_epochs_2: int = 100 # Default: 100 """number of epochs of training""" + early_stopping: bool = True + """whether to use early stopping for the transformer based on the held-out validation loss""" + early_stopping_patience: int = 3 + """number of epochs with no improvement after which training will be stopped""" + early_stopping_delta: float = 1e-3 + """minimum change in the monitored quantity to qualify as an improvement""" batch_size_2: int = 16 """size of the batches for Stage 2""" lr_2: float = 6e-4 # Default: 6e-4 @@ -1399,6 +1405,46 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu batch_size=args.batch_size_2, shuffle=True, ) + + # If early stopping enabled, create a validation dataloader + if args.early_stopping: + val_ds = problem.dataset.with_format("torch")["val"] + val_ds = val_ds.map( + lambda batch: { + "optimal_upsampled": resize_to(data=batch["optimal_design"][:], h=args.image_size, w=args.image_size) + .cpu() + .numpy() + }, + batched=True, + ) + val_ds = val_ds.remove_columns("optimal_design") + + # Optionally drop condition columns that are constant like overhang_constraint in beams2d + if args.drop_constant_conditions: + to_drop = [c for c in problem.conditions_keys if c not in conditions] + if to_drop: + val_ds = val_ds.remove_columns(to_drop) + + # If enabled, normalize using training mean/std (computed above) + if args.normalize_conditions: + val_ds = val_ds.map( + lambda batch: { + c: ((th.as_tensor(batch[c][:]).float() - mean[i]) / std[i]).numpy() for i, c in enumerate(conditions) + }, + batched=True, + ) + + # Move to device only here + th_val_ds = th.utils.data.TensorDataset( + th.as_tensor(val_ds["optimal_upsampled"][:]).to(device), + *[th.as_tensor(val_ds[key][:]).to(device) for key in conditions], + ) + dataloader_val = th.utils.data.DataLoader( + th_val_ds, + batch_size=args.batch_size_2, + shuffle=False, + ) + # For logging a fixed set of designs in Stage 1 n_logged_designs = 25 fixed_indices = random.sample(range(len_dataset), n_logged_designs) @@ -1430,6 +1476,8 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu wandb.define_metric("2_step", summary="max") wandb.define_metric("tr_loss", step_metric="2_step") wandb.define_metric("epoch_2", step_metric="2_step") + if args.early_stopping: + wandb.define_metric("tr_val_loss", step_metric="2_step") vqgan = VQGAN( device=device, @@ -1738,6 +1786,13 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: # -------------------------------- print("Stage 2: Training Transformer") transformer.train() + + # If early stopping enabled, initialize necessary variables + if args.early_stopping: + best_val = float("inf") + patience_counter = 0 + patience = args.early_stopping_patience + for epoch in tqdm.trange(args.n_epochs_2): for i, data in enumerate(dataloader_2): # THIS IS PROBLEM DEPENDENT @@ -1808,4 +1863,29 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: artifact_cvq.add_file("transformer.pth") wandb.log_artifact(artifact_cvq, aliases=[f"seed_{args.seed}"]) + # Early stopping based on held-out validation loss + if args.early_stopping: + transformer.eval() + val_losses = [] + with th.no_grad(): + for val_data in dataloader_val: + val_designs = val_data[0].to(dtype=th.float32, device=device) + val_conds = th.stack((val_data[1:]), dim=1).to(dtype=th.float32, device=device) + val_logits, val_targets = transformer(val_designs, val_conds) + val_loss = f.cross_entropy(val_logits.reshape(-1, val_logits.size(-1)), val_targets.reshape(-1)) + val_losses.append(val_loss.item()) + val_loss = sum(val_losses) / len(val_losses) + if args.track: + wandb.log({"tr_val_loss": val_loss, "2_step": batches_done}) + + if val_loss < best_val - args.early_stopping_delta: + best_val = val_loss + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= patience: + print(f"Early stopping at epoch {epoch} | best val loss: {best_val:.6f}") + break + transformer.train() + wandb.finish() From f0b36d84f318f81a1fde48030311ac5fd83e2d0e Mon Sep 17 00:00:00 2001 From: arthurdrake1 Date: Fri, 3 Oct 2025 16:40:58 +0200 Subject: [PATCH 18/22] fix early stopping model saving --- .gitignore | 1 + engiopt/vqgan/vqgan.py | 75 +++++++++++++++++++++++------------------- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 128df19..463236b 100644 --- a/.gitignore +++ b/.gitignore @@ -128,6 +128,7 @@ celerybeat.pid # Environments .env .venv +engiopt_env/ env/ venv/ ENV/ diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 2d47aa9..b447f0c 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -150,7 +150,7 @@ class Args: n_epochs_2: int = 100 # Default: 100 """number of epochs of training""" early_stopping: bool = True - """whether to use early stopping for the transformer based on the held-out validation loss""" + """whether to use early stopping for the transformer; if True requires args.track to be True""" early_stopping_patience: int = 3 """number of epochs with no improvement after which training will be stopped""" early_stopping_delta: float = 1e-3 @@ -1660,10 +1660,9 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: } th.save(ckpt_cvq, "cvqgan.pth") - if args.track: - artifact_cvq = wandb.Artifact(f"{args.problem_id}_{args.algo}_cvqgan", type="model") - artifact_cvq.add_file("cvqgan.pth") - wandb.log_artifact(artifact_cvq, aliases=[f"seed_{args.seed}"]) + artifact_cvq = wandb.Artifact(f"{args.problem_id}_{args.algo}_cvqgan", type="model") + artifact_cvq.add_file("cvqgan.pth") + wandb.log_artifact(artifact_cvq, aliases=[f"seed_{args.seed}"]) # Freeze CVQGAN for later use in Stage 2 Transformer for p in cvqgan.parameters(): @@ -1767,14 +1766,13 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: th.save(ckpt_vq, "vqgan.pth") th.save(ckpt_disc, "discriminator.pth") - if args.track: - artifact_vq = wandb.Artifact(f"{args.problem_id}_{args.algo}_vqgan", type="model") - artifact_vq.add_file("vqgan.pth") - artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model") - artifact_disc.add_file("discriminator.pth") + artifact_vq = wandb.Artifact(f"{args.problem_id}_{args.algo}_vqgan", type="model") + artifact_vq.add_file("vqgan.pth") + artifact_disc = wandb.Artifact(f"{args.problem_id}_{args.algo}_discriminator", type="model") + artifact_disc.add_file("discriminator.pth") - wandb.log_artifact(artifact_vq, aliases=[f"seed_{args.seed}"]) - wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"]) + wandb.log_artifact(artifact_vq, aliases=[f"seed_{args.seed}"]) + wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"]) # Freeze VQGAN for later use in Stage 2 Transformer for p in vqgan.parameters(): @@ -1845,26 +1843,8 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: plt.close() wandb.log({"designs_2": wandb.Image(img_fname)}) - # -------------- - # Save model - # -------------- - if args.save_model and epoch == args.n_epochs_2 - 1 and i == len(dataloader_2) - 1: - ckpt_transformer = { - "epoch": epoch, - "batches_done": batches_done, - "transformer": transformer.state_dict(), - "optimizer_transformer": opt_transformer.state_dict(), - "loss": loss.item(), - } - - th.save(ckpt_transformer, "transformer.pth") - if args.track: - artifact_cvq = wandb.Artifact(f"{args.problem_id}_{args.algo}_transformer", type="model") - artifact_cvq.add_file("transformer.pth") - wandb.log_artifact(artifact_cvq, aliases=[f"seed_{args.seed}"]) - # Early stopping based on held-out validation loss - if args.early_stopping: + if args.track and args.early_stopping: transformer.eval() val_losses = [] with th.no_grad(): @@ -1875,12 +1855,23 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: val_loss = f.cross_entropy(val_logits.reshape(-1, val_logits.size(-1)), val_targets.reshape(-1)) val_losses.append(val_loss.item()) val_loss = sum(val_losses) / len(val_losses) - if args.track: - wandb.log({"tr_val_loss": val_loss, "2_step": batches_done}) + wandb.log({"tr_val_loss": val_loss, "2_step": batches_done}) if val_loss < best_val - args.early_stopping_delta: best_val = val_loss patience_counter = 0 + + # Save best model (overwrite locally) + if args.save_model: + ckpt_tr = { + "epoch": epoch, + "batches_done": batches_done, + "transformer": transformer.state_dict(), + "optimizer_transformer": opt_transformer.state_dict(), + "loss": loss.item(), + "val_loss": val_loss, + } + th.save(ckpt_tr, "transformer.pth") else: patience_counter += 1 if patience_counter >= patience: @@ -1888,4 +1879,22 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: break transformer.train() + # -------------- + # Save model + # -------------- + if args.track and args.save_model: + if not args.early_stopping: + ckpt_tr = { + "epoch": epoch, + "batches_done": batches_done, + "transformer": transformer.state_dict(), + "optimizer_transformer": opt_transformer.state_dict(), + "loss": loss.item(), + } + th.save(ckpt_tr, "transformer.pth") + + artifact_tr = wandb.Artifact(f"{args.problem_id}_{args.algo}_transformer", type="model") + artifact_tr.add_file("transformer.pth") + wandb.log_artifact(artifact_tr, aliases=[f"seed_{args.seed}"]) + wandb.finish() From a0437b4dde34f2014d0e39bc41625052a3eebf0a Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Sat, 18 Oct 2025 21:59:14 +0200 Subject: [PATCH 19/22] clarify var names, add comments, other minor fixes --- .pre-commit-config.yaml | 2 +- engiopt/vqgan/evaluate_vqgan.py | 46 ++++----- engiopt/vqgan/vqgan.py | 170 +++++++++++++++++--------------- pyproject.toml | 2 +- 4 files changed, 114 insertions(+), 106 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 23a0869..fa1f04f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: name: pyright entry: pyright language: node - pass_filenames: true + pass_filenames: false types: [python] additional_dependencies: ["pyright@1.1.347"] args: diff --git a/engiopt/vqgan/evaluate_vqgan.py b/engiopt/vqgan/evaluate_vqgan.py index 6f00087..40c19ff 100644 --- a/engiopt/vqgan/evaluate_vqgan.py +++ b/engiopt/vqgan/evaluate_vqgan.py @@ -64,36 +64,38 @@ class Args: # Restores the pytorch model from wandb if args.wandb_entity is not None: - artifact_path_0 = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}" - artifact_path_1 = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}" - artifact_path_2 = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}" + artifact_path_cvqgan = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}" + artifact_path_vqgan = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}" + artifact_path_transformer = ( + f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}" + ) else: - artifact_path_0 = f"{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}" - artifact_path_1 = f"{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}" - artifact_path_2 = f"{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}" + artifact_path_cvqgan = f"{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}" + artifact_path_vqgan = f"{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}" + artifact_path_transformer = f"{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}" api = wandb.Api() - artifact_0 = api.artifact(artifact_path_0, type="model") - artifact_1 = api.artifact(artifact_path_1, type="model") - artifact_2 = api.artifact(artifact_path_2, type="model") + artifact_cvqgan = api.artifact(artifact_path_cvqgan, type="model") + artifact_vqgan = api.artifact(artifact_path_vqgan, type="model") + artifact_transformer = api.artifact(artifact_path_transformer, type="model") class RunRetrievalError(ValueError): def __init__(self): super().__init__("Failed to retrieve the run") - run = artifact_2.logged_by() + run = artifact_transformer.logged_by() if run is None or not hasattr(run, "config"): raise RunRetrievalError - artifact_dir_0 = artifact_0.download() - artifact_dir_1 = artifact_1.download() - artifact_dir_2 = artifact_2.download() + artifact_dir_cvqgan = artifact_cvqgan.download() + artifact_dir_vqgan = artifact_vqgan.download() + artifact_dir_transformer = artifact_transformer.download() - ckpt_path_0 = os.path.join(artifact_dir_0, "cvqgan.pth") - ckpt_path_1 = os.path.join(artifact_dir_1, "vqgan.pth") - ckpt_path_2 = os.path.join(artifact_dir_2, "transformer.pth") - ckpt_0 = th.load(ckpt_path_0, map_location=th.device(device), weights_only=False) - ckpt_1 = th.load(ckpt_path_1, map_location=th.device(device), weights_only=False) - ckpt_2 = th.load(ckpt_path_2, map_location=th.device(device), weights_only=False) + ckpt_path_cvqgan = os.path.join(artifact_dir_cvqgan, "cvqgan.pth") + ckpt_path_vqgan = os.path.join(artifact_dir_vqgan, "vqgan.pth") + ckpt_path_transformer = os.path.join(artifact_dir_transformer, "transformer.pth") + ckpt_cvqgan = th.load(ckpt_path_cvqgan, map_location=th.device(device), weights_only=False) + ckpt_vqgan = th.load(ckpt_path_vqgan, map_location=th.device(device), weights_only=False) + ckpt_transformer = th.load(ckpt_path_transformer, map_location=th.device(device), weights_only=False) vqgan = VQGAN( device=device, @@ -110,7 +112,7 @@ def __init__(self): latent_dim=run.config["latent_dim"], num_codebook_vectors=run.config["num_codebook_vectors"], ) - vqgan.load_state_dict(ckpt_1["vqgan"]) + vqgan.load_state_dict(ckpt_vqgan["vqgan"]) vqgan.eval() # Set to evaluation mode vqgan.to(device) @@ -123,7 +125,7 @@ def __init__(self): cond_latent_dim=run.config["cond_latent_dim"], cond_codebook_vectors=run.config["cond_codebook_vectors"], ) - cvqgan.load_state_dict(ckpt_0["cvqgan"]) + cvqgan.load_state_dict(ckpt_cvqgan["cvqgan"]) cvqgan.eval() # Set to evaluation mode cvqgan.to(device) @@ -140,7 +142,7 @@ def __init__(self): n_embd=run.config["n_embd"], dropout=run.config["dropout"], ) - model.load_state_dict(ckpt_2["transformer"]) + model.load_state_dict(ckpt_transformer["transformer"]) model.eval() # Set to evaluation mode model.to(device) diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index b447f0c..497b6ae 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -91,9 +91,9 @@ class Args: """number of vectors in the CVQGAN codebook""" cond_feature_map_dim: int = 4 """feature map dimension for the CVQGAN encoder output""" - batch_size_0: int = 16 + batch_size_cvqgan: int = 16 """size of the batches for CVQGAN""" - n_epochs_0: int = 1000 # Default: 1000 + n_epochs_cvqgan: int = 1000 # Default: 1000 """number of epochs of CVQGAN training""" cond_lr: float = 2e-4 # Default: 2e-4 """learning rate for CVQGAN""" @@ -104,11 +104,11 @@ class Args: # Algorithm-specific: Stage 1 (AE) # From original implementation: assume image_channels=1, use greyscale LPIPS only, use_Online=True, determine image_size automatically, calculate decoder_start_resolution automatically - n_epochs_1: int = 100 # Default: 100 + n_epochs_vqgan: int = 100 # Default: 100 """number of epochs of training""" - batch_size_1: int = 16 + batch_size_vqgan: int = 16 """size of the batches for Stage 1""" - lr_1: float = 5e-5 # Default: 2e-4 + lr_vqgan: float = 5e-5 # Default: 2e-4 """learning rate for Stage 1""" beta: float = 0.25 """beta hyperparameter for the codebook commitment loss""" @@ -142,12 +142,12 @@ class Args: """tuple of resolutions at which to apply attention in the decoder""" decoder_num_res_blocks: int = 3 """number of residual blocks per decoder layer""" - sample_interval_1: int = 100 + sample_interval_vqgan: int = 100 """interval between Stage 1 image samples""" # Algorithm-specific: Stage 2 (Transformer) # From original implementation: assume pkeep=1.0, sos_token=0, bias=True - n_epochs_2: int = 100 # Default: 100 + n_epochs_transformer: int = 100 # Default: 100 """number of epochs of training""" early_stopping: bool = True """whether to use early stopping for the transformer; if True requires args.track to be True""" @@ -155,9 +155,9 @@ class Args: """number of epochs with no improvement after which training will be stopped""" early_stopping_delta: float = 1e-3 """minimum change in the monitored quantity to qualify as an improvement""" - batch_size_2: int = 16 + batch_size_transformer: int = 16 """size of the batches for Stage 2""" - lr_2: float = 6e-4 # Default: 6e-4 + lr_transformer: float = 6e-4 # Default: 6e-4 """learning rate for Stage 2""" n_layer: int = 12 """number of layers in the transformer""" @@ -167,7 +167,7 @@ class Args: """transformer embedding dimension""" dropout: float = 0.3 """dropout rate in the transformer""" - sample_interval_2: int = 100 + sample_interval_transformer: int = 100 """interval between Stage 2 image samples""" @@ -370,14 +370,14 @@ def __init__(self, in_channels: int, out_channels: int): self.block = nn.Sequential( GroupNorm(in_channels), Swish(), - nn.Conv2d(in_channels, out_channels, 3, 1, 1), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), GroupNorm(out_channels), Swish(), - nn.Conv2d(out_channels, out_channels, 3, 1, 1), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), ) if in_channels != out_channels: - self.channel_up = nn.Conv2d(in_channels, out_channels, 1, 1, 0) + self.channel_up = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x: th.Tensor) -> th.Tensor: if self.in_channels != self.out_channels: @@ -394,7 +394,7 @@ class UpSampleBlock(nn.Module): def __init__(self, channels: int): super().__init__() - self.conv = nn.Conv2d(channels, channels, 3, 1, 1) + self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) def forward(self, x: th.Tensor) -> th.Tensor: x = f.interpolate(x, scale_factor=2.0) @@ -410,7 +410,7 @@ class DownSampleBlock(nn.Module): def __init__(self, channels: int): super().__init__() - self.conv = nn.Conv2d(channels, channels, 3, 2, 0) + self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=0) def forward(self, x: th.Tensor) -> th.Tensor: pad = (0, 1, 0, 1) @@ -430,10 +430,10 @@ def __init__(self, channels: int): self.in_channels = channels self.gn = GroupNorm(channels) - self.q = nn.Conv2d(channels, channels, 1, 1, 0) - self.k = nn.Conv2d(channels, channels, 1, 1, 0) - self.v = nn.Conv2d(channels, channels, 1, 1, 0) - self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0) + self.q = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) def forward(self, x: th.Tensor) -> th.Tensor: h_ = self.gn(x) @@ -479,6 +479,9 @@ def forward(self, x: th.Tensor) -> th.Tensor: class Encoder(nn.Module): """Encoder module for VQGAN Stage 1. + # Simplified architecture: image -> conv -> [resblock -> attn? -> downsample]* -> norm -> swish -> final conv -> latent image + Where `?` indicates a block that is only included at certain resolutions and `*` indicates a block that is repeated. + Consists of a series of convolutional, residual, and attention blocks arranged using the provided arguments. The number of downsample blocks is determined by the length of the encoder channels tuple minus two. For example, if encoder_channels=(128, 128, 128, 128) and the starting resolution is 128, the encoder will downsample the input image twice, from 128x128 to 32x32. @@ -504,7 +507,7 @@ def __init__( # noqa: PLR0913 super().__init__() channels = encoder_channels resolution = encoder_start_resolution - layers = [nn.Conv2d(image_channels, channels[0], 3, 1, 1)] + layers = [nn.Conv2d(image_channels, channels[0], kernel_size=3, stride=1, padding=1)] for i in range(len(channels) - 1): in_channels = channels[i] out_channels = channels[i + 1] @@ -521,7 +524,7 @@ def __init__( # noqa: PLR0913 layers.append(ResidualBlock(channels[-1], channels[-1])) layers.append(GroupNorm(channels[-1])) layers.append(Swish()) - layers.append(nn.Conv2d(channels[-1], latent_dim, 3, 1, 1)) + layers.append(nn.Conv2d(channels[-1], latent_dim, kernel_size=3, stride=1, padding=1)) self.model = nn.Sequential(*layers) def forward(self, x: th.Tensor) -> th.Tensor: @@ -556,6 +559,9 @@ def forward(self, x: th.Tensor) -> th.Tensor: class Decoder(nn.Module): """Decoder module for VQGAN Stage 1. + Simplified architecture: latent image -> conv -> [resblock -> attn? -> upsample]* -> norm -> swish -> final conv -> image + Where `?` indicates a block that is only included at certain resolutions and `*` indicates a block that is repeated. + Consists of a series of convolutional, residual, and attention blocks arranged using the provided arguments. The number of upsample blocks is determined by the length of the decoder channels tuple minus one. For example, if decoder_channels=(128, 128, 128) and the starting resolution is 32, the decoder will upsample the input image twice, from 32x32 to 128x128. @@ -582,7 +588,7 @@ def __init__( # noqa: PLR0913 in_channels = decoder_channels[0] resolution = decoder_start_resolution layers = [ - nn.Conv2d(latent_dim, in_channels, 3, 1, 1), + nn.Conv2d(latent_dim, in_channels, kernel_size=3, stride=1, padding=1), ResidualBlock(in_channels, in_channels), NonLocalBlock(in_channels), ResidualBlock(in_channels, in_channels), @@ -602,7 +608,7 @@ def __init__( # noqa: PLR0913 layers.append(GroupNorm(in_channels)) layers.append(Swish()) - layers.append(nn.Conv2d(in_channels, image_channels, 3, 1, 1)) + layers.append(nn.Conv2d(in_channels, image_channels, kernel_size=3, stride=1, padding=1)) self.model = nn.Sequential(*layers) def forward(self, x: th.Tensor) -> th.Tensor: @@ -905,8 +911,8 @@ def __init__( # noqa: PLR0913 self.decoder = CondDecoder(cond_latent_dim, cond_dim, cond_hidden_dim, cond_feature_map_dim).to(device=device) - self.quant_conv = nn.Conv2d(cond_latent_dim, cond_latent_dim, 1).to(device=device) - self.post_quant_conv = nn.Conv2d(cond_latent_dim, cond_latent_dim, 1).to(device=device) + self.quant_conv = nn.Conv2d(cond_latent_dim, cond_latent_dim, kernel_size=1).to(device=device) + self.post_quant_conv = nn.Conv2d(cond_latent_dim, cond_latent_dim, kernel_size=1).to(device=device) else: self.encoder = Encoder( encoder_channels, @@ -926,8 +932,8 @@ def __init__( # noqa: PLR0913 latent_dim, ).to(device=device) - self.quant_conv = nn.Conv2d(latent_dim, latent_dim, 1).to(device=device) - self.post_quant_conv = nn.Conv2d(latent_dim, latent_dim, 1).to(device=device) + self.quant_conv = nn.Conv2d(latent_dim, latent_dim, kernel_size=1).to(device=device) + self.post_quant_conv = nn.Conv2d(latent_dim, latent_dim, kernel_size=1).to(device=device) self.codebook = Codebook( num_codebook_vectors=cond_codebook_vectors if is_c else num_codebook_vectors, @@ -1339,8 +1345,8 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu random.seed(args.seed) th.backends.cudnn.deterministic = True - os.makedirs("images/vqgan_1", exist_ok=True) - os.makedirs("images/vqgan_2", exist_ok=True) + os.makedirs("images/vqgan", exist_ok=True) + os.makedirs("images/transformer", exist_ok=True) if th.backends.mps.is_available(): device = th.device("mps") @@ -1390,19 +1396,19 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu th.as_tensor(training_ds["optimal_upsampled"][:]).to(device), *[th.as_tensor(training_ds[key][:]).to(device) for key in conditions], ) - dataloader_0 = th.utils.data.DataLoader( + dataloader_cvqgan = th.utils.data.DataLoader( th_training_ds, - batch_size=args.batch_size_0, + batch_size=args.batch_size_cvqgan, shuffle=True, ) - dataloader_1 = th.utils.data.DataLoader( + dataloader_vqgan = th.utils.data.DataLoader( th_training_ds, - batch_size=args.batch_size_1, + batch_size=args.batch_size_vqgan, shuffle=True, ) - dataloader_2 = th.utils.data.DataLoader( + dataloader_transformer = th.utils.data.DataLoader( th_training_ds, - batch_size=args.batch_size_2, + batch_size=args.batch_size_transformer, shuffle=True, ) @@ -1441,7 +1447,7 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu ) dataloader_val = th.utils.data.DataLoader( th_val_ds, - batch_size=args.batch_size_2, + batch_size=args.batch_size_transformer, shuffle=False, ) @@ -1466,18 +1472,18 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu name=run_name, dir="./logs/wandb", ) - wandb.define_metric("0_step", summary="max") - wandb.define_metric("cvq_loss", step_metric="0_step") - wandb.define_metric("epoch_0", step_metric="0_step") - wandb.define_metric("1_step", summary="max") - wandb.define_metric("vq_loss", step_metric="1_step") - wandb.define_metric("d_loss", step_metric="1_step") - wandb.define_metric("epoch_1", step_metric="1_step") - wandb.define_metric("2_step", summary="max") - wandb.define_metric("tr_loss", step_metric="2_step") - wandb.define_metric("epoch_2", step_metric="2_step") + wandb.define_metric("cvqgan_step", summary="max") + wandb.define_metric("cvqgan_loss", step_metric="cvqgan_step") + wandb.define_metric("epoch_cvqgan", step_metric="cvqgan_step") + wandb.define_metric("vqgan_step", summary="max") + wandb.define_metric("vqgan_loss", step_metric="vqgan_step") + wandb.define_metric("discriminator_loss", step_metric="vqgan_step") + wandb.define_metric("epoch_vqgan", step_metric="vqgan_step") + wandb.define_metric("transformer_step", summary="max") + wandb.define_metric("transformer_loss", step_metric="transformer_step") + wandb.define_metric("epoch_transformer", step_metric="transformer_step") if args.early_stopping: - wandb.define_metric("tr_val_loss", step_metric="2_step") + wandb.define_metric("transformer_val_loss", step_metric="transformer_step") vqgan = VQGAN( device=device, @@ -1540,12 +1546,12 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu + list(vqgan.codebook.parameters()) + list(vqgan.quant_conv.parameters()) + list(vqgan.post_quant_conv.parameters()), - lr=args.lr_1, + lr=args.lr_vqgan, eps=1e-08, betas=(args.b1, args.b2), ) # VQGAN Stage 1 discriminator optimizer - opt_disc = th.optim.Adam(discriminator.parameters(), lr=args.lr_1, eps=1e-08, betas=(args.b1, args.b2)) + opt_disc = th.optim.Adam(discriminator.parameters(), lr=args.lr_vqgan, eps=1e-08, betas=(args.b1, args.b2)) # Transformer Stage 2 optimizer decay, no_decay = set(), set() @@ -1576,12 +1582,12 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu {"params": [param_dict[pn] for pn in sorted(no_decay)], "weight_decay": 0.0}, ] - opt_transformer = th.optim.AdamW(optim_groups, lr=args.lr_2, betas=(0.9, 0.95)) + opt_transformer = th.optim.AdamW(optim_groups, lr=args.lr_transformer, betas=(0.9, 0.95)) perceptual_loss_fcn = GreyscaleLPIPS().eval().to(device) @th.no_grad() - def sample_designs_1(n_designs: int) -> list[th.Tensor]: + def sample_designs_vqgan(n_designs: int) -> list[th.Tensor]: """Sample reconstructions from trained VQGAN Stage 1.""" vqgan.eval() @@ -1593,7 +1599,7 @@ def sample_designs_1(n_designs: int) -> list[th.Tensor]: return reconstructions @th.no_grad() - def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: + def sample_designs_transformer(n_designs: int) -> tuple[th.Tensor, th.Tensor]: """Sample generated designs from trained VQGAN Stage 2.""" transformer.eval() @@ -1624,8 +1630,8 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: if args.conditional: print("Stage 0: Training CVQGAN") cvqgan.train() - for epoch in tqdm.trange(args.n_epochs_0): - for i, data in enumerate(dataloader_0): + for epoch in tqdm.trange(args.n_epochs_cvqgan): + for i, data in enumerate(dataloader_cvqgan): # THIS IS PROBLEM DEPENDENT conds = th.stack((data[1:]), dim=1).to(dtype=th.float32, device=device) decoded_images, codebook_indices, q_loss = cvqgan(conds) @@ -1640,17 +1646,17 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: # Logging # ---------- if args.track: - batches_done = epoch * len(dataloader_0) + i - wandb.log({"cvq_loss": cvq_loss.item(), "0_step": batches_done}) - wandb.log({"epoch_0": epoch, "0_step": batches_done}) + batches_done = epoch * len(dataloader_cvqgan) + i + wandb.log({"cvqgan_loss": cvq_loss.item(), "cvqgan_step": batches_done}) + wandb.log({"epoch_cvqgan": epoch, "cvqgan_step": batches_done}) print( - f"[Epoch {epoch}/{args.n_epochs_0}] [Batch {i}/{len(dataloader_0)}] [CVQ loss: {cvq_loss.item()}]" + f"[Epoch {epoch}/{args.n_epochs_cvqgan}] [Batch {i}/{len(dataloader_cvqgan)}] [CVQ loss: {cvq_loss.item()}]" ) # -------------- # Save model # -------------- - if args.save_model and epoch == args.n_epochs_0 - 1 and i == len(dataloader_0) - 1: + if args.save_model and epoch == args.n_epochs_cvqgan - 1 and i == len(dataloader_cvqgan) - 1: ckpt_cvq = { "epoch": epoch, "batches_done": batches_done, @@ -1675,8 +1681,8 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: print("Stage 1: Training VQGAN") vqgan.train() discriminator.train() - for epoch in tqdm.trange(args.n_epochs_1): - for i, data in enumerate(dataloader_1): + for epoch in tqdm.trange(args.n_epochs_vqgan): + for i, data in enumerate(dataloader_vqgan): # THIS IS PROBLEM DEPENDENT designs = data[0].to(dtype=th.float32, device=device) decoded_images, codebook_indices, q_loss = vqgan(designs) @@ -1712,19 +1718,19 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: # Logging # ---------- if args.track: - batches_done = epoch * len(dataloader_1) + i - wandb.log({"vq_loss": vq_loss.item(), "1_step": batches_done}) - wandb.log({"d_loss": gan_loss.item(), "1_step": batches_done}) - wandb.log({"epoch_1": epoch, "1_step": batches_done}) + batches_done = epoch * len(dataloader_vqgan) + i + wandb.log({"vqgan_loss": vq_loss.item(), "vqgan_step": batches_done}) + wandb.log({"discriminator_loss": gan_loss.item(), "vqgan_step": batches_done}) + wandb.log({"epoch_vqgan": epoch, "vqgan_step": batches_done}) print( - f"[Epoch {epoch}/{args.n_epochs_1}] [Batch {i}/{len(dataloader_1)}] [D loss: {gan_loss.item()}] [VQ loss: {vq_loss.item()}]" + f"[Epoch {epoch}/{args.n_epochs_vqgan}] [Batch {i}/{len(dataloader_vqgan)}] [D loss: {gan_loss.item()}] [VQ loss: {vq_loss.item()}]" ) # This saves a grid image of 25 generated designs every sample_interval - if batches_done % args.sample_interval_1 == 0: + if batches_done % args.sample_interval_vqgan == 0: # Extract 25 designs designs = resize_to( - data=sample_designs_1(n_designs=n_logged_designs), h=design_shape[0], w=design_shape[1] + data=sample_designs_vqgan(n_designs=n_logged_designs), h=design_shape[0], w=design_shape[1] ) fig, axes = plt.subplots(5, 5, figsize=(12, 12)) @@ -1740,15 +1746,15 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: axes[j].set_yticks([]) # Hide y ticks plt.tight_layout() - img_fname = f"images/vqgan_1/{batches_done}.png" + img_fname = f"images/vqgan/{batches_done}.png" plt.savefig(img_fname) plt.close() - wandb.log({"designs_1": wandb.Image(img_fname)}) + wandb.log({"designs_vqgan": wandb.Image(img_fname)}) # -------------- # Save models # -------------- - if args.save_model and epoch == args.n_epochs_1 - 1 and i == len(dataloader_1) - 1: + if args.save_model and epoch == args.n_epochs_vqgan - 1 and i == len(dataloader_vqgan) - 1: ckpt_vq = { "epoch": epoch, "batches_done": batches_done, @@ -1791,8 +1797,8 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: patience_counter = 0 patience = args.early_stopping_patience - for epoch in tqdm.trange(args.n_epochs_2): - for i, data in enumerate(dataloader_2): + for epoch in tqdm.trange(args.n_epochs_transformer): + for i, data in enumerate(dataloader_transformer): # THIS IS PROBLEM DEPENDENT designs = data[0].to(dtype=th.float32, device=device) conds = th.stack((data[1:]), dim=1).to(dtype=th.float32, device=device) @@ -1807,17 +1813,17 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: # Logging # ---------- if args.track: - batches_done = epoch * len(dataloader_2) + i - wandb.log({"tr_loss": loss.item(), "2_step": batches_done}) - wandb.log({"epoch_2": epoch, "2_step": batches_done}) + batches_done = epoch * len(dataloader_transformer) + i + wandb.log({"transformer_loss": loss.item(), "transformer_step": batches_done}) + wandb.log({"epoch_transformer": epoch, "transformer_step": batches_done}) print( - f"[Epoch {epoch}/{args.n_epochs_2}] [Batch {i}/{len(dataloader_2)}] [Transformer loss: {loss.item()}]" + f"[Epoch {epoch}/{args.n_epochs_transformer}] [Batch {i}/{len(dataloader_transformer)}] [Transformer loss: {loss.item()}]" ) # This saves a grid image of 25 generated designs every sample_interval - if batches_done % args.sample_interval_2 == 0: + if batches_done % args.sample_interval_transformer == 0: # Extract 25 designs - desired_conds, designs = sample_designs_2(n_designs=n_logged_designs) + desired_conds, designs = sample_designs_transformer(n_designs=n_logged_designs) if args.normalize_conditions: desired_conds = (desired_conds.cpu() * std) + mean designs = resize_to(data=designs, h=design_shape[0], w=design_shape[1]) @@ -1838,10 +1844,10 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: axes[j].set_yticks([]) # Hide y ticks plt.tight_layout() - img_fname = f"images/vqgan_2/{batches_done}.png" + img_fname = f"images/transformer/{batches_done}.png" plt.savefig(img_fname) plt.close() - wandb.log({"designs_2": wandb.Image(img_fname)}) + wandb.log({"designs_transformer": wandb.Image(img_fname)}) # Early stopping based on held-out validation loss if args.track and args.early_stopping: @@ -1855,7 +1861,7 @@ def sample_designs_2(n_designs: int) -> tuple[th.Tensor, th.Tensor]: val_loss = f.cross_entropy(val_logits.reshape(-1, val_logits.size(-1)), val_targets.reshape(-1)) val_losses.append(val_loss.item()) val_loss = sum(val_losses) / len(val_losses) - wandb.log({"tr_val_loss": val_loss, "2_step": batches_done}) + wandb.log({"transformer_val_loss": val_loss, "transformer_step": batches_done}) if val_loss < best_val - args.early_stopping_delta: best_val = val_loss diff --git a/pyproject.toml b/pyproject.toml index 5ff340a..0f637e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,7 +224,7 @@ strict = [] typeCheckingMode = "basic" pythonVersion = "3.9" pythonPlatform = "All" -# typeshedPath = "typeshed" -> commented out may lead to precommit out of memory error +typeshedPath = "typeshed" enableTypeIgnoreComments = true # This is required as the CI pre-commit does not download the module (i.e. numpy, pygame) From a29fd33b18d07d9004c38817fc2d27a491a9715d Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Sat, 18 Oct 2025 22:44:37 +0200 Subject: [PATCH 20/22] move helper blocks and utils to their own file --- engiopt/vqgan/utils.py | 717 +++++++++++++++++++++++++++++++++++++++++ engiopt/vqgan/vqgan.py | 717 +---------------------------------------- 2 files changed, 729 insertions(+), 705 deletions(-) create mode 100644 engiopt/vqgan/utils.py diff --git a/engiopt/vqgan/utils.py b/engiopt/vqgan/utils.py new file mode 100644 index 0000000..ba03c7d --- /dev/null +++ b/engiopt/vqgan/utils.py @@ -0,0 +1,717 @@ +"""Architectural blocks and other utils for the Vector Quantized Generative Adversarial Network (VQGAN).""" + +from __future__ import annotations + +from dataclasses import dataclass +import math +import os +import warnings + +from einops import rearrange +import requests +import torch as th +from torch import nn +from torch.nn import functional as f +from torchvision.models import vgg16 +from torchvision.models import VGG16_Weights +import tqdm + +# URL and checkpoint for LPIPS model +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} + +CKPT_MAP = {"vgg_lpips": "vgg.pth"} + + +class Codebook(nn.Module): + """Improved version over vector quantizer, with the dynamic initialization for the unoptimized "dead" vectors. + + Parameters: + num_codebook_vectors (int): number of codebook entries + latent_dim (int): dimensionality of codebook entries + beta (float): weight for the commitment loss + decay (float): decay for the moving average of code usage + distance (str): distance type for looking up the closest code + anchor (str): anchor sampling methods + first_batch (bool): if true, the offline version of the model + contras_loss (bool): if true, use the contras_loss to further improve the performance + init (bool): if true, the codebook has been initialized + """ + + def __init__( # noqa: PLR0913 + self, + *, + num_codebook_vectors: int, + latent_dim: int, + beta: float = 0.25, + decay: float = 0.99, + distance: str = "cos", + anchor: str = "probrandom", + first_batch: bool = False, + contras_loss: bool = False, + init: bool = False, + ): + super().__init__() + + self.num_embed = num_codebook_vectors + self.embed_dim = latent_dim + self.beta = beta + self.decay = decay + self.distance = distance + self.anchor = anchor + self.first_batch = first_batch + self.contras_loss = contras_loss + self.init = init + + self.pool = FeaturePool(self.num_embed, self.embed_dim) + self.embedding = nn.Embedding(self.num_embed, self.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed) + self.register_buffer("embed_prob", th.zeros(self.num_embed)) + + def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, "b c h w -> b h w c").contiguous() + z_flattened = z.view(-1, self.embed_dim) + + # clculate the distance + if self.distance == "l2": + # l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + -th.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) + - th.sum(self.embedding.weight**2, dim=1) + + 2 * th.einsum("bd, dn-> bn", z_flattened.detach(), rearrange(self.embedding.weight, "n d-> d n")) + ) + elif self.distance == "cos": + # cosine distances from z to embeddings e_j + normed_z_flattened = f.normalize(z_flattened, dim=1).detach() + normed_codebook = f.normalize(self.embedding.weight, dim=1) + d = th.einsum("bd,dn->bn", normed_z_flattened, rearrange(normed_codebook, "n d -> d n")) + + # encoding + sort_distance, indices = d.sort(dim=1) + # look up the closest point for the indices + encoding_indices = indices[:, -1] + encodings = th.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device) + encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) + + # quantize and unflatten + z_q = th.matmul(encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = self.beta * th.mean((z_q.detach() - z) ** 2) + th.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + # reshape back to match original input shape + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + # count + avg_probs = th.mean(encodings, dim=0) + perplexity = th.exp(-th.sum(avg_probs * th.log(avg_probs + 1e-10))) + min_encodings = encodings + + # online clustered reinitialization for unoptimized points + if self.training: + # calculate the average usage of code entries + self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1 - self.decay) + # running average updates + if self.anchor in ["closest", "random", "probrandom"] and (not self.init): + # closest sampling + if self.anchor == "closest": + sort_distance, indices = d.sort(dim=0) + random_feat = z_flattened.detach()[indices[-1, :]] + # feature pool based random sampling + elif self.anchor == "random": + random_feat = self.pool.query(z_flattened.detach()) + # probabilitical based random sampling + elif self.anchor == "probrandom": + norm_distance = f.softmax(d.t(), dim=1) + prob = th.multinomial(norm_distance, num_samples=1).view(-1) + random_feat = z_flattened.detach()[prob] + # decay parameter based on the average usage + decay = ( + th.exp(-(self.embed_prob * self.num_embed * 10) / (1 - self.decay) - 1e-3) + .unsqueeze(1) + .repeat(1, self.embed_dim) + ) + self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay + if self.first_batch: + self.init = True + # contrastive loss + if self.contras_loss: + sort_distance, indices = d.sort(dim=0) + dis_pos = sort_distance[-max(1, int(sort_distance.size(0) / self.num_embed)) :, :].mean(dim=0, keepdim=True) + dis_neg = sort_distance[: int(sort_distance.size(0) * 1 / 2), :] + dis = th.cat([dis_pos, dis_neg], dim=0).t() / 0.07 + contra_loss = f.cross_entropy(dis, th.zeros((dis.size(0),), dtype=th.long, device=dis.device)) + loss += contra_loss + + return z_q, encoding_indices, loss, min_encodings, perplexity + + +class FeaturePool: + """Implements a feature buffer that stores previously encoded features. + + This buffer enables us to initialize the codebook using a history of generated features rather than the ones produced by the latest encoders. + + Parameters: + pool_size (int): the size of feature buffer + dim (int): the dimension of each feature + """ + + def __init__(self, pool_size: int, dim: int = 64): + self.pool_size = pool_size + if self.pool_size > 0: + self.nums_features = 0 + self.features = (th.rand((pool_size, dim)) * 2 - 1) / pool_size + + def query(self, features: th.Tensor) -> th.Tensor: + """Return features from the pool.""" + self.features = self.features.to(features.device) + if self.nums_features < self.pool_size: + if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook + random_feat_id = th.randint(0, features.size(0), (int(self.pool_size),)) + self.features = features[random_feat_id] + self.nums_features = self.pool_size + else: + # if the mini-batch is not large nuough, just store it for the next update + num = self.nums_features + features.size(0) + self.features[self.nums_features : num] = features + self.nums_features = num + elif features.size(0) > int(self.pool_size): + random_feat_id = th.randint(0, features.size(0), (int(self.pool_size),)) + self.features = features[random_feat_id] + else: + random_id = th.randperm(self.pool_size) + self.features[random_id[: features.size(0)]] = features + + return self.features + + +class Discriminator(nn.Module): + """PatchGAN-style discriminator. + + Adapted from: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538 + This assumes we never use a discriminator for the CVQGAN, since it is generally a much simpler model. + + Parameters: + num_filters_last: Number of filters in the last conv layer. + n_layers: Number of convolutional layers. + image_channels: Number of channels in the input image. + """ + + def __init__(self, num_filters_last: int = 64, n_layers: int = 3, image_channels: int = 1): + super().__init__() + + # Convolutional backbone (PatchGAN) + layers: list[nn.Module] = [ + nn.Conv2d(image_channels, num_filters_last, kernel_size=4, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ] + num_filters_mult = 1 + + for i in range(1, n_layers + 1): + num_filters_mult_last = num_filters_mult + num_filters_mult = min(2**i, 8) + layers += [ + nn.Conv2d( + num_filters_last * num_filters_mult_last, + num_filters_last * num_filters_mult, + kernel_size=4, + stride=2 if i < n_layers else 1, + padding=1, + bias=False, + ), + nn.BatchNorm2d(num_filters_last * num_filters_mult), + nn.LeakyReLU(0.2, inplace=True), + ] + + layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, kernel_size=4, stride=1, padding=1)) + self.model = nn.Sequential(*layers) + + # Initialize weights + self.apply(self._weights_init) + + @staticmethod + def _weights_init(m: nn.Module) -> None: + """Custom weight initialization (DCGAN-style).""" + classname = m.__class__.__name__ + if "Conv" in classname: + nn.init.normal_(m.weight.data, mean=0.0, std=0.02) + elif "BatchNorm" in classname: + nn.init.normal_(m.weight.data, mean=1.0, std=0.02) + nn.init.constant_(m.bias.data, 0.0) + + def forward(self, x: th.Tensor) -> th.Tensor: + """Forward pass with optional CVQGAN adapter.""" + return self.model(x) + + +class GroupNorm(nn.Module): + """Group Normalization block to be used in VQGAN Encoder and Decoder. + + Parameters: + channels (int): number of channels in the input feature map + """ + + def __init__(self, channels: int): + super().__init__() + self.gn = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True) + + def forward(self, x: th.Tensor) -> th.Tensor: + return self.gn(x) + + +class Swish(nn.Module): + """Swish activation function to be used in VQGAN Encoder and Decoder.""" + + def forward(self, x: th.Tensor) -> th.Tensor: + return x * th.sigmoid(x) + + +class ResidualBlock(nn.Module): + """Residual block to be used in VQGAN Encoder and Decoder. + + Parameters: + in_channels (int): number of input channels + out_channels (int): number of output channels + """ + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.block = nn.Sequential( + GroupNorm(in_channels), + Swish(), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), + GroupNorm(out_channels), + Swish(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + + if in_channels != out_channels: + self.channel_up = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: th.Tensor) -> th.Tensor: + if self.in_channels != self.out_channels: + return self.channel_up(x) + self.block(x) + return x + self.block(x) + + +class UpSampleBlock(nn.Module): + """Up-sampling block to be used in VQGAN Decoder. + + Parameters: + channels (int): number of channels in the input feature map + """ + + def __init__(self, channels: int): + super().__init__() + self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: th.Tensor) -> th.Tensor: + x = f.interpolate(x, scale_factor=2.0) + return self.conv(x) + + +class DownSampleBlock(nn.Module): + """Down-sampling block to be used in VQGAN Encoder. + + Parameters: + channels (int): number of channels in the input feature map + """ + + def __init__(self, channels: int): + super().__init__() + self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: th.Tensor) -> th.Tensor: + pad = (0, 1, 0, 1) + x = f.pad(x, pad, mode="constant", value=0) + return self.conv(x) + + +class NonLocalBlock(nn.Module): + """Non-local attention block to be used in VQGAN Encoder and Decoder. + + Parameters: + channels (int): number of channels in the input feature map + """ + + def __init__(self, channels: int): + super().__init__() + self.in_channels = channels + + self.gn = GroupNorm(channels) + self.q = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: th.Tensor) -> th.Tensor: + h_ = self.gn(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + v = v.reshape(b, c, h * w) + + attn = th.bmm(q, k) + attn = attn * (int(c) ** (-0.5)) + attn = f.softmax(attn, dim=2) + attn = attn.permute(0, 2, 1) + + a = th.bmm(v, attn) + a = a.reshape(b, c, h, w) + + return x + a + + +class LinearCombo(nn.Module): + """Regular fully connected layer combo for the CVQGAN if enabled. + + Parameters: + in_features (int): number of input features + out_features (int): number of output features + alpha (float): negative slope for LeakyReLU + """ + + def __init__(self, in_features: int, out_features: int, alpha: float = 0.2): + super().__init__() + self.model = nn.Sequential(nn.Linear(in_features, out_features), nn.LeakyReLU(alpha)) + + def forward(self, x: th.Tensor) -> th.Tensor: + return self.model(x) + + +class ScalingLayer(nn.Module): + """Channel-wise affine normalization used by LPIPS.""" + + def __init__(self): + super().__init__() + self.register_buffer("shift", th.tensor([-0.030, -0.088, -0.188])[None, :, None, None]) + self.register_buffer("scale", th.tensor([0.458, 0.448, 0.450])[None, :, None, None]) + + def forward(self, x: th.Tensor) -> th.Tensor: + return (x - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """1x1 conv with dropout (per-layer LPIPS linear head).""" + + def __init__(self, in_channels: int, out_channels: int = 1): + super().__init__() + self.model = nn.Sequential( + nn.Dropout(), + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + return self.model(x) + + +class VGG16(nn.Module): + """Torchvision VGG16 feature extractor sliced at LPIPS tap points.""" + + def __init__(self): + super().__init__() + vgg_feats = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features + blocks = [vgg_feats[i] for i in range(30)] + self.slice1 = nn.Sequential(*blocks[0:4]) # relu1_2 + self.slice2 = nn.Sequential(*blocks[4:9]) # relu2_2 + self.slice3 = nn.Sequential(*blocks[9:16]) # relu3_3 + self.slice4 = nn.Sequential(*blocks[16:23]) # relu4_3 + self.slice5 = nn.Sequential(*blocks[23:30]) # relu5_3 + self.requires_grad_(requires_grad=False) + + def forward(self, x: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: + h1 = self.slice1(x) + h2 = self.slice2(h1) + h3 = self.slice3(h2) + h4 = self.slice4(h3) + h5 = self.slice5(h4) + return (h1, h2, h3, h4, h5) + + +class GreyscaleLPIPS(nn.Module): + """LPIPS for greyscale/topological data with optional 'raw' aggregation. + + ``use_raw=True`` is often preferable for non-natural images since learned + linear heads are tuned on natural RGB photos. + + Parameters: + use_raw: If True, average raw per-layer squared diffs (no linear heads). + clamp_output: Clamp the final loss to ``>= 0``. + robust_clamp: Clamp inputs to [0, 1] before feature extraction. + warn_on_clamp: If True, log warnings when inputs fall outside [0, 1]. + freeze: If True, disables grads on all params. + ckpt_name: Key in URL_MAP/CKPT_MAP for loading LPIPS heads. + """ + + def __init__( # noqa: PLR0913 + self, + *, + use_raw: bool = True, + clamp_output: bool = False, + robust_clamp: bool = True, + warn_on_clamp: bool = False, + freeze: bool = True, + ckpt_name: str = "vgg_lpips", + ): + super().__init__() + self.use_raw = use_raw + self.clamp_output = clamp_output + self.robust_clamp = robust_clamp + self.warn_on_clamp = warn_on_clamp + + self.scaling_layer = ScalingLayer() + self.channels = (64, 128, 256, 512, 512) + self.vgg = VGG16() + self.linears = nn.ModuleList([NetLinLayer(c) for c in self.channels]) + + self._load_from_pretrained(name=ckpt_name) + if freeze: + self.requires_grad_(requires_grad=False) + + def forward(self, real_x: th.Tensor, fake_x: th.Tensor) -> th.Tensor: + """Compute greyscale-aware LPIPS distance between two batches.""" + if self.robust_clamp: + real_x = th.clamp(real_x, 0.0, 1.0) + fake_x = th.clamp(fake_x, 0.0, 1.0) + + # Promote greyscale -> RGB for VGG features + if real_x.shape[1] == 1: + real_x = real_x.repeat(1, 3, 1, 1) + if fake_x.shape[1] == 1: + fake_x = fake_x.repeat(1, 3, 1, 1) + + fr = self.vgg(self.scaling_layer(real_x)) + ff = self.vgg(self.scaling_layer(fake_x)) + diffs = [(self._norm_tensor(a) - self._norm_tensor(b)) ** 2 for a, b in zip(fr, ff)] + + if self.use_raw: + parts = [self._spatial_average(d).mean(dim=1, keepdim=True) for d in diffs] + else: + parts = [self._spatial_average(self.linears[i](d)) for i, d in enumerate(diffs)] + + loss = th.stack(parts, dim=0).sum() + if self.clamp_output: + loss = th.clamp(loss, min=0.0) + return loss + + # Helpers + @staticmethod + def _norm_tensor(x: th.Tensor) -> th.Tensor: + """L2-normalize channels per spatial location: BxCxHxW -> BxCxHxW.""" + norm = th.sqrt(th.sum(x**2, dim=1, keepdim=True)) + return x / (norm + 1e-10) + + @staticmethod + def _spatial_average(x: th.Tensor) -> th.Tensor: + """Average over spatial dimensions with dims kept: BxCxHxW -> BxCx1x1.""" + return x.mean(dim=(2, 3), keepdim=True) + + def _load_from_pretrained(self, *, name: str) -> None: + """Load LPIPS linear heads (and any required buffers) from a checkpoint.""" + ckpt = self._get_ckpt_path(name, "vgg_lpips") + state_dict = th.load(ckpt, map_location=th.device("cpu"), weights_only=True) + self.load_state_dict(state_dict, strict=False) + + @staticmethod + def _download(url: str, local_path: str, *, chunk_size: int = 1024) -> None: + """Stream a file to disk with a progress bar.""" + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True, timeout=10) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm.tqdm(total=total_size, unit="B", unit_scale=True) as pbar, open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(len(data)) + + def _get_ckpt_path(self, name: str, root: str) -> str: + """Return local path to a pretrained LPIPS checkpoint; download if missing.""" + assert name in URL_MAP, f"Unknown LPIPS checkpoint name: {name!r}" + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path): + self._download(URL_MAP[name], path) + return path + + +########################################### +########## GPT-2 BASE CODE BELOW ########## +########################################### +class LayerNorm(nn.Module): + """LayerNorm with optional bias (PyTorch lacks bias=False support).""" + + def __init__(self, ndim: int, *, bias: bool): + super().__init__() + self.weight = nn.Parameter(th.ones(ndim)) + self.bias = nn.Parameter(th.zeros(ndim)) if bias else None + + def forward(self, x: th.Tensor) -> th.Tensor: + return f.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) + + +class CausalSelfAttention(nn.Module): + """Causal self-attention with FlashAttention fallback when unavailable.""" + + def __init__(self, config: GPTConfig): + super().__init__() + assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head" + + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + + self.flash = hasattr(f, "scaled_dot_product_attention") + if not self.flash: + warnings.warn( + "Falling back to non-flash attention; PyTorch >= 2.0 enables FlashAttention.", + stacklevel=2, + ) + self.register_buffer( + "bias", + th.tril(th.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size), + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + b, t, c = x.size() + + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(b, t, self.n_head, c // self.n_head).transpose(1, 2) + q = q.view(b, t, self.n_head, c // self.n_head).transpose(1, 2) + v = v.view(b, t, self.n_head, c // self.n_head).transpose(1, 2) + + if self.flash: + y = f.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + ) + else: + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float("-inf")) + att = f.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v + + y = y.transpose(1, 2).contiguous().view(b, t, c) + return self.resid_dropout(self.c_proj(y)) + + +class MLP(nn.Module): + """Feed-forward block used inside Transformer blocks.""" + + def __init__(self, config: GPTConfig): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.gelu = nn.GELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x: th.Tensor) -> th.Tensor: + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + return self.dropout(x) + + +class Block(nn.Module): + """Transformer block: LayerNorm -> Self-Attn -> residual; LayerNorm -> MLP -> residual.""" + + def __init__(self, config: GPTConfig): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + + def forward(self, x: th.Tensor) -> th.Tensor: + x = x + self.attn(self.ln_1(x)) + return x + self.mlp(self.ln_2(x)) + + +@dataclass +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50304 # GPT-2 uses 50257; padded to multiple of 64 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # GPT-2 uses biases in Linear/LayerNorm + + +class GPT(nn.Module): + """Minimal GPT-2 style Transformer.""" + + def __init__(self, config: GPTConfig): + super().__init__() + self.config = config + + self.transformer = nn.ModuleDict( + { + "wte": nn.Embedding(config.vocab_size, config.n_embd), + "wpe": nn.Embedding(config.block_size, config.n_embd), + "drop": nn.Dropout(config.dropout), + "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + "ln_f": LayerNorm(config.n_embd, bias=config.bias), + } + ) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.transformer["wte"].weight = self.lm_head.weight # weight tying + + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + th.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, nn.Linear): + th.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + th.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + th.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward( + self, + idx: th.Tensor, + targets: th.Tensor | None = None, + ) -> tuple[th.Tensor, th.Tensor | None]: + """Forward pass returning logits and optional cross-entropy loss.""" + device = idx.device + _, t = idx.size() + assert t <= self.config.block_size, f"Cannot forward sequence of length {t}; block size is {self.config.block_size}" + pos = th.arange(0, t, dtype=th.long, device=device) + + tok_emb = self.transformer["wte"](idx) + pos_emb = self.transformer["wpe"](pos) + x = self.transformer["drop"](tok_emb + pos_emb) + for block in self.transformer["h"]: + x = block(x) + x = self.transformer["ln_f"](x) + + logits = self.lm_head(x) + loss: th.Tensor | None + if targets is not None: + loss = f.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + else: + loss = None + return logits, loss + + +########################################### +########## GPT-2 BASE CODE ABOVE ########## +########################################### diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 497b6ae..92cb976 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -18,22 +18,17 @@ from __future__ import annotations from dataclasses import dataclass -import math import os import random import time import warnings -from einops import rearrange from engibench.utils.all_problems import BUILTIN_PROBLEMS import matplotlib.pyplot as plt import numpy as np -import requests import torch as th from torch import nn from torch.nn import functional as f -from torchvision.models import vgg16 -from torchvision.models import VGG16_Weights import tqdm import tyro import wandb @@ -41,11 +36,18 @@ from engiopt.transforms import drop_constant from engiopt.transforms import normalize from engiopt.transforms import resize_to - -# URL and checkpoint for LPIPS model -URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} - -CKPT_MAP = {"vgg_lpips": "vgg.pth"} +from engiopt.vqgan.utils import Codebook +from engiopt.vqgan.utils import Discriminator +from engiopt.vqgan.utils import DownSampleBlock +from engiopt.vqgan.utils import GPT +from engiopt.vqgan.utils import GPTConfig +from engiopt.vqgan.utils import GreyscaleLPIPS +from engiopt.vqgan.utils import GroupNorm +from engiopt.vqgan.utils import LinearCombo +from engiopt.vqgan.utils import NonLocalBlock +from engiopt.vqgan.utils import ResidualBlock +from engiopt.vqgan.utils import Swish +from engiopt.vqgan.utils import UpSampleBlock @dataclass @@ -171,311 +173,6 @@ class Args: """interval between Stage 2 image samples""" -class Codebook(nn.Module): - """Improved version over vector quantizer, with the dynamic initialization for the unoptimized "dead" vectors. - - Parameters: - num_codebook_vectors (int): number of codebook entries - latent_dim (int): dimensionality of codebook entries - beta (float): weight for the commitment loss - decay (float): decay for the moving average of code usage - distance (str): distance type for looking up the closest code - anchor (str): anchor sampling methods - first_batch (bool): if true, the offline version of the model - contras_loss (bool): if true, use the contras_loss to further improve the performance - init (bool): if true, the codebook has been initialized - """ - - def __init__( # noqa: PLR0913 - self, - *, - num_codebook_vectors: int, - latent_dim: int, - beta: float = 0.25, - decay: float = 0.99, - distance: str = "cos", - anchor: str = "probrandom", - first_batch: bool = False, - contras_loss: bool = False, - init: bool = False, - ): - super().__init__() - - self.num_embed = num_codebook_vectors - self.embed_dim = latent_dim - self.beta = beta - self.decay = decay - self.distance = distance - self.anchor = anchor - self.first_batch = first_batch - self.contras_loss = contras_loss - self.init = init - - self.pool = FeaturePool(self.num_embed, self.embed_dim) - self.embedding = nn.Embedding(self.num_embed, self.embed_dim) - self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed) - self.register_buffer("embed_prob", th.zeros(self.num_embed)) - - def forward(self, z: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: - # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, "b c h w -> b h w c").contiguous() - z_flattened = z.view(-1, self.embed_dim) - - # clculate the distance - if self.distance == "l2": - # l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = ( - -th.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - - th.sum(self.embedding.weight**2, dim=1) - + 2 * th.einsum("bd, dn-> bn", z_flattened.detach(), rearrange(self.embedding.weight, "n d-> d n")) - ) - elif self.distance == "cos": - # cosine distances from z to embeddings e_j - normed_z_flattened = f.normalize(z_flattened, dim=1).detach() - normed_codebook = f.normalize(self.embedding.weight, dim=1) - d = th.einsum("bd,dn->bn", normed_z_flattened, rearrange(normed_codebook, "n d -> d n")) - - # encoding - sort_distance, indices = d.sort(dim=1) - # look up the closest point for the indices - encoding_indices = indices[:, -1] - encodings = th.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device) - encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) - - # quantize and unflatten - z_q = th.matmul(encodings, self.embedding.weight).view(z.shape) - # compute loss for embedding - loss = self.beta * th.mean((z_q.detach() - z) ** 2) + th.mean((z_q - z.detach()) ** 2) - # preserve gradients - z_q = z + (z_q - z).detach() - # reshape back to match original input shape - z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() - # count - avg_probs = th.mean(encodings, dim=0) - perplexity = th.exp(-th.sum(avg_probs * th.log(avg_probs + 1e-10))) - min_encodings = encodings - - # online clustered reinitialization for unoptimized points - if self.training: - # calculate the average usage of code entries - self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1 - self.decay) - # running average updates - if self.anchor in ["closest", "random", "probrandom"] and (not self.init): - # closest sampling - if self.anchor == "closest": - sort_distance, indices = d.sort(dim=0) - random_feat = z_flattened.detach()[indices[-1, :]] - # feature pool based random sampling - elif self.anchor == "random": - random_feat = self.pool.query(z_flattened.detach()) - # probabilitical based random sampling - elif self.anchor == "probrandom": - norm_distance = f.softmax(d.t(), dim=1) - prob = th.multinomial(norm_distance, num_samples=1).view(-1) - random_feat = z_flattened.detach()[prob] - # decay parameter based on the average usage - decay = ( - th.exp(-(self.embed_prob * self.num_embed * 10) / (1 - self.decay) - 1e-3) - .unsqueeze(1) - .repeat(1, self.embed_dim) - ) - self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay - if self.first_batch: - self.init = True - # contrastive loss - if self.contras_loss: - sort_distance, indices = d.sort(dim=0) - dis_pos = sort_distance[-max(1, int(sort_distance.size(0) / self.num_embed)) :, :].mean(dim=0, keepdim=True) - dis_neg = sort_distance[: int(sort_distance.size(0) * 1 / 2), :] - dis = th.cat([dis_pos, dis_neg], dim=0).t() / 0.07 - contra_loss = f.cross_entropy(dis, th.zeros((dis.size(0),), dtype=th.long, device=dis.device)) - loss += contra_loss - - return z_q, encoding_indices, loss, min_encodings, perplexity - - -class FeaturePool: - """Implements a feature buffer that stores previously encoded features. - - This buffer enables us to initialize the codebook using a history of generated features rather than the ones produced by the latest encoders. - - Parameters: - pool_size (int): the size of feature buffer - dim (int): the dimension of each feature - """ - - def __init__(self, pool_size: int, dim: int = 64): - self.pool_size = pool_size - if self.pool_size > 0: - self.nums_features = 0 - self.features = (th.rand((pool_size, dim)) * 2 - 1) / pool_size - - def query(self, features: th.Tensor) -> th.Tensor: - """Return features from the pool.""" - self.features = self.features.to(features.device) - if self.nums_features < self.pool_size: - if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook - random_feat_id = th.randint(0, features.size(0), (int(self.pool_size),)) - self.features = features[random_feat_id] - self.nums_features = self.pool_size - else: - # if the mini-batch is not large nuough, just store it for the next update - num = self.nums_features + features.size(0) - self.features[self.nums_features : num] = features - self.nums_features = num - elif features.size(0) > int(self.pool_size): - random_feat_id = th.randint(0, features.size(0), (int(self.pool_size),)) - self.features = features[random_feat_id] - else: - random_id = th.randperm(self.pool_size) - self.features[random_id[: features.size(0)]] = features - - return self.features - - -class GroupNorm(nn.Module): - """Group Normalization block to be used in VQGAN Encoder and Decoder. - - Parameters: - channels (int): number of channels in the input feature map - """ - - def __init__(self, channels: int): - super().__init__() - self.gn = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True) - - def forward(self, x: th.Tensor) -> th.Tensor: - return self.gn(x) - - -class Swish(nn.Module): - """Swish activation function to be used in VQGAN Encoder and Decoder.""" - - def forward(self, x: th.Tensor) -> th.Tensor: - return x * th.sigmoid(x) - - -class ResidualBlock(nn.Module): - """Residual block to be used in VQGAN Encoder and Decoder. - - Parameters: - in_channels (int): number of input channels - out_channels (int): number of output channels - """ - - def __init__(self, in_channels: int, out_channels: int): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.block = nn.Sequential( - GroupNorm(in_channels), - Swish(), - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), - GroupNorm(out_channels), - Swish(), - nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), - ) - - if in_channels != out_channels: - self.channel_up = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x: th.Tensor) -> th.Tensor: - if self.in_channels != self.out_channels: - return self.channel_up(x) + self.block(x) - return x + self.block(x) - - -class UpSampleBlock(nn.Module): - """Up-sampling block to be used in VQGAN Decoder. - - Parameters: - channels (int): number of channels in the input feature map - """ - - def __init__(self, channels: int): - super().__init__() - self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x: th.Tensor) -> th.Tensor: - x = f.interpolate(x, scale_factor=2.0) - return self.conv(x) - - -class DownSampleBlock(nn.Module): - """Down-sampling block to be used in VQGAN Encoder. - - Parameters: - channels (int): number of channels in the input feature map - """ - - def __init__(self, channels: int): - super().__init__() - self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x: th.Tensor) -> th.Tensor: - pad = (0, 1, 0, 1) - x = f.pad(x, pad, mode="constant", value=0) - return self.conv(x) - - -class NonLocalBlock(nn.Module): - """Non-local attention block to be used in VQGAN Encoder and Decoder. - - Parameters: - channels (int): number of channels in the input feature map - """ - - def __init__(self, channels: int): - super().__init__() - self.in_channels = channels - - self.gn = GroupNorm(channels) - self.q = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) - self.k = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) - self.v = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) - self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x: th.Tensor) -> th.Tensor: - h_ = self.gn(x) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - b, c, h, w = q.shape - - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h * w) - v = v.reshape(b, c, h * w) - - attn = th.bmm(q, k) - attn = attn * (int(c) ** (-0.5)) - attn = f.softmax(attn, dim=2) - attn = attn.permute(0, 2, 1) - - a = th.bmm(v, attn) - a = a.reshape(b, c, h, w) - - return x + a - - -class LinearCombo(nn.Module): - """Regular fully connected layer combo for the CVQGAN if enabled. - - Parameters: - in_features (int): number of input features - out_features (int): number of output features - alpha (float): negative slope for LeakyReLU - """ - - def __init__(self, in_features: int, out_features: int, alpha: float = 0.2): - super().__init__() - self.model = nn.Sequential(nn.Linear(in_features, out_features), nn.LeakyReLU(alpha)) - - def forward(self, x: th.Tensor) -> th.Tensor: - return self.model(x) - - class Encoder(nn.Module): """Encoder module for VQGAN Stage 1. @@ -638,219 +335,6 @@ def forward(self, x: th.Tensor) -> th.Tensor: return self.model(x.contiguous().view(len(x), -1)) -class Discriminator(nn.Module): - """PatchGAN-style discriminator. - - Adapted from: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538 - This assumes we never use a discriminator for the CVQGAN, since it is generally a much simpler model. - - Parameters: - num_filters_last: Number of filters in the last conv layer. - n_layers: Number of convolutional layers. - image_channels: Number of channels in the input image. - """ - - def __init__(self, num_filters_last: int = 64, n_layers: int = 3, image_channels: int = 1): - super().__init__() - - # Convolutional backbone (PatchGAN) - layers: list[nn.Module] = [ - nn.Conv2d(image_channels, num_filters_last, kernel_size=4, stride=2, padding=1), - nn.LeakyReLU(0.2, inplace=True), - ] - num_filters_mult = 1 - - for i in range(1, n_layers + 1): - num_filters_mult_last = num_filters_mult - num_filters_mult = min(2**i, 8) - layers += [ - nn.Conv2d( - num_filters_last * num_filters_mult_last, - num_filters_last * num_filters_mult, - kernel_size=4, - stride=2 if i < n_layers else 1, - padding=1, - bias=False, - ), - nn.BatchNorm2d(num_filters_last * num_filters_mult), - nn.LeakyReLU(0.2, inplace=True), - ] - - layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, kernel_size=4, stride=1, padding=1)) - self.model = nn.Sequential(*layers) - - # Initialize weights - self.apply(self._weights_init) - - @staticmethod - def _weights_init(m: nn.Module) -> None: - """Custom weight initialization (DCGAN-style).""" - classname = m.__class__.__name__ - if "Conv" in classname: - nn.init.normal_(m.weight.data, mean=0.0, std=0.02) - elif "BatchNorm" in classname: - nn.init.normal_(m.weight.data, mean=1.0, std=0.02) - nn.init.constant_(m.bias.data, 0.0) - - def forward(self, x: th.Tensor) -> th.Tensor: - """Forward pass with optional CVQGAN adapter.""" - return self.model(x) - - -class ScalingLayer(nn.Module): - """Channel-wise affine normalization used by LPIPS.""" - - def __init__(self): - super().__init__() - self.register_buffer("shift", th.tensor([-0.030, -0.088, -0.188])[None, :, None, None]) - self.register_buffer("scale", th.tensor([0.458, 0.448, 0.450])[None, :, None, None]) - - def forward(self, x: th.Tensor) -> th.Tensor: - return (x - self.shift) / self.scale - - -class NetLinLayer(nn.Module): - """1x1 conv with dropout (per-layer LPIPS linear head).""" - - def __init__(self, in_channels: int, out_channels: int = 1): - super().__init__() - self.model = nn.Sequential( - nn.Dropout(), - nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), - ) - - def forward(self, x: th.Tensor) -> th.Tensor: - return self.model(x) - - -class VGG16(nn.Module): - """Torchvision VGG16 feature extractor sliced at LPIPS tap points.""" - - def __init__(self): - super().__init__() - vgg_feats = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features - blocks = [vgg_feats[i] for i in range(30)] - self.slice1 = nn.Sequential(*blocks[0:4]) # relu1_2 - self.slice2 = nn.Sequential(*blocks[4:9]) # relu2_2 - self.slice3 = nn.Sequential(*blocks[9:16]) # relu3_3 - self.slice4 = nn.Sequential(*blocks[16:23]) # relu4_3 - self.slice5 = nn.Sequential(*blocks[23:30]) # relu5_3 - self.requires_grad_(requires_grad=False) - - def forward(self, x: th.Tensor) -> tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor]: - h1 = self.slice1(x) - h2 = self.slice2(h1) - h3 = self.slice3(h2) - h4 = self.slice4(h3) - h5 = self.slice5(h4) - return (h1, h2, h3, h4, h5) - - -class GreyscaleLPIPS(nn.Module): - """LPIPS for greyscale/topological data with optional 'raw' aggregation. - - ``use_raw=True`` is often preferable for non-natural images since learned - linear heads are tuned on natural RGB photos. - - Parameters: - use_raw: If True, average raw per-layer squared diffs (no linear heads). - clamp_output: Clamp the final loss to ``>= 0``. - robust_clamp: Clamp inputs to [0, 1] before feature extraction. - warn_on_clamp: If True, log warnings when inputs fall outside [0, 1]. - freeze: If True, disables grads on all params. - ckpt_name: Key in URL_MAP/CKPT_MAP for loading LPIPS heads. - """ - - def __init__( # noqa: PLR0913 - self, - *, - use_raw: bool = True, - clamp_output: bool = False, - robust_clamp: bool = True, - warn_on_clamp: bool = False, - freeze: bool = True, - ckpt_name: str = "vgg_lpips", - ): - super().__init__() - self.use_raw = use_raw - self.clamp_output = clamp_output - self.robust_clamp = robust_clamp - self.warn_on_clamp = warn_on_clamp - - self.scaling_layer = ScalingLayer() - self.channels = (64, 128, 256, 512, 512) - self.vgg = VGG16() - self.linears = nn.ModuleList([NetLinLayer(c) for c in self.channels]) - - self._load_from_pretrained(name=ckpt_name) - if freeze: - self.requires_grad_(requires_grad=False) - - def forward(self, real_x: th.Tensor, fake_x: th.Tensor) -> th.Tensor: - """Compute greyscale-aware LPIPS distance between two batches.""" - if self.robust_clamp: - real_x = th.clamp(real_x, 0.0, 1.0) - fake_x = th.clamp(fake_x, 0.0, 1.0) - - # Promote greyscale -> RGB for VGG features - if real_x.shape[1] == 1: - real_x = real_x.repeat(1, 3, 1, 1) - if fake_x.shape[1] == 1: - fake_x = fake_x.repeat(1, 3, 1, 1) - - fr = self.vgg(self.scaling_layer(real_x)) - ff = self.vgg(self.scaling_layer(fake_x)) - diffs = [(self._norm_tensor(a) - self._norm_tensor(b)) ** 2 for a, b in zip(fr, ff)] - - if self.use_raw: - parts = [self._spatial_average(d).mean(dim=1, keepdim=True) for d in diffs] - else: - parts = [self._spatial_average(self.linears[i](d)) for i, d in enumerate(diffs)] - - loss = th.stack(parts, dim=0).sum() - if self.clamp_output: - loss = th.clamp(loss, min=0.0) - return loss - - # Helpers - @staticmethod - def _norm_tensor(x: th.Tensor) -> th.Tensor: - """L2-normalize channels per spatial location: BxCxHxW -> BxCxHxW.""" - norm = th.sqrt(th.sum(x**2, dim=1, keepdim=True)) - return x / (norm + 1e-10) - - @staticmethod - def _spatial_average(x: th.Tensor) -> th.Tensor: - """Average over spatial dimensions with dims kept: BxCxHxW -> BxCx1x1.""" - return x.mean(dim=(2, 3), keepdim=True) - - def _load_from_pretrained(self, *, name: str) -> None: - """Load LPIPS linear heads (and any required buffers) from a checkpoint.""" - ckpt = self._get_ckpt_path(name, "vgg_lpips") - state_dict = th.load(ckpt, map_location=th.device("cpu"), weights_only=True) - self.load_state_dict(state_dict, strict=False) - - @staticmethod - def _download(url: str, local_path: str, *, chunk_size: int = 1024) -> None: - """Stream a file to disk with a progress bar.""" - os.makedirs(os.path.split(local_path)[0], exist_ok=True) - with requests.get(url, stream=True, timeout=10) as r: - total_size = int(r.headers.get("content-length", 0)) - with tqdm.tqdm(total=total_size, unit="B", unit_scale=True) as pbar, open(local_path, "wb") as f: - for data in r.iter_content(chunk_size=chunk_size): - if data: - f.write(data) - pbar.update(len(data)) - - def _get_ckpt_path(self, name: str, root: str) -> str: - """Return local path to a pretrained LPIPS checkpoint; download if missing.""" - assert name in URL_MAP, f"Unknown LPIPS checkpoint name: {name!r}" - path = os.path.join(root, CKPT_MAP[name]) - if not os.path.exists(path): - self._download(URL_MAP[name], path) - return path - - class VQGAN(nn.Module): """VQGAN model for Stage 1. @@ -974,183 +458,6 @@ def adopt_weight(disc_factor: float, i: int, threshold: int, value: float = 0.0) return value if i < threshold else disc_factor -########################################### -########## GPT-2 BASE CODE BELOW ########## -########################################### -class LayerNorm(nn.Module): - """LayerNorm with optional bias (PyTorch lacks bias=False support).""" - - def __init__(self, ndim: int, *, bias: bool): - super().__init__() - self.weight = nn.Parameter(th.ones(ndim)) - self.bias = nn.Parameter(th.zeros(ndim)) if bias else None - - def forward(self, x: th.Tensor) -> th.Tensor: - return f.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) - - -class CausalSelfAttention(nn.Module): - """Causal self-attention with FlashAttention fallback when unavailable.""" - - def __init__(self, config: GPTConfig): - super().__init__() - assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head" - - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) - self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - self.attn_dropout = nn.Dropout(config.dropout) - self.resid_dropout = nn.Dropout(config.dropout) - - self.n_head = config.n_head - self.n_embd = config.n_embd - self.dropout = config.dropout - - self.flash = hasattr(f, "scaled_dot_product_attention") - if not self.flash: - warnings.warn( - "Falling back to non-flash attention; PyTorch >= 2.0 enables FlashAttention.", - stacklevel=2, - ) - self.register_buffer( - "bias", - th.tril(th.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size), - ) - - def forward(self, x: th.Tensor) -> th.Tensor: - b, t, c = x.size() - - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(b, t, self.n_head, c // self.n_head).transpose(1, 2) - q = q.view(b, t, self.n_head, c // self.n_head).transpose(1, 2) - v = v.view(b, t, self.n_head, c // self.n_head).transpose(1, 2) - - if self.flash: - y = f.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=self.dropout if self.training else 0.0, - is_causal=True, - ) - else: - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float("-inf")) - att = f.softmax(att, dim=-1) - att = self.attn_dropout(att) - y = att @ v - - y = y.transpose(1, 2).contiguous().view(b, t, c) - return self.resid_dropout(self.c_proj(y)) - - -class MLP(nn.Module): - """Feed-forward block used inside Transformer blocks.""" - - def __init__(self, config: GPTConfig): - super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) - self.gelu = nn.GELU() - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) - self.dropout = nn.Dropout(config.dropout) - - def forward(self, x: th.Tensor) -> th.Tensor: - x = self.c_fc(x) - x = self.gelu(x) - x = self.c_proj(x) - return self.dropout(x) - - -class Block(nn.Module): - """Transformer block: LayerNorm -> Self-Attn -> residual; LayerNorm -> MLP -> residual.""" - - def __init__(self, config: GPTConfig): - super().__init__() - self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) - self.attn = CausalSelfAttention(config) - self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) - self.mlp = MLP(config) - - def forward(self, x: th.Tensor) -> th.Tensor: - x = x + self.attn(self.ln_1(x)) - return x + self.mlp(self.ln_2(x)) - - -@dataclass -class GPTConfig: - block_size: int = 1024 - vocab_size: int = 50304 # GPT-2 uses 50257; padded to multiple of 64 - n_layer: int = 12 - n_head: int = 12 - n_embd: int = 768 - dropout: float = 0.0 - bias: bool = True # GPT-2 uses biases in Linear/LayerNorm - - -class GPT(nn.Module): - """Minimal GPT-2 style Transformer.""" - - def __init__(self, config: GPTConfig): - super().__init__() - self.config = config - - self.transformer = nn.ModuleDict( - { - "wte": nn.Embedding(config.vocab_size, config.n_embd), - "wpe": nn.Embedding(config.block_size, config.n_embd), - "drop": nn.Dropout(config.dropout), - "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - "ln_f": LayerNorm(config.n_embd, bias=config.bias), - } - ) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.transformer["wte"].weight = self.lm_head.weight # weight tying - - self.apply(self._init_weights) - for pn, p in self.named_parameters(): - if pn.endswith("c_proj.weight"): - th.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) - - def _init_weights(self, module: nn.Module) -> None: - if isinstance(module, nn.Linear): - th.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - th.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - th.nn.init.normal_(module.weight, mean=0.0, std=0.02) - - def forward( - self, - idx: th.Tensor, - targets: th.Tensor | None = None, - ) -> tuple[th.Tensor, th.Tensor | None]: - """Forward pass returning logits and optional cross-entropy loss.""" - device = idx.device - _, t = idx.size() - assert t <= self.config.block_size, f"Cannot forward sequence of length {t}; block size is {self.config.block_size}" - pos = th.arange(0, t, dtype=th.long, device=device) - - tok_emb = self.transformer["wte"](idx) - pos_emb = self.transformer["wpe"](pos) - x = self.transformer["drop"](tok_emb + pos_emb) - for block in self.transformer["h"]: - x = block(x) - x = self.transformer["ln_f"](x) - - logits = self.lm_head(x) - loss: th.Tensor | None - if targets is not None: - loss = f.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) - else: - loss = None - return logits, loss - - -########################################### -########## GPT-2 BASE CODE ABOVE ########## -########################################### - - class VQGANTransformer(nn.Module): """Wrapper for VQGAN Stage 2: Transformer. From d88a48bff249b5535f1dcf4b1aa2f35f5535ace8 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Wed, 22 Oct 2025 18:01:53 +0200 Subject: [PATCH 21/22] remove default comments from args --- engiopt/vqgan/vqgan.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 92cb976..787b451 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -95,9 +95,9 @@ class Args: """feature map dimension for the CVQGAN encoder output""" batch_size_cvqgan: int = 16 """size of the batches for CVQGAN""" - n_epochs_cvqgan: int = 1000 # Default: 1000 + n_epochs_cvqgan: int = 1000 """number of epochs of CVQGAN training""" - cond_lr: float = 2e-4 # Default: 2e-4 + cond_lr: float = 2e-4 """learning rate for CVQGAN""" latent_size: int = 16 """size of the latent feature map (automatically determined later)""" @@ -106,11 +106,11 @@ class Args: # Algorithm-specific: Stage 1 (AE) # From original implementation: assume image_channels=1, use greyscale LPIPS only, use_Online=True, determine image_size automatically, calculate decoder_start_resolution automatically - n_epochs_vqgan: int = 100 # Default: 100 + n_epochs_vqgan: int = 100 """number of epochs of training""" batch_size_vqgan: int = 16 """size of the batches for Stage 1""" - lr_vqgan: float = 5e-5 # Default: 2e-4 + lr_vqgan: float = 5e-5 """learning rate for Stage 1""" beta: float = 0.25 """beta hyperparameter for the codebook commitment loss""" @@ -149,7 +149,7 @@ class Args: # Algorithm-specific: Stage 2 (Transformer) # From original implementation: assume pkeep=1.0, sos_token=0, bias=True - n_epochs_transformer: int = 100 # Default: 100 + n_epochs_transformer: int = 100 """number of epochs of training""" early_stopping: bool = True """whether to use early stopping for the transformer; if True requires args.track to be True""" @@ -159,7 +159,7 @@ class Args: """minimum change in the monitored quantity to qualify as an improvement""" batch_size_transformer: int = 16 """size of the batches for Stage 2""" - lr_transformer: float = 6e-4 # Default: 6e-4 + lr_transformer: float = 6e-4 """learning rate for Stage 2""" n_layer: int = 12 """number of layers in the transformer""" From 5577396789df4b6453972356df82d885a0e15b94 Mon Sep 17 00:00:00 2001 From: Arthur Drake Date: Wed, 22 Oct 2025 23:21:00 +0200 Subject: [PATCH 22/22] remove unused args --- engiopt/vqgan/vqgan.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index 787b451..efdc2fa 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -99,10 +99,6 @@ class Args: """number of epochs of CVQGAN training""" cond_lr: float = 2e-4 """learning rate for CVQGAN""" - latent_size: int = 16 - """size of the latent feature map (automatically determined later)""" - image_channels: int = 1 - """number of channels in the input image (automatically determined later)""" # Algorithm-specific: Stage 1 (AE) # From original implementation: assume image_channels=1, use greyscale LPIPS only, use_Online=True, determine image_size automatically, calculate decoder_start_resolution automatically @@ -682,8 +678,8 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu training_ds = training_ds.remove_columns("optimal_design") # Now we assume the dataset is of shape (N, C, H, W) and work from there - args.image_channels = training_ds["optimal_upsampled"][:].shape[1] - args.latent_size = args.image_size // (2 ** (len(args.encoder_channels) - 2)) + image_channels = training_ds["optimal_upsampled"][:].shape[1] + latent_size = args.image_size // (2 ** (len(args.encoder_channels) - 2)) conditions = problem.conditions_keys # Optionally drop condition columns that are constant like overhang_constraint in beams2d @@ -800,15 +796,15 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu encoder_attn_resolutions=args.encoder_attn_resolutions, encoder_num_res_blocks=args.encoder_num_res_blocks, decoder_channels=args.decoder_channels, - decoder_start_resolution=args.latent_size, + decoder_start_resolution=latent_size, decoder_attn_resolutions=args.decoder_attn_resolutions, decoder_num_res_blocks=args.decoder_num_res_blocks, - image_channels=args.image_channels, + image_channels=image_channels, latent_dim=args.latent_dim, num_codebook_vectors=args.num_codebook_vectors, ).to(device=device) - discriminator = Discriminator(image_channels=args.image_channels).to(device=device) + discriminator = Discriminator(image_channels=image_channels).to(device=device) cvqgan = VQGAN( device=device, @@ -924,7 +920,7 @@ def sample_designs_transformer(n_designs: int) -> tuple[th.Tensor, th.Tensor]: c = th.ones(n_designs, 1, dtype=th.int64, device=device) * transformer.sos_token latent_imgs = transformer.sample( - x=th.empty(n_designs, 0, dtype=th.int64, device=device), c=c, steps=(args.latent_size**2) + x=th.empty(n_designs, 0, dtype=th.int64, device=device), c=c, steps=(latent_size**2) ) gen_imgs = transformer.z_to_image(latent_imgs)