Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
482c85d
add initial vqgan files
arthurdrake1 Sep 23, 2025
be14a17
add windows support and add vqgan description
arthurdrake1 Sep 23, 2025
1398f41
add vqgan args and some new image transform functions
arthurdrake1 Sep 23, 2025
e334900
add stage 1 submodules
arthurdrake1 Sep 23, 2025
8df08db
clean up stage 1 modules
arthurdrake1 Sep 24, 2025
e415de7
add remaining cleaned up stage 1 modules
arthurdrake1 Sep 24, 2025
97efc89
add vqgan stage 1 main class and resolve arg names
arthurdrake1 Sep 24, 2025
f9d565d
add vqgan stage 2 modules
arthurdrake1 Sep 25, 2025
5b9625f
add initial training loops and conditions augmentations
arthurdrake1 Sep 27, 2025
2be8e23
Merge branch 'main' into vqgan
arthurdrake1 Sep 28, 2025
c0e4ae0
add training tracking
arthurdrake1 Sep 28, 2025
93cc9cd
add image logging to training
arthurdrake1 Sep 29, 2025
ff937fa
simplify args and transforms; remove unneeded transformer code
arthurdrake1 Sep 30, 2025
e7bafce
minor plotting fixes
arthurdrake1 Sep 30, 2025
b845834
add initial vqgan eval script
arthurdrake1 Sep 30, 2025
c543345
minor eval conditions fixes
arthurdrake1 Sep 30, 2025
ef8aa72
ruff fixes and add dependencies
arthurdrake1 Sep 30, 2025
992a3fd
add early stopping arg for transformer
arthurdrake1 Oct 2, 2025
f0b36d8
fix early stopping model saving
arthurdrake1 Oct 3, 2025
a0437b4
clarify var names, add comments, other minor fixes
arthurdrake1 Oct 18, 2025
a29fd33
move helper blocks and utils to their own file
arthurdrake1 Oct 18, 2025
d88a48b
remove default comments from args
arthurdrake1 Oct 22, 2025
5577396
remove unused args
arthurdrake1 Oct 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
*.py text eol=lf
*.json text eol=lf
*.yml text eol=lf
*.yaml text eol=lf
*.sh text eol=lf
*.md text eol=lf
*.txt text eol=lf
*.ipynb text eol=lf

*.png binary
*.jpg binary
*.jpeg binary
*.gif binary
*.pdf binary
*.pkl binary
*.npy binary
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ share/python-wheels/
*.egg
MANIFEST

# Windows local environment files
pyvenv.cfg
Scripts/

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
Expand Down Expand Up @@ -124,6 +128,7 @@ celerybeat.pid
# Environments
.env
.venv
engiopt_env/
env/
venv/
ENV/
Expand Down Expand Up @@ -156,6 +161,8 @@ cython_debug/

