-
Notifications
You must be signed in to change notification settings - Fork 6
Vqgan #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Vqgan #49
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
482c85d
add initial vqgan files
arthurdrake1 be14a17
add windows support and add vqgan description
arthurdrake1 1398f41
add vqgan args and some new image transform functions
arthurdrake1 e334900
add stage 1 submodules
arthurdrake1 8df08db
clean up stage 1 modules
arthurdrake1 e415de7
add remaining cleaned up stage 1 modules
arthurdrake1 97efc89
add vqgan stage 1 main class and resolve arg names
arthurdrake1 f9d565d
add vqgan stage 2 modules
arthurdrake1 5b9625f
add initial training loops and conditions augmentations
arthurdrake1 2be8e23
Merge branch 'main' into vqgan
arthurdrake1 c0e4ae0
add training tracking
arthurdrake1 93cc9cd
add image logging to training
arthurdrake1 ff937fa
simplify args and transforms; remove unneeded transformer code
arthurdrake1 e7bafce
minor plotting fixes
arthurdrake1 b845834
add initial vqgan eval script
arthurdrake1 c543345
minor eval conditions fixes
arthurdrake1 ef8aa72
ruff fixes and add dependencies
arthurdrake1 992a3fd
add early stopping arg for transformer
arthurdrake1 f0b36d8
fix early stopping model saving
arthurdrake1 a0437b4
clarify var names, add comments, other minor fixes
arthurdrake1 a29fd33
move helper blocks and utils to their own file
arthurdrake1 d88a48b
remove default comments from args
arthurdrake1 5577396
remove unused args
arthurdrake1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| *.py text eol=lf | ||
| *.json text eol=lf | ||
| *.yml text eol=lf | ||
| *.yaml text eol=lf | ||
| *.sh text eol=lf | ||
| *.md text eol=lf | ||
| *.txt text eol=lf | ||
| *.ipynb text eol=lf | ||
|
|
||
| *.png binary | ||
| *.jpg binary | ||
| *.jpeg binary | ||
| *.gif binary | ||
| *.pdf binary | ||
| *.pkl binary | ||
| *.npy binary |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| """Evaluation for the VQGAN.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import dataclasses | ||
| import os | ||
|
|
||
| from engibench.utils.all_problems import BUILTIN_PROBLEMS | ||
| import numpy as np | ||
| import pandas as pd | ||
| import torch as th | ||
| import tyro | ||
| import wandb | ||
|
|
||
| from engiopt import metrics | ||
| from engiopt.dataset_sample_conditions import sample_conditions | ||
| from engiopt.transforms import drop_constant | ||
| from engiopt.transforms import normalize | ||
| from engiopt.transforms import resize_to | ||
| from engiopt.vqgan.vqgan import VQGAN | ||
| from engiopt.vqgan.vqgan import VQGANTransformer | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Args: | ||
| """Command-line arguments for a single-seed VQGAN 2D evaluation.""" | ||
|
|
||
| problem_id: str = "beams2d" | ||
| """Problem identifier.""" | ||
| seed: int = 1 | ||
| """Random seed to run.""" | ||
| wandb_project: str = "engiopt" | ||
| """Wandb project name.""" | ||
| wandb_entity: str | None = None | ||
| """Wandb entity name.""" | ||
| n_samples: int = 50 | ||
| """Number of generated samples per seed.""" | ||
| sigma: float = 10.0 | ||
| """Kernel bandwidth for MMD and DPP metrics.""" | ||
| output_csv: str = "vqgan_{problem_id}_metrics.csv" | ||
| """Output CSV path template; may include {problem_id}.""" | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = tyro.cli(Args) | ||
|
|
||
| seed = args.seed | ||
| problem = BUILTIN_PROBLEMS[args.problem_id]() | ||
| problem.reset(seed=seed) | ||
|
|
||
| # Reproducibility | ||
| th.manual_seed(seed) | ||
| rng = np.random.default_rng(seed) | ||
| th.backends.cudnn.deterministic = True | ||
|
|
||
| if th.backends.mps.is_available(): | ||
| device = th.device("mps") | ||
| elif th.cuda.is_available(): | ||
| device = th.device("cuda") | ||
| else: | ||
| device = th.device("cpu") | ||
|
|
||
| ### Set Up Transformer ### | ||
|
|
||
| # Restores the pytorch model from wandb | ||
| if args.wandb_entity is not None: | ||
| artifact_path_cvqgan = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}" | ||
| artifact_path_vqgan = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}" | ||
| artifact_path_transformer = ( | ||
| f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}" | ||
| ) | ||
| else: | ||
| artifact_path_cvqgan = f"{args.wandb_project}/{args.problem_id}_vqgan_cvqgan:seed_{seed}" | ||
| artifact_path_vqgan = f"{args.wandb_project}/{args.problem_id}_vqgan_vqgan:seed_{seed}" | ||
| artifact_path_transformer = f"{args.wandb_project}/{args.problem_id}_vqgan_transformer:seed_{seed}" | ||
|
|
||
| api = wandb.Api() | ||
| artifact_cvqgan = api.artifact(artifact_path_cvqgan, type="model") | ||
| artifact_vqgan = api.artifact(artifact_path_vqgan, type="model") | ||
| artifact_transformer = api.artifact(artifact_path_transformer, type="model") | ||
|
|
||
| class RunRetrievalError(ValueError): | ||
| def __init__(self): | ||
| super().__init__("Failed to retrieve the run") | ||
|
|
||
| run = artifact_transformer.logged_by() | ||
| if run is None or not hasattr(run, "config"): | ||
| raise RunRetrievalError | ||
| artifact_dir_cvqgan = artifact_cvqgan.download() | ||
| artifact_dir_vqgan = artifact_vqgan.download() | ||
| artifact_dir_transformer = artifact_transformer.download() | ||
|
|
||
| ckpt_path_cvqgan = os.path.join(artifact_dir_cvqgan, "cvqgan.pth") | ||
| ckpt_path_vqgan = os.path.join(artifact_dir_vqgan, "vqgan.pth") | ||
| ckpt_path_transformer = os.path.join(artifact_dir_transformer, "transformer.pth") | ||
| ckpt_cvqgan = th.load(ckpt_path_cvqgan, map_location=th.device(device), weights_only=False) | ||
| ckpt_vqgan = th.load(ckpt_path_vqgan, map_location=th.device(device), weights_only=False) | ||
| ckpt_transformer = th.load(ckpt_path_transformer, map_location=th.device(device), weights_only=False) | ||
|
|
||
| vqgan = VQGAN( | ||
| device=device, | ||
| is_c=False, | ||
| encoder_channels=run.config["encoder_channels"], | ||
| encoder_start_resolution=run.config["image_size"], | ||
| encoder_attn_resolutions=run.config["encoder_attn_resolutions"], | ||
| encoder_num_res_blocks=run.config["encoder_num_res_blocks"], | ||
| decoder_channels=run.config["decoder_channels"], | ||
| decoder_start_resolution=run.config["latent_size"], | ||
| decoder_attn_resolutions=run.config["decoder_attn_resolutions"], | ||
| decoder_num_res_blocks=run.config["decoder_num_res_blocks"], | ||
| image_channels=run.config["image_channels"], | ||
| latent_dim=run.config["latent_dim"], | ||
| num_codebook_vectors=run.config["num_codebook_vectors"], | ||
| ) | ||
| vqgan.load_state_dict(ckpt_vqgan["vqgan"]) | ||
| vqgan.eval() # Set to evaluation mode | ||
| vqgan.to(device) | ||
|
|
||
| cvqgan = VQGAN( | ||
| device=device, | ||
| is_c=True, | ||
| cond_feature_map_dim=run.config["cond_feature_map_dim"], | ||
| cond_dim=run.config["cond_dim"], | ||
| cond_hidden_dim=run.config["cond_hidden_dim"], | ||
| cond_latent_dim=run.config["cond_latent_dim"], | ||
| cond_codebook_vectors=run.config["cond_codebook_vectors"], | ||
| ) | ||
| cvqgan.load_state_dict(ckpt_cvqgan["cvqgan"]) | ||
| cvqgan.eval() # Set to evaluation mode | ||
| cvqgan.to(device) | ||
|
|
||
| model = VQGANTransformer( | ||
| conditional=run.config["conditional"], | ||
| vqgan=vqgan, | ||
| cvqgan=cvqgan, | ||
| image_size=run.config["image_size"], | ||
| decoder_channels=run.config["decoder_channels"], | ||
| cond_feature_map_dim=run.config["cond_feature_map_dim"], | ||
| num_codebook_vectors=run.config["num_codebook_vectors"], | ||
| n_layer=run.config["n_layer"], | ||
| n_head=run.config["n_head"], | ||
| n_embd=run.config["n_embd"], | ||
| dropout=run.config["dropout"], | ||
| ) | ||
| model.load_state_dict(ckpt_transformer["transformer"]) | ||
| model.eval() # Set to evaluation mode | ||
| model.to(device) | ||
|
|
||
| ### Set up testing conditions ### | ||
| _, sampled_conditions, sampled_designs_np, _ = sample_conditions( | ||
| problem=problem, n_samples=args.n_samples, device=device, seed=seed | ||
| ) | ||
|
|
||
| # Clean up conditions based on model training settings and convert back to tensor | ||
| sampled_conditions_new = sampled_conditions.select(range(len(sampled_conditions))) | ||
| conditions = sampled_conditions_new.column_names | ||
|
|
||
| # Drop constant condition columns if enabled | ||
| if run.config["drop_constant_conditions"]: | ||
| sampled_conditions_new, conditions = drop_constant(sampled_conditions_new, sampled_conditions_new.column_names) | ||
|
|
||
| # Normalize condition columns if enabled | ||
| if run.config["normalize_conditions"]: | ||
| sampled_conditions_new, mean, std = normalize(sampled_conditions_new, conditions) | ||
|
|
||
| # Convert to tensor | ||
| conditions_tensor = th.stack([th.as_tensor(sampled_conditions_new[c][:]).float() for c in conditions], dim=1).to(device) | ||
|
|
||
| # Set the start-of-sequence tokens for the transformer using the CVQGAN to discretize the conditions if enabled | ||
| if run.config["conditional"]: | ||
| c = model.encode_to_z(x=conditions_tensor, is_c=True)[1] | ||
| else: | ||
| c = th.ones(args.n_samples, 1, dtype=th.int64, device=device) * model.sos_token | ||
|
|
||
| # Generate a batch of designs | ||
| latent_designs = model.sample( | ||
| x=th.empty(args.n_samples, 0, dtype=th.int64, device=device), c=c, steps=(run.config["latent_size"] ** 2) | ||
| ) | ||
| gen_designs = resize_to( | ||
| data=model.z_to_image(latent_designs), h=problem.design_space.shape[0], w=problem.design_space.shape[1] | ||
| ) | ||
| gen_designs_np = gen_designs.detach().cpu().numpy() | ||
| gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) | ||
|
|
||
| # Clip to boundaries for running THIS IS PROBLEM DEPENDENT | ||
| gen_designs_np = np.clip(gen_designs_np, 1e-3, 1) | ||
|
|
||
| # Compute metrics | ||
| metrics_dict = metrics.metrics( | ||
| problem, | ||
| gen_designs_np, | ||
| sampled_designs_np, | ||
| sampled_conditions, | ||
| sigma=args.sigma, | ||
| ) | ||
|
|
||
| metrics_dict.update( | ||
| { | ||
| "seed": seed, | ||
| "problem_id": args.problem_id, | ||
| "model_id": "vqgan", | ||
| "n_samples": args.n_samples, | ||
| "sigma": args.sigma, | ||
| } | ||
| ) | ||
|
|
||
| # Append result row to CSV | ||
| metrics_df = pd.DataFrame([metrics_dict]) | ||
| out_path = args.output_csv.format(problem_id=args.problem_id) | ||
| write_header = not os.path.exists(out_path) | ||
| metrics_df.to_csv(out_path, mode="a", header=write_header, index=False) | ||
|
|
||
| print(f"Seed {seed} done; appended to {out_path}") | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get what this does
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This copies over the original
sampled_conditionsinto a new dataset that can then be normalized/cleaned prior to feeding into the model. The reason we need this copy is (to the best of my knowledge) themetrics.metrics()call later requires the original conditions for its calculations.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair! The name is a bit cryptic then, though.