From 0f780a5a896e454a5063ccc8ff29408d70ac324b Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Thu, 29 Jan 2026 23:32:57 +0100 Subject: [PATCH 1/4] lvae 2d vanilla implementation --- engiopt/lvae_2d/aes.py | 509 +++++++++++++++++++++++++++++++ engiopt/lvae_2d/lvae_2d.py | 455 +++++++++++++++++++++++++++ engiopt/lvae_2d/plvae_2d.py | 592 ++++++++++++++++++++++++++++++++++++ 3 files changed, 1556 insertions(+) create mode 100644 engiopt/lvae_2d/aes.py create mode 100644 engiopt/lvae_2d/lvae_2d.py create mode 100644 engiopt/lvae_2d/plvae_2d.py diff --git a/engiopt/lvae_2d/aes.py b/engiopt/lvae_2d/aes.py new file mode 100644 index 0000000..e0e3ef1 --- /dev/null +++ b/engiopt/lvae_2d/aes.py @@ -0,0 +1,509 @@ +"""LVAE autoencoder implementations. Adapted from https://github.com/IDEALLab/Least_Volume_ICLR2024. + +This module provides simplified autoencoder classes for EngiOpt's LVAE experiments: +- LeastVolumeAE: Volume-regularized autoencoder (no pruning) +- LeastVolumeAE_DynamicPruning: Volume AE with plummet-based dimension pruning +- PerfLeastVolumeAE_DP: Adds performance prediction to dynamic pruning AE +- InterpretablePerfLeastVolumeAE_DP: Performance prediction using first latent dims +""" + +from __future__ import annotations + +from typing import Callable, TYPE_CHECKING + +import torch +from torch import nn +import torch.nn.functional as f +from tqdm import tqdm + +if TYPE_CHECKING: + from torch.optim import Optimizer + from torch.utils.data import DataLoader + + +class LeastVolumeAE(nn.Module): + """Autoencoder with volume-regularization loss. + + Minimizes the volume of the latent space (geometric mean of standard deviations) + in addition to reconstruction error, promoting a compact representation. + + Volume loss is computed as: exp(mean(log(std_i + eta))) + where std_i is the standard deviation of each latent dimension and eta is a small + constant for numerical stability. + + Args: + encoder: Encoder network mapping input to latent code. + decoder: Decoder network mapping latent code to reconstruction. + optimizer: Optimizer instance for training. + weights: Loss weights [reconstruction, volume]. Default: [1.0, 0.001]. + eta: Smoothing constant for volume loss computation. Default: 0. + """ + + w: torch.Tensor # Type annotation for buffer + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + optimizer: Optimizer, + weights: list[float] | Callable[[int], torch.Tensor] | None = None, + eta: float = 0, + ) -> None: + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.optim = optimizer + self.eta = eta + + if weights is None: + weights = [1.0, 0.001] + + if callable(weights): + w = weights(0) + self._w_schedule: Callable[[int], torch.Tensor] | None = weights + else: + w = weights + self._w_schedule = None + + self.register_buffer("w", torch.as_tensor(w, dtype=torch.float)) + self._init_epoch = 0 + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode input to latent representation.""" + return self.encoder(x) + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """Decode latent representation to reconstruction.""" + return self.decoder(z) + + def loss(self, x: torch.Tensor) -> torch.Tensor: + """Compute reconstruction and volume losses. + + Args: + x: Input batch tensor. + + Returns: + Tensor of shape (2,) containing [reconstruction_loss, volume_loss]. + """ + z = self.encode(x) + x_hat = self.decode(z) + return torch.stack([self.loss_rec(x, x_hat), self.loss_vol(z)]) + + def loss_rec(self, x: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor: + """Compute reconstruction loss (MSE).""" + return f.mse_loss(x, x_hat) + + def loss_vol(self, z: torch.Tensor) -> torch.Tensor: + """Compute volume loss as geometric mean of latent standard deviations. + + Volume loss = exp(mean(log(std_i + eta))) + + Args: + z: Latent codes of shape (batch_size, latent_dim). + + Returns: + Scalar volume loss. + """ + s = z.std(0) + return torch.exp(torch.log(s + self.eta).mean()) + + def epoch_hook(self, epoch: int) -> None: + """Called at the start of each epoch to update weight schedule.""" + if self._w_schedule is not None: + w = self._w_schedule(epoch) + self.w = w.to(self.w.device) + + def epoch_report( + self, + epoch: int, + callbacks: list[Callable[..., None]], + **kwargs: object, + ) -> None: + """Called at the end of each epoch for logging/callbacks.""" + for callback in callbacks: + callback(self, epoch=epoch, **kwargs) + + def fit( + self, + dataloader: DataLoader[torch.Tensor], + epochs: int, + callbacks: list[Callable[..., None]] | None = None, + ) -> None: + """Train the autoencoder. + + Args: + dataloader: Training data loader. + epochs: Maximum number of epochs. + callbacks: Optional list of callback functions. + """ + if callbacks is None: + callbacks = [] + + with tqdm( + range(self._init_epoch, epochs), + initial=self._init_epoch, + total=epochs, + bar_format="{l_bar}{bar:20}{r_bar}", + desc="Training", + ) as pbar: + for epoch in pbar: + self.epoch_hook(epoch=epoch) + for batch in dataloader: + self.optim.zero_grad() + loss = self.loss(batch) + (loss * self.w).sum().backward() + self.optim.step() + self.epoch_report(epoch=epoch, callbacks=callbacks, batch=batch, loss=loss, pbar=pbar) + + +class LeastVolumeAE_DynamicPruning(LeastVolumeAE): # noqa: N801 + """Least-volume autoencoder with plummet-based dynamic dimension pruning. + + Extends LeastVolumeAE by dynamically pruning low-variance latent dimensions + during training using the plummet strategy (detects sharp drops in sorted variances). + + Args: + encoder: Encoder network. + decoder: Decoder network. + optimizer: Optimizer instance. + latent_dim: Total number of latent dimensions. + weights: Loss weights [reconstruction, volume]. Default: [1.0, 0.001]. + eta: Smoothing parameter for volume loss. Default: 0. + beta: EMA momentum for latent statistics. Default: 0.9. + pruning_epoch: Epoch to start pruning. Default: 500. + plummet_threshold: Ratio threshold for plummet pruning. Default: 0.02. + """ + + _p: torch.Tensor # Boolean mask for pruned dimensions + _z: torch.Tensor # Frozen mean values for pruned dimensions + _frozen_std: torch.Tensor # Frozen std values for volume loss (captured at prune time) + + def __init__( # noqa: PLR0913 + self, + encoder: nn.Module, + decoder: nn.Module, + optimizer: Optimizer, + latent_dim: int, + weights: list[float] | Callable[[int], torch.Tensor] | None = None, + eta: float = 0, + beta: float = 0.9, + pruning_epoch: int = 500, + plummet_threshold: float = 0.02, + ) -> None: + if weights is None: + weights = [1.0, 0.001] + super().__init__(encoder, decoder, optimizer, weights, eta) + + self.register_buffer("_p", torch.zeros(latent_dim, dtype=torch.bool)) + self.register_buffer("_z", torch.zeros(latent_dim)) + self.register_buffer("_frozen_std", torch.ones(latent_dim)) # Init to 1.0, overwritten each forward + + self._beta = beta + self.pruning_epoch = pruning_epoch + self.plummet_threshold = plummet_threshold + + # EMA statistics (initialized on first batch) + self._zstd: torch.Tensor | None = None + self._zmean: torch.Tensor | None = None + + def to(self, device: torch.device | str) -> LeastVolumeAE_DynamicPruning: + """Move model to device.""" + super().to(device) + self._p = self._p.to(device) + self._z = self._z.to(device) + self._frozen_std = self._frozen_std.to(device) + return self + + @property + def dim(self) -> int: + """Number of active (unpruned) latent dimensions.""" + return int((~self._p).sum().item()) + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """Decode with pruned dimensions frozen to their mean values.""" + z = z.clone() + z[:, self._p] = self._z[self._p] + return self.decoder(z) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode with pruned dimensions frozen to their mean values.""" + z = self.encoder(x) + z = z.clone() + z[:, self._p] = self._z[self._p] + return z + + def loss(self, x: torch.Tensor) -> torch.Tensor: + """Compute losses and update moving statistics.""" + z = self.encode(x) + x_hat = self.decode(z) + self._update_moving_mean(z) + + # Volume loss: pruned dims use frozen std (captured at prune time), + # active dims use current std. This makes pruning volume-neutral. + s = self._frozen_std.clone() + if (~self._p).any(): + s[~self._p] = z[:, ~self._p].std(0) + vol_loss = torch.exp(torch.log(s).mean()) + + return torch.stack([self.loss_rec(x, x_hat), vol_loss]) + + @torch.no_grad() + def _update_moving_mean(self, z: torch.Tensor) -> None: + """Update exponential moving average of latent statistics.""" + if self._zstd is None or self._zmean is None: + self._zstd = z.std(0) + self._zmean = z.mean(0) + else: + self._zstd = torch.lerp(self._zstd, z.std(0), 1 - self._beta) + self._zmean = torch.lerp(self._zmean, z.mean(0), 1 - self._beta) + + @torch.no_grad() + def _plummet_prune(self, z_std: torch.Tensor) -> torch.Tensor: + """Plummet-based pruning: detect sharp drops in sorted variances. + + Args: + z_std: Standard deviation per latent dimension. + + Returns: + Boolean mask where True indicates dimensions to prune. + """ + # Sort variances in descending order + srt, idx = torch.sort(z_std, descending=True) + + # Compute log-space drops + log_srt = (srt + 1e-12).log() + d_log = log_srt[1:] - log_srt[:-1] + + # Find the steepest drop (most negative value) + # d_log[i] = log(srt[i+1]) - log(srt[i]), so argmin gives the index BEFORE the drop + pidx_sorted = d_log.argmin() + + # Use variance BEFORE the drop as reference (the last "good" dimension) + ref = srt[pidx_sorted] + + # Prune dimensions with ratio below threshold relative to reference + ratio = z_std / (ref + 1e-12) + return ratio < self.plummet_threshold + + @torch.no_grad() + def _prune_step(self, _epoch: int) -> None: + """Execute pruning step if conditions are met.""" + if self._zstd is None or self._zmean is None: + return + + # Only consider active dimensions + z_std_active = self._zstd[~self._p] + if len(z_std_active) == 0: + return + + cand_active = self._plummet_prune(z_std_active) + + # Map back to full dimension space + cand = torch.zeros_like(self._p, dtype=torch.bool) + cand[~self._p] = cand_active + + # Get indices to prune + prune_idx = torch.where(cand & (~self._p))[0] + if len(prune_idx) == 0: + return + + # Freeze std BEFORE marking as pruned (capture current variance for volume loss) + self._frozen_std[prune_idx] = self._zstd[prune_idx].clone() + + # Commit pruning + self._p[prune_idx] = True + self._z[prune_idx] = self._zmean[prune_idx] + + def epoch_report( + self, + epoch: int, + callbacks: list[Callable[..., None]], + **kwargs: object, + ) -> None: + """Called at end of epoch - triggers pruning if past pruning_epoch.""" + if epoch >= self.pruning_epoch: + self._prune_step(epoch) + + super().epoch_report(epoch=epoch, callbacks=callbacks, **kwargs) + + +class PerfLeastVolumeAE_DP(LeastVolumeAE_DynamicPruning): # noqa: N801 + """Performance-predicting autoencoder with dynamic pruning. + + Extends LeastVolumeAE_DynamicPruning to include performance prediction + capabilities alongside reconstruction and volume minimization. + + The predictor takes the full latent code concatenated with conditions + to predict performance values. + + Args: + encoder: Encoder network. + decoder: Decoder network. + predictor: Performance prediction network (input: [z, conditions]). + optimizer: Optimizer instance. + latent_dim: Total number of latent dimensions. + weights: Loss weights [reconstruction, performance, volume]. Default: [1.0, 1.0, 0.001]. + eta: Smoothing parameter for volume loss. Default: 0. + beta: EMA momentum for latent statistics. Default: 0.9. + pruning_epoch: Epoch to start pruning. Default: 500. + plummet_threshold: Ratio threshold for plummet pruning. Default: 0.02. + """ + + def __init__( # noqa: PLR0913 + self, + encoder: nn.Module, + decoder: nn.Module, + predictor: nn.Module, + optimizer: Optimizer, + latent_dim: int, + weights: list[float] | Callable[[int], torch.Tensor] | None = None, + eta: float = 0, + beta: float = 0.9, + pruning_epoch: int = 500, + plummet_threshold: float = 0.02, + ) -> None: + if weights is None: + weights = [1.0, 1.0, 0.001] + super().__init__( + encoder=encoder, + decoder=decoder, + optimizer=optimizer, + latent_dim=latent_dim, + weights=weights, + eta=eta, + beta=beta, + pruning_epoch=pruning_epoch, + plummet_threshold=plummet_threshold, + ) + self.predictor = predictor + + def loss(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Compute reconstruction, performance, and volume losses. + + Args: + batch: Tuple of (designs, conditions, performance_targets). + + Returns: + Tensor of shape (3,) containing [rec_loss, perf_loss, vol_loss]. + """ + x, c, p = batch + z = self.encode(x) + x_hat = self.decode(z) + + # Update moving statistics + self._update_moving_mean(z) + + # Performance prediction using full latent + conditions + p_hat = self.predictor(torch.cat([z, c], dim=-1)) + + # Volume loss: pruned dims use frozen std (captured at prune time), + # active dims use current std. This makes pruning volume-neutral. + s = self._frozen_std.clone() + if (~self._p).any(): + s[~self._p] = z[:, ~self._p].std(0) + vol_loss = torch.exp(torch.log(s).mean()) + + return torch.stack( + [ + self.loss_rec(x, x_hat), + self.loss_rec(p, p_hat), + vol_loss, + ] + ) + + +class InterpretablePerfLeastVolumeAE_DP(LeastVolumeAE_DynamicPruning): # noqa: N801 + """Interpretable performance-predicting autoencoder with dynamic pruning. + + This variant enforces that the first `perf_dim` latent dimensions are dedicated + to performance prediction, making them more interpretable. + + The predictor only uses the first `perf_dim` latent dimensions concatenated + with conditions to predict performance values. + + Args: + encoder: Encoder network. + decoder: Decoder network. + predictor: Performance prediction network (input: [z[:perf_dim], conditions]). + optimizer: Optimizer instance. + latent_dim: Total number of latent dimensions. + perf_dim: Number of latent dimensions dedicated to performance prediction. + weights: Loss weights [reconstruction, performance, volume]. Default: [1.0, 0.1, 0.001]. + eta: Smoothing parameter for volume loss. Default: 0. + beta: EMA momentum for latent statistics. Default: 0.9. + pruning_epoch: Epoch to start pruning. Default: 500. + plummet_threshold: Ratio threshold for plummet pruning. Default: 0.02. + """ + + def __init__( # noqa: PLR0913 + self, + encoder: nn.Module, + decoder: nn.Module, + predictor: nn.Module, + optimizer: Optimizer, + latent_dim: int, + perf_dim: int, + weights: list[float] | Callable[[int], torch.Tensor] | None = None, + eta: float = 0, + beta: float = 0.9, + pruning_epoch: int = 500, + plummet_threshold: float = 0.02, + ) -> None: + if weights is None: + weights = [1.0, 0.1, 0.001] + super().__init__( + encoder=encoder, + decoder=decoder, + optimizer=optimizer, + latent_dim=latent_dim, + weights=weights, + eta=eta, + beta=beta, + pruning_epoch=pruning_epoch, + plummet_threshold=plummet_threshold, + ) + self.predictor = predictor + self.perf_dim = perf_dim + + def loss(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Compute losses using only first perf_dim latents for performance prediction. + + Args: + batch: Tuple of (designs, conditions, performance_targets). + + Returns: + Tensor of shape (3,) containing [rec_loss, perf_loss, vol_loss]. + """ + x, c, p = batch + z = self.encode(x) + x_hat = self.decode(z) + + # Update moving statistics + self._update_moving_mean(z) + + # Only first perf_dim dimensions for performance prediction + pz = z[:, : self.perf_dim] + p_hat = self.predictor(torch.cat([pz, c], dim=-1)) + + # Volume loss: pruned dims use frozen std (captured at prune time), + # active dims use current std. This makes pruning volume-neutral. + s = self._frozen_std.clone() + if (~self._p).any(): + s[~self._p] = z[:, ~self._p].std(0) + vol_loss = torch.exp(torch.log(s).mean()) + + return torch.stack( + [ + self.loss_rec(x, x_hat), + self.loss_rec(p, p_hat), + vol_loss, + ] + ) + + +__all__ = [ + "InterpretablePerfLeastVolumeAE_DP", + "LeastVolumeAE", + "LeastVolumeAE_DynamicPruning", + "PerfLeastVolumeAE_DP", +] diff --git a/engiopt/lvae_2d/lvae_2d.py b/engiopt/lvae_2d/lvae_2d.py new file mode 100644 index 0000000..efde4ed --- /dev/null +++ b/engiopt/lvae_2d/lvae_2d.py @@ -0,0 +1,455 @@ +"""LVAE for 2D designs with plummet-based dynamic pruning. Adapted from https://github.com/IDEALLab/Least_Volume_ICLR2024.. + +For more information, see: https://arxiv.org/abs/2404.17773 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +from itertools import product +import os +import random +import time + +from engibench.utils.all_problems import BUILTIN_PROBLEMS +import matplotlib.pyplot as plt +import numpy as np +import torch as th +from torch import nn +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data import TensorDataset +from torchvision import transforms +import tqdm +import tyro + +from engiopt.lvae_2d.aes import LeastVolumeAE_DynamicPruning +import wandb + + +@dataclass +class Args: + """Command-line arguments for LVAE training.""" + + # Problem and tracking + problem_id: str = "beams2d" + """Problem ID to run. Must be one of the built-in problems in engibench.""" + algo: str = os.path.basename(__file__)[: -len(".py")] + """Algorithm name for tracking purposes.""" + track: bool = True + """Whether to track with Weights & Biases.""" + wandb_project: str = "engiopt" + """WandB project name.""" + wandb_entity: str | None = None + """WandB entity name. If None, uses the default entity.""" + seed: int = 1 + """Random seed for reproducibility.""" + save_model: bool = False + """Whether to save the model after training.""" + sample_interval: int = 500 + """Interval for sampling designs during training.""" + + # Training parameters + n_epochs: int = 2000 + """Number of training epochs.""" + batch_size: int = 128 + """Batch size for training.""" + lr: float = 1e-4 + """Learning rate for the optimizer.""" + + # LVAE-specific + latent_dim: int = 250 + """Dimensionality of the latent space (overestimate).""" + w_reconstruction: float = 1.0 + """Weight for reconstruction loss.""" + w_volume: float = 0.001 + """Weight for volume loss.""" + + # Pruning parameters + pruning_epoch: int = 500 + """Epoch to start pruning dimensions.""" + plummet_threshold: float = 0.02 + """Threshold for plummet pruning strategy.""" + + # Volume weight warmup + volume_warmup_epochs: int = 0 + """Epochs to polynomially ramp volume weight from 0 to w_volume. 0 disables warmup.""" + volume_warmup_degree: float = 2.0 + """Polynomial degree for volume weight warmup (1.0=linear, 2.0=quadratic).""" + + # Architecture + resize_dimensions: tuple[int, int] = (100, 100) + """Dimensions to resize input images to before encoding/decoding.""" + + +class Encoder(nn.Module): + """Convolutional encoder for 2D designs. + + Architecture: Input -> Conv layers -> Latent vector + • Input [100x100] + • Conv1 [50x50] (k=4, s=2, p=1) + • Conv2 [25x25] (k=4, s=2, p=1) + • Conv3 [13x13] (k=3, s=2, p=1) + • Conv4 [7x7] (k=3, s=2, p=1) + • Conv5 [1x1] (k=7, s=1, p=0) + """ + + def __init__( + self, + latent_dim: int, + design_shape: tuple[int, int], + resize_dimensions: tuple[int, int] = (100, 100), + ) -> None: + super().__init__() + self.resize_in = transforms.Resize(resize_dimensions) + self.design_shape = design_shape + + self.features = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False), # 100->50 + nn.BatchNorm2d(64), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), # 50->25 + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False), # 25->13 + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False), # 13->7 + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + ) + + # Final 7x7 conv produces (B, latent_dim, 1, 1) -> flatten to (B, latent_dim) + self.to_latent = nn.Conv2d(512, latent_dim, kernel_size=7, stride=1, padding=0, bias=True) + + def forward(self, x: th.Tensor) -> th.Tensor: + """Forward pass through encoder.""" + x = self.resize_in(x) # (B,1,100,100) + h = self.features(x) # (B,512,7,7) + return self.to_latent(h).flatten(1) # (B,latent_dim) + + +class Decoder(nn.Module): + """Convolutional decoder for 2D designs. + + Architecture: Latent vector -> Deconv layers -> Output + • Latent [latent_dim] + • Linear [512*7*7] + • Reshape [512x7x7] + • Deconv1 [256x13x13] (k=3, s=2, p=1) + • Deconv2 [128x25x25] (k=3, s=2, p=1) + • Deconv3 [64x50x50] (k=4, s=2, p=1) + • Deconv4 [1x100x100] (k=4, s=2, p=1) + """ + + def __init__( + self, + latent_dim: int, + design_shape: tuple[int, int], + ) -> None: + super().__init__() + self.design_shape = design_shape + self.resize_out = transforms.Resize(self.design_shape) + + # Linear projection to spatial features + self.proj = nn.Sequential( + nn.Linear(latent_dim, 512 * 7 * 7), + nn.ReLU(inplace=True), + ) + + # Deconvolutional layers + self.deconv = nn.Sequential( + # 7->13 + nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=0, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + # 13->25 + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=0, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + # 25->50 + nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=0, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + # 50->100 + nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, output_padding=0, bias=False), + nn.Sigmoid(), + ) + + def forward(self, z: th.Tensor) -> th.Tensor: + """Forward pass through decoder.""" + x = self.proj(z).view(z.size(0), 512, 7, 7) # (B,512,7,7) + x = self.deconv(x) # (B,1,100,100) + return self.resize_out(x) # (B,1,H_orig,W_orig) + + +def volume_weight_schedule(epoch: int, w_rec: float, w_vol: float, warmup_epochs: int, degree: float) -> th.Tensor: + """Compute weights with polynomial ramp on volume weight. + + Args: + epoch: Current epoch. + w_rec: Reconstruction weight (constant). + w_vol: Final volume weight after warmup. + warmup_epochs: Epochs to ramp volume weight from 0 to w_vol. + degree: Polynomial degree (1.0=linear, 2.0=quadratic). + + Returns: + Tensor [w_rec, current_w_vol] where current_w_vol ramps polynomially. + """ + if warmup_epochs <= 0: + return th.tensor([w_rec, w_vol], dtype=th.float) + t = min(epoch / warmup_epochs, 1.0) + return th.tensor([w_rec, w_vol * (t**degree)], dtype=th.float) + + +if __name__ == "__main__": + args = tyro.cli(Args) + + problem = BUILTIN_PROBLEMS[args.problem_id]() + problem.reset(seed=args.seed) + + design_shape = problem.design_space.shape + + # Logging + run_name = f"{args.problem_id}__{args.algo}__{args.seed}__{int(time.time())}" + if args.track: + wandb.init( + project=args.wandb_project, + entity=args.wandb_entity, + config=vars(args), + save_code=True, + name=run_name, + ) + + # Seeding for reproducibility + th.manual_seed(args.seed) + th.cuda.manual_seed_all(args.seed) + rng = np.random.default_rng(args.seed) + random.seed(args.seed) + th.backends.cudnn.deterministic = True + th.backends.cudnn.benchmark = False + g = th.Generator().manual_seed(args.seed) # For DataLoader shuffling + + os.makedirs("images", exist_ok=True) + + if th.backends.mps.is_available(): + device = th.device("mps") + elif th.cuda.is_available(): + device = th.device("cuda") + else: + device = th.device("cpu") + + # Build encoder and decoder + enc = Encoder(args.latent_dim, design_shape, args.resize_dimensions) + dec = Decoder(args.latent_dim, design_shape) + + # Weight schedule (ramps volume weight if warmup_epochs > 0, otherwise constant) + weights = partial( + volume_weight_schedule, + w_rec=args.w_reconstruction, + w_vol=args.w_volume, + warmup_epochs=args.volume_warmup_epochs, + degree=args.volume_warmup_degree, + ) + + # Initialize LVAE with dynamic pruning + lvae = LeastVolumeAE_DynamicPruning( + encoder=enc, + decoder=dec, + optimizer=Adam(list(enc.parameters()) + list(dec.parameters()), lr=args.lr), + latent_dim=args.latent_dim, + weights=weights, + pruning_epoch=args.pruning_epoch, + plummet_threshold=args.plummet_threshold, + ).to(device) + + print(f"\n{'=' * 60}") + print("LVAE Training") + print(f"Problem: {args.problem_id}") + print(f"Latent dim: {args.latent_dim}") + print(f"Pruning epoch: {args.pruning_epoch}") + print(f"Plummet threshold: {args.plummet_threshold}") + print(f"Volume warmup epochs: {args.volume_warmup_epochs}") + print(f"{'=' * 60}\n") + + # ---- DataLoader ---- + hf = problem.dataset.with_format("torch") + train_ds = hf["train"] + val_ds = hf["val"] + + x_train = train_ds["optimal_design"][:].unsqueeze(1) + x_val = val_ds["optimal_design"][:].unsqueeze(1) + + loader = DataLoader(TensorDataset(x_train), batch_size=args.batch_size, shuffle=True, generator=g) + val_loader = DataLoader(TensorDataset(x_val), batch_size=args.batch_size, shuffle=False) + + # ---- Training loop ---- + for epoch in range(args.n_epochs): + lvae.epoch_hook(epoch=epoch) + + bar = tqdm.tqdm(loader, desc=f"Epoch {epoch}") + for i, batch in enumerate(bar): + x_batch = batch[0].to(device) + lvae.optim.zero_grad() + + # Compute loss components (rec, vol) + losses = lvae.loss(x_batch) + + # Weighted sum for backprop + loss = (losses * lvae.w).sum() + loss.backward() + lvae.optim.step() + + bar.set_postfix( + { + "rec": f"{losses[0].item():.4f}", + "vol": f"{losses[1].item():.4f}", + "dim": lvae.dim, + } + ) + + # Log to wandb + if args.track: + batches_done = epoch * len(bar) + i + + log_dict = { + "rec_loss": losses[0].item(), + "vol_loss": losses[1].item(), + "total_loss": loss.item(), + "active_dims": lvae.dim, + "epoch": epoch, + "w_volume": lvae.w[1].item(), + } + wandb.log(log_dict) + + print( + f"[Epoch {epoch}/{args.n_epochs}] [Batch {i}/{len(bar)}] " + f"[rec loss: {losses[0].item():.4f}] [vol loss: {losses[1].item():.4f}] " + f"[active dims: {lvae.dim}]" + ) + + # Sample and visualize at regular intervals + if batches_done % args.sample_interval == 0: + with th.no_grad(): + # Encode training designs + xs = x_train.to(device) + z = lvae.encode(xs) + z_std, idx = th.sort(z.std(0), descending=True) + z_mean = z.mean(0) + n_active = (z_std > 0).sum().item() + + # Generate interpolated designs + x_ints = [] + for alpha in [0, 0.25, 0.5, 0.75, 1]: + z_ = (1 - alpha) * z[:25] + alpha * th.roll(z, -1, 0)[:25] + x_ints.append(lvae.decode(z_).cpu().numpy()) + + # Generate random designs + z_rand = z_mean.unsqueeze(0).repeat([25, 1]) + z_rand[:, idx[:n_active]] += z_std[:n_active] * th.randn_like(z_rand[:, idx[:n_active]]) + x_rand = lvae.decode(z_rand).cpu().numpy() + + # Move tensors to CPU for plotting + z_std_cpu = z_std.cpu().numpy() + xs_cpu = xs.cpu().numpy() + + # Plot 1: Latent dimension statistics + plt.figure(figsize=(12, 6)) + plt.subplot(211) + plt.bar(np.arange(len(z_std_cpu)), z_std_cpu) + plt.yscale("log") + plt.xlabel("Latent dimension index") + plt.ylabel("Standard deviation") + plt.title(f"Number of principal components = {n_active}") + plt.subplot(212) + plt.bar(np.arange(n_active), z_std_cpu[:n_active]) + plt.yscale("log") + plt.xlabel("Latent dimension index") + plt.ylabel("Standard deviation") + plt.savefig(f"images/dim_{batches_done}.png") + plt.close() + + # Plot 2: Interpolated designs + fig, axs = plt.subplots(25, 6, figsize=(12, 25)) + for i_row, j in product(range(25), range(5)): + axs[i_row, j + 1].imshow(x_ints[j][i_row].reshape(design_shape)) + axs[i_row, j + 1].axis("off") + axs[i_row, j + 1].set_aspect("equal") + for ax, alpha in zip(axs[0, 1:], [0, 0.25, 0.5, 0.75, 1]): + ax.set_title(rf"$\alpha$ = {alpha}") + for i_row in range(25): + axs[i_row, 0].imshow(xs_cpu[i_row].reshape(design_shape)) + axs[i_row, 0].axis("off") + axs[i_row, 0].set_aspect("equal") + axs[0, 0].set_title("groundtruth") + fig.tight_layout() + plt.savefig(f"images/interp_{batches_done}.png") + plt.close() + + # Plot 3: Random designs from latent space + fig, axs = plt.subplots(5, 5, figsize=(15, 7.5)) + for k, (i_row, j) in enumerate(product(range(5), range(5))): + axs[i_row, j].imshow(x_rand[k].reshape(design_shape)) + axs[i_row, j].axis("off") + axs[i_row, j].set_aspect("equal") + fig.tight_layout() + plt.suptitle("Gaussian random designs from latent space") + plt.savefig(f"images/norm_{batches_done}.png") + plt.close() + + # Log plots to wandb + wandb.log( + { + "dim_plot": wandb.Image(f"images/dim_{batches_done}.png"), + "interp_plot": wandb.Image(f"images/interp_{batches_done}.png"), + "norm_plot": wandb.Image(f"images/norm_{batches_done}.png"), + } + ) + + # ---- Validation ---- + with th.no_grad(): + lvae.eval() + val_rec = val_vol = 0.0 + n = 0 + for batch_v in val_loader: + x_v = batch_v[0].to(device) + vlosses = lvae.loss(x_v) + bsz = x_v.size(0) + val_rec += vlosses[0].item() * bsz + val_vol += vlosses[1].item() * bsz + n += bsz + val_rec /= n + val_vol /= n + + # Trigger pruning check at end of epoch + lvae.epoch_report(epoch=epoch, callbacks=[], batch=None, loss=losses, pbar=None) + + if args.track: + val_log_dict = { + "epoch": epoch, + "val_rec": val_rec, + "val_vol_loss": val_vol, + } + wandb.log(val_log_dict, commit=True) + + th.cuda.empty_cache() + lvae.train() + + # Save models at end of training + if args.save_model and epoch == args.n_epochs - 1: + ckpt_lvae = { + "epoch": epoch, + "encoder": lvae.encoder.state_dict(), + "decoder": lvae.decoder.state_dict(), + "optimizer": lvae.optim.state_dict(), + "args": vars(args), + } + th.save(ckpt_lvae, "lvae.pth") + if args.track: + artifact = wandb.Artifact(f"{args.problem_id}_{args.algo}", type="model") + artifact.add_file("lvae.pth") + wandb.log_artifact(artifact, aliases=[f"seed_{args.seed}"]) + + if args.track: + wandb.finish() diff --git a/engiopt/lvae_2d/plvae_2d.py b/engiopt/lvae_2d/plvae_2d.py new file mode 100644 index 0000000..1989dda --- /dev/null +++ b/engiopt/lvae_2d/plvae_2d.py @@ -0,0 +1,592 @@ +"""Performance-LVAE for 2D designs with plummet-based dynamic pruning. Adapted from https://github.com/IDEALLab/Least_Volume_ICLR2024. + +Configuration: +- perf_dim: Number of latent dimensions dedicated to performance (default: all latent_dim dimensions) + - Use perf_dim < latent_dim for interpretable mode + - Use perf_dim = latent_dim (default) for regular mode + +For more information, see: https://arxiv.org/abs/2404.17773 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +from itertools import product +import os +import random +import time + +from engibench.utils.all_problems import BUILTIN_PROBLEMS +import matplotlib.pyplot as plt +import numpy as np +from sklearn.preprocessing import RobustScaler +import torch as th +from torch import nn +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data import TensorDataset +from torchvision import transforms +import tqdm +import tyro + +from engiopt.lvae_2d.aes import InterpretablePerfLeastVolumeAE_DP +import wandb + + +@dataclass +class Args: + """Command-line arguments for Performance-LVAE training.""" + + # Problem and tracking + problem_id: str = "beams2d" + """Problem ID to run. Must be one of the built-in problems in engibench.""" + algo: str = os.path.basename(__file__)[: -len(".py")] + """Algorithm name for tracking purposes.""" + track: bool = True + """Whether to track with Weights & Biases.""" + wandb_project: str = "engiopt" + """WandB project name.""" + wandb_entity: str | None = None + """WandB entity name. If None, uses the default entity.""" + seed: int = 1 + """Random seed for reproducibility.""" + save_model: bool = False + """Whether to save the model after training.""" + sample_interval: int = 500 + """Interval for sampling designs during training.""" + + # Training parameters + n_epochs: int = 2000 + """Number of training epochs.""" + batch_size: int = 128 + """Batch size for training.""" + lr: float = 1e-4 + """Learning rate for the optimizer.""" + + # LVAE-specific + latent_dim: int = 250 + """Dimensionality of the latent space (overestimate).""" + perf_dim: int = -1 + """Number of latent dimensions dedicated to performance prediction. If -1 (default), uses all latent_dim dimensions.""" + w_reconstruction: float = 1.0 + """Weight for reconstruction loss.""" + w_performance: float = 0.1 + """Weight for performance loss.""" + w_volume: float = 0.001 + """Weight for volume loss.""" + + # Pruning parameters + pruning_epoch: int = 500 + """Epoch to start pruning dimensions.""" + plummet_threshold: float = 0.02 + """Threshold for plummet pruning strategy.""" + + # Volume weight warmup + volume_warmup_epochs: int = 0 + """Epochs to polynomially ramp volume weight from 0 to w_volume. 0 disables warmup.""" + volume_warmup_degree: float = 2.0 + """Polynomial degree for volume weight warmup (1.0=linear, 2.0=quadratic).""" + + # Architecture + resize_dimensions: tuple[int, int] = (100, 100) + """Dimensions to resize input images to before encoding/decoding.""" + predictor_hidden_dims: tuple[int, ...] = (256, 128) + """Hidden dimensions for the MLP predictor.""" + conditional_predictor: bool = True + """Whether to include conditions in performance prediction (True) or use only latent codes (False).""" + + +class Encoder(nn.Module): + """Convolutional encoder for 2D designs. + + Architecture: Input -> Conv layers -> Latent vector + • Input [100x100] + • Conv1 [50x50] (k=4, s=2, p=1) + • Conv2 [25x25] (k=4, s=2, p=1) + • Conv3 [13x13] (k=3, s=2, p=1) + • Conv4 [7x7] (k=3, s=2, p=1) + • Conv5 [1x1] (k=7, s=1, p=0) + """ + + def __init__( + self, + latent_dim: int, + design_shape: tuple[int, int], + resize_dimensions: tuple[int, int] = (100, 100), + ) -> None: + super().__init__() + self.resize_in = transforms.Resize(resize_dimensions) + self.design_shape = design_shape + + self.features = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False), # 100->50 + nn.BatchNorm2d(64), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), # 50->25 + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False), # 25->13 + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False), # 13->7 + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + ) + + # Final 7x7 conv produces (B, latent_dim, 1, 1) -> flatten to (B, latent_dim) + self.to_latent = nn.Conv2d(512, latent_dim, kernel_size=7, stride=1, padding=0, bias=True) + + def forward(self, x: th.Tensor) -> th.Tensor: + """Forward pass through encoder.""" + x = self.resize_in(x) # (B,1,100,100) + h = self.features(x) # (B,512,7,7) + return self.to_latent(h).flatten(1) # (B,latent_dim) + + +class Decoder(nn.Module): + """Convolutional decoder for 2D designs. + + Architecture: Latent vector -> Deconv layers -> Output + • Latent [latent_dim] + • Linear [512*7*7] + • Reshape [512x7x7] + • Deconv1 [256x13x13] (k=3, s=2, p=1) + • Deconv2 [128x25x25] (k=3, s=2, p=1) + • Deconv3 [64x50x50] (k=4, s=2, p=1) + • Deconv4 [1x100x100] (k=4, s=2, p=1) + """ + + def __init__( + self, + latent_dim: int, + design_shape: tuple[int, int], + ) -> None: + super().__init__() + self.design_shape = design_shape + self.resize_out = transforms.Resize(self.design_shape) + + # Linear projection to spatial features + self.proj = nn.Sequential( + nn.Linear(latent_dim, 512 * 7 * 7), + nn.ReLU(inplace=True), + ) + + # Deconvolutional layers + self.deconv = nn.Sequential( + # 7->13 + nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=0, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + # 13->25 + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=0, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + # 25->50 + nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=0, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + # 50->100 + nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, output_padding=0, bias=False), + nn.Sigmoid(), + ) + + def forward(self, z: th.Tensor) -> th.Tensor: + """Forward pass through decoder.""" + x = self.proj(z).view(z.size(0), 512, 7, 7) # (B,512,7,7) + x = self.deconv(x) # (B,1,100,100) + return self.resize_out(x) # (B,1,H_orig,W_orig) + + +class MLPPredictor(nn.Module): + """MLP that predicts performance from latent codes + conditions. + + Uses LeakyReLU activations in hidden layers and no activation on output. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dims: tuple[int, ...] = (256, 128), + ) -> None: + super().__init__() + layers: list[nn.Module] = [] + prev_dim = input_dim + for hidden_dim in hidden_dims: + layers.append(nn.Linear(prev_dim, hidden_dim)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + prev_dim = hidden_dim + # Final layer: Linear (no activation) + layers.append(nn.Linear(prev_dim, output_dim)) + self.net = nn.Sequential(*layers) + + def forward(self, x: th.Tensor) -> th.Tensor: + """Predict performance from latent codes + conditions.""" + return self.net(x) + + +def volume_weight_schedule( # noqa: PLR0913 + epoch: int, w_rec: float, w_perf: float, w_vol: float, warmup_epochs: int, degree: float +) -> th.Tensor: + """Compute weights with polynomial ramp on volume weight. + + Args: + epoch: Current epoch. + w_rec: Reconstruction weight (constant). + w_perf: Performance weight (constant). + w_vol: Final volume weight after warmup. + warmup_epochs: Epochs to ramp volume weight from 0 to w_vol. + degree: Polynomial degree (1.0=linear, 2.0=quadratic). + + Returns: + Tensor [w_rec, w_perf, current_w_vol] where current_w_vol ramps polynomially. + """ + if warmup_epochs <= 0: + return th.tensor([w_rec, w_perf, w_vol], dtype=th.float) + t = min(epoch / warmup_epochs, 1.0) + return th.tensor([w_rec, w_perf, w_vol * (t**degree)], dtype=th.float) + + +if __name__ == "__main__": + args = tyro.cli(Args) + + problem = BUILTIN_PROBLEMS[args.problem_id]() + problem.reset(seed=args.seed) + + design_shape = problem.design_space.shape + conditions = problem.conditions_keys + n_conds = len(conditions) + + # Logging + run_name = f"{args.problem_id}__{args.algo}__{args.seed}__{int(time.time())}" + if args.track: + wandb.init( + project=args.wandb_project, + entity=args.wandb_entity, + config=vars(args), + save_code=True, + name=run_name, + ) + + # Seeding for reproducibility + th.manual_seed(args.seed) + th.cuda.manual_seed_all(args.seed) + rng = np.random.default_rng(args.seed) + random.seed(args.seed) + th.backends.cudnn.deterministic = True + th.backends.cudnn.benchmark = False + g = th.Generator().manual_seed(args.seed) # For DataLoader shuffling + + os.makedirs("images", exist_ok=True) + + if th.backends.mps.is_available(): + device = th.device("mps") + elif th.cuda.is_available(): + device = th.device("cuda") + else: + device = th.device("cpu") + + # Build encoder and decoder + enc = Encoder(args.latent_dim, design_shape, args.resize_dimensions) + dec = Decoder(args.latent_dim, design_shape) + + # Determine perf_dim: if -1 (default), use all latent dimensions + perf_dim = args.latent_dim if args.perf_dim == -1 else args.perf_dim + n_perf = 1 # Single performance objective + + # Build MLP predictor (input: perf_dim latent dims + conditions if conditional) + predictor_input_dim = perf_dim + (n_conds if args.conditional_predictor else 0) + predictor = MLPPredictor( + input_dim=predictor_input_dim, + output_dim=n_perf, + hidden_dims=args.predictor_hidden_dims, + ) + + print(f"\n{'=' * 60}") + print("Performance-LVAE Training") + print(f"Problem: {args.problem_id}") + print(f"Latent dim: {args.latent_dim}") + print(f"Perf dim: {perf_dim} (first {perf_dim} dims predict performance)") + print(f"Predictor mode: {'Conditional' if args.conditional_predictor else 'Unconditional'}") + print( + f"Predictor input: {predictor_input_dim} (perf_dim={perf_dim}, n_conds={n_conds if args.conditional_predictor else 0})" + ) + print(f"Pruning epoch: {args.pruning_epoch}") + print(f"Plummet threshold: {args.plummet_threshold}") + print(f"Volume warmup epochs: {args.volume_warmup_epochs}") + print(f"{'=' * 60}\n") + + # Weight schedule (ramps volume weight if warmup_epochs > 0, otherwise constant) + weights = partial( + volume_weight_schedule, + w_rec=args.w_reconstruction, + w_perf=args.w_performance, + w_vol=args.w_volume, + warmup_epochs=args.volume_warmup_epochs, + degree=args.volume_warmup_degree, + ) + + # Initialize Performance-LVAE with dynamic pruning + plvae = InterpretablePerfLeastVolumeAE_DP( + encoder=enc, + decoder=dec, + predictor=predictor, + optimizer=Adam( + list(enc.parameters()) + list(dec.parameters()) + list(predictor.parameters()), + lr=args.lr, + ), + latent_dim=args.latent_dim, + perf_dim=perf_dim, + weights=weights, + pruning_epoch=args.pruning_epoch, + plummet_threshold=args.plummet_threshold, + ).to(device) + + # ---- DataLoader ---- + hf = problem.dataset.with_format("torch") + train_ds = hf["train"] + val_ds = hf["val"] + + # Extract designs, conditions, and performance + x_train = train_ds["optimal_design"][:].unsqueeze(1) + c_train = th.stack([train_ds[key][:] for key in problem.conditions_keys], dim=-1) + p_train = train_ds[problem.objectives_keys[0]][:].unsqueeze(-1) # (N, 1) + + x_val = val_ds["optimal_design"][:].unsqueeze(1) + c_val = th.stack([val_ds[key][:] for key in problem.conditions_keys], dim=-1) + p_val = val_ds[problem.objectives_keys[0]][:].unsqueeze(-1) + + # Scale performance values using RobustScaler + p_scaler = RobustScaler() + p_train_scaled = th.from_numpy(p_scaler.fit_transform(p_train.numpy())).to(p_train.dtype) + p_val_scaled = th.from_numpy(p_scaler.transform(p_val.numpy())).to(p_val.dtype) + + # Scale conditions using RobustScaler (if using conditional predictor) + if args.conditional_predictor: + c_scaler = RobustScaler() + c_train_scaled = th.from_numpy(c_scaler.fit_transform(c_train.numpy())).to(c_train.dtype) + c_val_scaled = th.from_numpy(c_scaler.transform(c_val.numpy())).to(c_val.dtype) + else: + # Dummy tensors when not using conditions (won't be used in predictor) + c_train_scaled = th.zeros(len(x_train), 0) + c_val_scaled = th.zeros(len(x_val), 0) + + loader = DataLoader( + TensorDataset(x_train, c_train_scaled, p_train_scaled), + batch_size=args.batch_size, + shuffle=True, + generator=g, + ) + val_loader = DataLoader( + TensorDataset(x_val, c_val_scaled, p_val_scaled), + batch_size=args.batch_size, + shuffle=False, + ) + + # ---- Training loop ---- + for epoch in range(args.n_epochs): + plvae.epoch_hook(epoch=epoch) + + bar = tqdm.tqdm(loader, desc=f"Epoch {epoch}") + for i, batch in enumerate(bar): + x_batch = batch[0].to(device) + c_batch = batch[1].to(device) + p_batch = batch[2].to(device) + + plvae.optim.zero_grad() + + # Compute loss components (rec, perf, vol) + losses = plvae.loss((x_batch, c_batch, p_batch)) + + # Weighted sum for backprop + loss = (losses * plvae.w).sum() + loss.backward() + plvae.optim.step() + + bar.set_postfix( + { + "rec": f"{losses[0].item():.4f}", + "perf": f"{losses[1].item():.4f}", + "vol": f"{losses[2].item():.4f}", + "dim": plvae.dim, + } + ) + + # Log to wandb + if args.track: + batches_done = epoch * len(bar) + i + + log_dict = { + "rec_loss": losses[0].item(), + "perf_loss": losses[1].item(), + "vol_loss": losses[2].item(), + "total_loss": loss.item(), + "active_dims": plvae.dim, + "epoch": epoch, + "w_volume": plvae.w[2].item(), + } + wandb.log(log_dict) + + print( + f"[Epoch {epoch}/{args.n_epochs}] [Batch {i}/{len(bar)}] " + f"[rec loss: {losses[0].item():.4f}] [perf loss: {losses[1].item():.4f}] " + f"[vol loss: {losses[2].item():.4f}] [active dims: {plvae.dim}]" + ) + + # Sample and visualize at regular intervals + if batches_done % args.sample_interval == 0: + with th.no_grad(): + # Encode training designs + xs = x_train.to(device) + z = plvae.encode(xs) + z_std, idx = th.sort(z.std(0), descending=True) + z_mean = z.mean(0) + n_active = (z_std > 0).sum().item() + + # Generate interpolated designs + x_ints = [] + for alpha in [0, 0.25, 0.5, 0.75, 1]: + z_ = (1 - alpha) * z[:25] + alpha * th.roll(z, -1, 0)[:25] + x_ints.append(plvae.decode(z_).cpu().numpy()) + + # Generate random designs + z_rand = z_mean.unsqueeze(0).repeat([25, 1]) + z_rand[:, idx[:n_active]] += z_std[:n_active] * th.randn_like(z_rand[:, idx[:n_active]]) + x_rand = plvae.decode(z_rand).cpu().numpy() + + # Get performance predictions on training data + pz_train = z[:, :perf_dim] + p_pred_scaled = plvae.predictor(th.cat([pz_train, c_train.to(device)], dim=-1)) + + # Inverse transform to get true-scale values for plotting + p_actual = p_scaler.inverse_transform(p_train_scaled.cpu().numpy()).flatten() + p_predicted = p_scaler.inverse_transform(p_pred_scaled.cpu().numpy()).flatten() + + # Move tensors to CPU for plotting + z_std_cpu = z_std.cpu().numpy() + xs_cpu = xs.cpu().numpy() + + # Plot 1: Latent dimension statistics + plt.figure(figsize=(12, 6)) + plt.subplot(211) + plt.bar(np.arange(len(z_std_cpu)), z_std_cpu) + plt.yscale("log") + plt.xlabel("Latent dimension index") + plt.ylabel("Standard deviation") + plt.title(f"Number of principal components = {n_active}") + plt.subplot(212) + plt.bar(np.arange(n_active), z_std_cpu[:n_active]) + plt.yscale("log") + plt.xlabel("Latent dimension index") + plt.ylabel("Standard deviation") + plt.savefig(f"images/dim_{batches_done}.png") + plt.close() + + # Plot 2: Interpolated designs + fig, axs = plt.subplots(25, 6, figsize=(12, 25)) + for i_row, j in product(range(25), range(5)): + axs[i_row, j + 1].imshow(x_ints[j][i_row].reshape(design_shape)) + axs[i_row, j + 1].axis("off") + axs[i_row, j + 1].set_aspect("equal") + for ax, alpha in zip(axs[0, 1:], [0, 0.25, 0.5, 0.75, 1]): + ax.set_title(rf"$\alpha$ = {alpha}") + for i_row in range(25): + axs[i_row, 0].imshow(xs_cpu[i_row].reshape(design_shape)) + axs[i_row, 0].axis("off") + axs[i_row, 0].set_aspect("equal") + axs[0, 0].set_title("groundtruth") + fig.tight_layout() + plt.savefig(f"images/interp_{batches_done}.png") + plt.close() + + # Plot 3: Random designs from latent space + fig, axs = plt.subplots(5, 5, figsize=(15, 7.5)) + for k, (i_row, j) in enumerate(product(range(5), range(5))): + axs[i_row, j].imshow(x_rand[k].reshape(design_shape)) + axs[i_row, j].axis("off") + axs[i_row, j].set_aspect("equal") + fig.tight_layout() + plt.suptitle("Gaussian random designs from latent space") + plt.savefig(f"images/norm_{batches_done}.png") + plt.close() + + # Plot 4: Predicted vs actual performance + plt.figure(figsize=(8, 8)) + plt.scatter(p_actual, p_predicted, alpha=0.5, s=20) + min_val = min(p_actual.min(), p_predicted.min()) + max_val = max(p_actual.max(), p_predicted.max()) + plt.plot([min_val, max_val], [min_val, max_val], "r--", linewidth=2, label="1:1 line") + plt.xlabel("Actual Performance") + plt.ylabel("Predicted Performance") + mse_value = np.mean((p_actual - p_predicted) ** 2) + plt.title(f"MSE: {mse_value:.4e}") + plt.grid(visible=True, alpha=0.3) + plt.legend() + plt.axis("equal") + plt.tight_layout() + plt.savefig(f"images/perf_pred_vs_actual_{batches_done}.png") + plt.close() + + # Log plots to wandb + wandb.log( + { + "dim_plot": wandb.Image(f"images/dim_{batches_done}.png"), + "interp_plot": wandb.Image(f"images/interp_{batches_done}.png"), + "norm_plot": wandb.Image(f"images/norm_{batches_done}.png"), + "perf_pred_vs_actual": wandb.Image(f"images/perf_pred_vs_actual_{batches_done}.png"), + } + ) + + # ---- Validation ---- + with th.no_grad(): + plvae.eval() + val_rec = val_perf = val_vol = 0.0 + n = 0 + for batch_v in val_loader: + x_v = batch_v[0].to(device) + c_v = batch_v[1].to(device) + p_v = batch_v[2].to(device) + vlosses = plvae.loss((x_v, c_v, p_v)) + bsz = x_v.size(0) + val_rec += vlosses[0].item() * bsz + val_perf += vlosses[1].item() * bsz + val_vol += vlosses[2].item() * bsz + n += bsz + val_rec /= n + val_perf /= n + val_vol /= n + + # Trigger pruning check at end of epoch + plvae.epoch_report(epoch=epoch, callbacks=[], batch=None, loss=losses, pbar=None) + + if args.track: + val_log_dict = { + "epoch": epoch, + "val_rec": val_rec, + "val_perf": val_perf, + "val_vol_loss": val_vol, + } + wandb.log(val_log_dict, commit=True) + + th.cuda.empty_cache() + plvae.train() + + # Save models at end of training + if args.save_model and epoch == args.n_epochs - 1: + ckpt_plvae = { + "epoch": epoch, + "encoder": plvae.encoder.state_dict(), + "decoder": plvae.decoder.state_dict(), + "predictor": plvae.predictor.state_dict(), + "optimizer": plvae.optim.state_dict(), + "args": vars(args), + } + th.save(ckpt_plvae, "plvae.pth") + if args.track: + artifact = wandb.Artifact(f"{args.problem_id}_{args.algo}", type="model") + artifact.add_file("plvae.pth") + wandb.log_artifact(artifact, aliases=[f"seed_{args.seed}"]) + + if args.track: + wandb.finish() From 7c8b32d23f34b65a981f907d20108b34490ecee7 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Thu, 29 Jan 2026 23:36:47 +0100 Subject: [PATCH 2/4] ruff fixes for PR --- engiopt/lvae_2d/__init__.py | 0 engiopt/lvae_2d/aes.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 engiopt/lvae_2d/__init__.py diff --git a/engiopt/lvae_2d/__init__.py b/engiopt/lvae_2d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/engiopt/lvae_2d/aes.py b/engiopt/lvae_2d/aes.py index e0e3ef1..c34cb21 100644 --- a/engiopt/lvae_2d/aes.py +++ b/engiopt/lvae_2d/aes.py @@ -268,7 +268,7 @@ def _plummet_prune(self, z_std: torch.Tensor) -> torch.Tensor: Boolean mask where True indicates dimensions to prune. """ # Sort variances in descending order - srt, idx = torch.sort(z_std, descending=True) + srt, _idx = torch.sort(z_std, descending=True) # Compute log-space drops log_srt = (srt + 1e-12).log() From 210ff35a325d5d5b6588ee535a2158698d356b96 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Mon, 2 Feb 2026 14:41:51 +0100 Subject: [PATCH 3/4] bugfix for performance visualization --- engiopt/lvae_2d/plvae_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engiopt/lvae_2d/plvae_2d.py b/engiopt/lvae_2d/plvae_2d.py index 1989dda..c0de8be 100644 --- a/engiopt/lvae_2d/plvae_2d.py +++ b/engiopt/lvae_2d/plvae_2d.py @@ -457,7 +457,7 @@ def volume_weight_schedule( # noqa: PLR0913 # Get performance predictions on training data pz_train = z[:, :perf_dim] - p_pred_scaled = plvae.predictor(th.cat([pz_train, c_train.to(device)], dim=-1)) + p_pred_scaled = plvae.predictor(th.cat([pz_train, c_train_scaled.to(device)], dim=-1)) # Inverse transform to get true-scale values for plotting p_actual = p_scaler.inverse_transform(p_train_scaled.cpu().numpy()).flatten() From d087b8ecfad6272c964522f92ce8eb0b1b295b58 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Mon, 2 Feb 2026 18:12:17 +0100 Subject: [PATCH 4/4] modified default hyperparameters --- engiopt/lvae_2d/lvae_2d.py | 6 +++--- engiopt/lvae_2d/plvae_2d.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/engiopt/lvae_2d/lvae_2d.py b/engiopt/lvae_2d/lvae_2d.py index efde4ed..bdbd032 100644 --- a/engiopt/lvae_2d/lvae_2d.py +++ b/engiopt/lvae_2d/lvae_2d.py @@ -51,7 +51,7 @@ class Args: """Interval for sampling designs during training.""" # Training parameters - n_epochs: int = 2000 + n_epochs: int = 5000 """Number of training epochs.""" batch_size: int = 128 """Batch size for training.""" @@ -63,7 +63,7 @@ class Args: """Dimensionality of the latent space (overestimate).""" w_reconstruction: float = 1.0 """Weight for reconstruction loss.""" - w_volume: float = 0.001 + w_volume: float = 0.01 """Weight for volume loss.""" # Pruning parameters @@ -73,7 +73,7 @@ class Args: """Threshold for plummet pruning strategy.""" # Volume weight warmup - volume_warmup_epochs: int = 0 + volume_warmup_epochs: int = 100 """Epochs to polynomially ramp volume weight from 0 to w_volume. 0 disables warmup.""" volume_warmup_degree: float = 2.0 """Polynomial degree for volume weight warmup (1.0=linear, 2.0=quadratic).""" diff --git a/engiopt/lvae_2d/plvae_2d.py b/engiopt/lvae_2d/plvae_2d.py index c0de8be..2e184d1 100644 --- a/engiopt/lvae_2d/plvae_2d.py +++ b/engiopt/lvae_2d/plvae_2d.py @@ -57,7 +57,7 @@ class Args: """Interval for sampling designs during training.""" # Training parameters - n_epochs: int = 2000 + n_epochs: int = 5000 """Number of training epochs.""" batch_size: int = 128 """Batch size for training.""" @@ -73,7 +73,7 @@ class Args: """Weight for reconstruction loss.""" w_performance: float = 0.1 """Weight for performance loss.""" - w_volume: float = 0.001 + w_volume: float = 0.01 """Weight for volume loss.""" # Pruning parameters @@ -83,7 +83,7 @@ class Args: """Threshold for plummet pruning strategy.""" # Volume weight warmup - volume_warmup_epochs: int = 0 + volume_warmup_epochs: int = 100 """Epochs to polynomially ramp volume weight from 0 to w_volume. 0 disables warmup.""" volume_warmup_degree: float = 2.0 """Polynomial degree for volume weight warmup (1.0=linear, 2.0=quadratic).""" @@ -93,7 +93,7 @@ class Args: """Dimensions to resize input images to before encoding/decoding.""" predictor_hidden_dims: tuple[int, ...] = (256, 128) """Hidden dimensions for the MLP predictor.""" - conditional_predictor: bool = True + conditional_predictor: bool = False """Whether to include conditions in performance prediction (True) or use only latent codes (False)."""