wandb/*
images/*
logs/*
*.csv
# Editors
.idea/
.vscode/
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ As much as we can, we follow the [CleanRL](https://github.com/vwxyzjn/cleanrl) p
[gan_bezier](engiopt/gan_bezier/) | Inverse Design | 1D | ❌ | GAN + Bezier layer
[gan_cnn_2d](engiopt/gan_cnn_2d/) | Inverse Design | 2D | ❌ | GAN + CNN
[surrogate_model](engiopt/surrogate_model/) | Surrogate Model | 1D | ❌ | MLP
[vqgan](engiopt/vqgan) | Inverse Design | 2D | ✅ | VQVAE + Transformer

## Dashboards
The integration with WandB allows us to access live dashboards of our runs (on the cluster or not). We also upload the trained models there. You can access some of our runs at https://wandb.ai/engibench/engiopt.
Expand Down
7 changes: 6 additions & 1 deletion engiopt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import multiprocessing
import os
import sys
import traceback
from typing import Any, TYPE_CHECKING

Expand All @@ -21,7 +22,11 @@
from engibench import OptiStep
from engibench.core import Problem

multiprocessing.set_start_method("fork")

if sys.platform != "win32": # only set fork on non-Windows
multiprocessing.set_start_method("fork", force=True)
else:
multiprocessing.set_start_method("spawn", force=True)


def mmd(x: np.ndarray, y: np.ndarray, sigma: float = 1.0) -> float:
Expand Down
45 changes: 45 additions & 0 deletions engiopt/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from collections.abc import Callable

from datasets import Dataset
from engibench.core import Problem
from gymnasium import spaces
import torch as th
import torch.nn.functional as f


def flatten_dict_factory(problem: Problem, device: th.device) -> Callable:
Expand All @@ -21,3 +23,46 @@ def flatten_dict(x):
return th.stack(flattened)

return flatten_dict


def resize_to(data: th.Tensor, h: int, w: int, mode: str = "bicubic") -> th.Tensor:
"""Resize 2D data back to any desired (h, w). Data should be a Tensor in the format (B, C, H, W)."""
low_dim = 3
if data.ndim == low_dim:
data = data.unsqueeze(1) # (B, 1, H, W)
return f.interpolate(data, size=(h, w), mode=mode)


def normalize(ds: Dataset, condition_names: list[str]) -> tuple[Dataset, th.Tensor, th.Tensor]:
"""Normalize specified condition columns with global mean/std."""
# stack condition columns into a single tensor (N, C) on CPU
conds = th.stack([th.as_tensor(ds[c][:]).float() for c in condition_names], dim=1)
mean = conds.mean(dim=0)
std = conds.std(dim=0).clamp(min=1e-8)

# normalize each condition column (HF expects numpy back)
ds = ds.map(
lambda batch: {
c: ((th.as_tensor(batch[c][:]).float() - mean[i]) / std[i]).numpy() for i, c in enumerate(condition_names)
},
batched=True,
)

return ds, mean, std


def drop_constant(ds: Dataset, condition_names: list[str]) -> tuple[Dataset, list[str]]:
"""Drop constant condition columns (std=0) from dataset."""
conds = th.stack([th.as_tensor(ds[c][:]).float() for c in condition_names], dim=1)
std = conds.std(dim=0)

kept = [c for i, c in enumerate(condition_names) if std[i] > 0]
dropped = [c for i, c in enumerate(condition_names) if std[i] == 0]

if dropped:
print(f"Warning: Dropping constant condition columns (std=0): {dropped}")

# remove dropped columns from dataset
ds = ds.remove_columns(dropped)

return ds, kept
Empty file added engiopt/vqgan/__init__.py
Empty file.
213 changes: 213 additions & 0 deletions engiopt/vqgan/evaluate_vqgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Evaluation for the VQGAN."""

from __future__ import annotations

import dataclasses
import os

from engibench.utils.all_problems import BUILTIN_PROBLEMS
import numpy as np
import pandas as pd
import torch as th
import tyro
import wandb

from engiopt import metrics
from engiopt.dataset_sample_conditions import sample_conditions
from engiopt.transforms import drop_constant
from engiopt.transforms import normalize
from engiopt.transforms import resize_to
from engiopt.vqgan.vqgan import VQGAN
from engiopt.vqgan.vqgan import VQGANTransformer


@dataclasses.dataclass
class Args:
"""Command-line arguments for a single-seed VQGAN 2D evaluation."""

problem_id: str = "beams2d"
"""Problem identifier."""
seed: int = 1
"""Random seed to run."""
wandb_project: str = "engiopt"
"""Wandb project name."""
wandb_entity: str | None = None
"""Wandb entity name."""
n_samples: int = 50
"""Number of generated samples per seed."""
sigma: float = 10.0
"""Kernel bandwidth for MMD and DPP metrics."""
output_csv: str = "vqgan_{problem_id}_metrics.csv"
"""Output CSV path template; may include {problem_id}."""


if __name__ == "__main__":
args = tyro.cli(Args)

seed = args.seed
problem = BUILTIN_PROBLEMS[args.problem_id]()
problem.reset(seed=seed)

# Reproducibility
th.manual_seed(seed)
rng = np.random.default_rng(seed)
th.backends.cudnn.deterministic = True

if th.backends.mps.is_available():
device = th.device("mps")
elif th.cuda.is_available():
device = th.device("cuda")
else:
device = th.device("cpu")

### Set Up Transformer ###

# Restores the pytorch model from wandb
if args.wandb_entity is not None:
artifact_path_cvqgan = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}"
artifact_path_vqgan = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}"
artifact_path_transformer = (
f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}"
)
else:
artifact_path_cvqgan = f"{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}"
artifact_path_vqgan = f"{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}"
artifact_path_transformer = f"{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}"

api = wandb.Api()
artifact_cvqgan = api.artifact(artifact_path_cvqgan, type="model")
artifact_vqgan = api.artifact(artifact_path_vqgan, type="model")
artifact_transformer = api.artifact(artifact_path_transformer, type="model")

class RunRetrievalError(ValueError):
def __init__(self):
super().__init__("Failed to retrieve the run")

run = artifact_transformer.logged_by()
if run is None or not hasattr(run, "config"):
raise RunRetrievalError
artifact_dir_cvqgan = artifact_cvqgan.download()
artifact_dir_vqgan = artifact_vqgan.download()
artifact_dir_transformer = artifact_transformer.download()

ckpt_path_cvqgan = os.path.join(artifact_dir_cvqgan, "cvqgan.pth")
ckpt_path_vqgan = os.path.join(artifact_dir_vqgan, "vqgan.pth")
ckpt_path_transformer = os.path.join(artifact_dir_transformer, "transformer.pth")
ckpt_cvqgan = th.load(ckpt_path_cvqgan, map_location=th.device(device), weights_only=False)
ckpt_vqgan = th.load(ckpt_path_vqgan, map_location=th.device(device), weights_only=False)
ckpt_transformer = th.load(ckpt_path_transformer, map_location=th.device(device), weights_only=False)

vqgan = VQGAN(
device=device,
is_c=False,
encoder_channels=run.config["encoder_channels"],
encoder_start_resolution=run.config["image_size"],
encoder_attn_resolutions=run.config["encoder_attn_resolutions"],
encoder_num_res_blocks=run.config["encoder_num_res_blocks"],
decoder_channels=run.config["decoder_channels"],
decoder_start_resolution=run.config["latent_size"],
decoder_attn_resolutions=run.config["decoder_attn_resolutions"],
decoder_num_res_blocks=run.config["decoder_num_res_blocks"],
image_channels=run.config["image_channels"],
latent_dim=run.config["latent_dim"],
num_codebook_vectors=run.config["num_codebook_vectors"],
)
vqgan.load_state_dict(ckpt_vqgan["vqgan"])
vqgan.eval() # Set to evaluation mode
vqgan.to(device)

cvqgan = VQGAN(
device=device,
is_c=True,
cond_feature_map_dim=run.config["cond_feature_map_dim"],
cond_dim=run.config["cond_dim"],
cond_hidden_dim=run.config["cond_hidden_dim"],
cond_latent_dim=run.config["cond_latent_dim"],
cond_codebook_vectors=run.config["cond_codebook_vectors"],
)
cvqgan.load_state_dict(ckpt_cvqgan["cvqgan"])
cvqgan.eval() # Set to evaluation mode
cvqgan.to(device)

model = VQGANTransformer(
conditional=run.config["conditional"],
vqgan=vqgan,
cvqgan=cvqgan,
image_size=run.config["image_size"],
decoder_channels=run.config["decoder_channels"],
cond_feature_map_dim=run.config["cond_feature_map_dim"],
num_codebook_vectors=run.config["num_codebook_vectors"],
n_layer=run.config["n_layer"],
n_head=run.config["n_head"],
n_embd=run.config["n_embd"],
dropout=run.config["dropout"],
)
model.load_state_dict(ckpt_transformer["transformer"])
model.eval() # Set to evaluation mode
model.to(device)

### Set up testing conditions ###
_, sampled_conditions, sampled_designs_np, _ = sample_conditions(
problem=problem, n_samples=args.n_samples, device=device, seed=seed
)

# Clean up conditions based on model training settings and convert back to tensor
sampled_conditions_new = sampled_conditions.select(range(len(sampled_conditions)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get what this does

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This copies over the original sampled_conditions into a new dataset that can then be normalized/cleaned prior to feeding into the model. The reason we need this copy is (to the best of my knowledge) the metrics.metrics() call later requires the original conditions for its calculations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair! The name is a bit cryptic then, though.

conditions = sampled_conditions_new.column_names

# Drop constant condition columns if enabled
if run.config["drop_constant_conditions"]:
sampled_conditions_new, conditions = drop_constant(sampled_conditions_new, sampled_conditions_new.column_names)

# Normalize condition columns if enabled
if run.config["normalize_conditions"]:
sampled_conditions_new, mean, std = normalize(sampled_conditions_new, conditions)

# Convert to tensor
conditions_tensor = th.stack([th.as_tensor(sampled_conditions_new[c][:]).float() for c in conditions], dim=1).to(device)

# Set the start-of-sequence tokens for the transformer using the CVQGAN to discretize the conditions if enabled
if run.config["conditional"]:
c = model.encode_to_z(x=conditions_tensor, is_c=True)[1]
else:
c = th.ones(args.n_samples, 1, dtype=th.int64, device=device) * model.sos_token

# Generate a batch of designs
latent_designs = model.sample(
x=th.empty(args.n_samples, 0, dtype=th.int64, device=device), c=c, steps=(run.config["latent_size"] ** 2)
)
gen_designs = resize_to(
data=model.z_to_image(latent_designs), h=problem.design_space.shape[0], w=problem.design_space.shape[1]
)
gen_designs_np = gen_designs.detach().cpu().numpy()
gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape)

# Clip to boundaries for running THIS IS PROBLEM DEPENDENT
gen_designs_np = np.clip(gen_designs_np, 1e-3, 1)

# Compute metrics
metrics_dict = metrics.metrics(
problem,
gen_designs_np,
sampled_designs_np,
sampled_conditions,
sigma=args.sigma,
)

metrics_dict.update(
{
"seed": seed,
"problem_id": args.problem_id,
"model_id": "vqgan",
"n_samples": args.n_samples,
"sigma": args.sigma,
}
)

# Append result row to CSV
metrics_df = pd.DataFrame([metrics_dict])
out_path = args.output_csv.format(problem_id=args.problem_id)
write_header = not os.path.exists(out_path)
metrics_df.to_csv(out_path, mode="a", header=write_header, index=False)

print(f"Seed {seed} done; appended to {out_path}")
Loading
Loading