-
Notifications
You must be signed in to change notification settings - Fork 6
Final version of the PixelCNN++ implementation #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
1a6e50f
innit with contents from other neural network modules that are shared
JonasETHZ cc90a5f
changed __innit____ to __init__
JonasETHZ 11c285a
start of implementing conditional inputs for
JonasETHZ acc01de
model mostly implemented, needs testing
JonasETHZ ca91de7
loss function added
JonasETHZ b44d771
loss function added
JonasETHZ 94be7a6
removed unneccessary file
JonasETHZ 2e8e5d9
partially debugged
JonasETHZ 9900866
working pixel_cnn_pp_2d.py version
JonasETHZ 52b0caa
working pixel_cnn_pp_2d.py model, added function to save model every …
JonasETHZ ead782b
finished up evaluate_pixel_cnn_pp_2d.py and fixed parameter handling …
JonasETHZ 7081abd
Added multiple GPU support for sampling
JonasETHZ bd34475
increased number of generated samples to 25
JonasETHZ f5c475a
multi GPU v2
JonasETHZ 66a1e8f
added info message for multi-GPU sampling
JonasETHZ d690781
model with max settings and 1 GPU
JonasETHZ 1ef0254
fixed scaling (input/output) and removed multi GPU sampling
JonasETHZ eeb1812
sample interval 600
JonasETHZ 811167a
updated evaluate_pixel_cnn_pp_2d.py
JonasETHZ e3081c3
added some debugging lines
JonasETHZ c5a143a
added sampling in batches for pixel_cnn_pp_2d
JonasETHZ b73b420
removed comments and time.sleep()
JonasETHZ 93e34ec
added info text
JonasETHZ c246690
another print statement
JonasETHZ e8fbf36
lowered sampling batch size to reduce memory consumption during evalu…
JonasETHZ fd2bb14
nr of samples changed to 25
JonasETHZ b2858de
evaluation preparation and removal of some debug prints
JonasETHZ 9497b4e
Final version of PixelCNN++ implementation
JonasETHZ f29cca1
Fixed typos
JonasETHZ 23b9091
Merge branch 'main' into main
JonasETHZ ea6fe9a
Formatting with ruff
JonasETHZ b2abfe9
Merge branch 'main' of https://github.com/JonasETHZ/EngiOpt
JonasETHZ c2f726b
replaced NIN by NetworkInNetwork
JonasETHZ File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| """Evaluation of the PixelCNN++ model.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| import os | ||
|
|
||
| from engibench.utils.all_problems import BUILTIN_PROBLEMS | ||
| import numpy as np | ||
| import pandas as pd | ||
| import torch as th | ||
| import tyro | ||
| import wandb | ||
|
|
||
| from engiopt import metrics | ||
| from engiopt.dataset_sample_conditions import sample_conditions | ||
| from engiopt.pixel_cnn_pp_2d.pixel_cnn_pp_2d import PixelCNNpp | ||
| from engiopt.pixel_cnn_pp_2d.pixel_cnn_pp_2d import sample_from_discretized_mix_logistic | ||
|
|
||
|
|
||
| @dataclass | ||
| class Args: | ||
| """Command-line arguments for a single-seed PixelCNN++ 2D evaluation.""" | ||
|
|
||
| problem_id: str = "beams2d" | ||
| """Problem identifier.""" | ||
| seed: int = 1 | ||
| """Random seed to run.""" | ||
| sampling_batch_size: int = 5 | ||
| """Batch size to use during sampling.""" | ||
| wandb_project: str = "engiopt" | ||
| """Wandb project name.""" | ||
| wandb_entity: str | None = None | ||
| """Wandb entity name.""" | ||
| n_samples: int = 50 | ||
| """Number of generated samples per seed.""" | ||
| sigma: float = 10.0 | ||
| """Kernel bandwidth for MMD and DPP metrics.""" | ||
| output_csv: str = "pixel_cnn_pp_2d_{problem_id}_metrics.csv" | ||
| """Output CSV path template; may include {problem_id}.""" | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = tyro.cli(Args) | ||
|
|
||
| seed = args.seed | ||
| problem = BUILTIN_PROBLEMS[args.problem_id]() | ||
| problem.reset(seed=seed) | ||
|
|
||
| # Seeding for reproducibility | ||
| th.manual_seed(seed) | ||
| rng = np.random.default_rng(seed) | ||
| th.backends.cudnn.deterministic = True | ||
|
|
||
| # Select device | ||
| if th.backends.mps.is_available(): | ||
| device = th.device("mps") | ||
| elif th.cuda.is_available(): | ||
| device = th.device("cuda") | ||
| else: | ||
| device = th.device("cpu") | ||
|
|
||
| # Set up testing conditions | ||
| conditions_tensor, sampled_conditions, sampled_designs_np, selected_indices = sample_conditions( | ||
| problem=problem, | ||
| n_samples=args.n_samples, | ||
| device=device, | ||
| seed=seed, | ||
| ) | ||
|
|
||
| # adapt to PixelCNN++ input shape requirements | ||
| conditions_tensor = conditions_tensor.unsqueeze(-1).unsqueeze(-1) | ||
| design_shape = (problem.design_space.shape[0], problem.design_space.shape[1]) | ||
|
|
||
| # Set up PixelCNN++ model | ||
| if args.wandb_entity is not None: | ||
| artifact_path = f"{args.wandb_entity}/{args.wandb_project}/{args.problem_id}_pixel_cnn_pp_2d_model:seed_{seed}" | ||
| else: | ||
| artifact_path = f"{args.wandb_project}/{args.problem_id}_pixel_cnn_pp_2d_model:seed_{seed}" | ||
|
|
||
| api = wandb.Api() | ||
| artifact = api.artifact(artifact_path, type="model") | ||
|
|
||
| class RunRetrievalError(ValueError): | ||
| def __init__(self): | ||
| super().__init__("Failed to retrieve the run") | ||
|
|
||
| run = artifact.logged_by() | ||
| if run is None or not hasattr(run, "config"): | ||
| raise RunRetrievalError | ||
|
|
||
| artifact_dir = artifact.download() | ||
| ckpt_path = os.path.join(artifact_dir, "model.pth") | ||
| ckpt = th.load(ckpt_path, map_location=device) | ||
|
|
||
| model = PixelCNNpp( | ||
| nr_resnet=run.config["nr_resnet"], | ||
| nr_filters=run.config["nr_filters"], | ||
| nr_logistic_mix=run.config["nr_logistic_mix"], | ||
| resnet_nonlinearity=run.config["resnet_nonlinearity"], | ||
| dropout_p=run.config["dropout_p"], | ||
| input_channels=1, | ||
| nr_conditions=conditions_tensor.shape[1], | ||
| ) | ||
|
|
||
| model.load_state_dict(ckpt["model"]) | ||
| model.eval() | ||
| model.to(device) | ||
|
|
||
| # Sample designs in batches | ||
| all_batches: list[th.Tensor] = [] | ||
|
|
||
| for start in range(0, args.n_samples, args.sampling_batch_size): | ||
| end = min(args.n_samples, start + args.sampling_batch_size) | ||
| b = end - start | ||
|
|
||
| # prepare batch-local tensors on the same device as the model | ||
| batch_conds = conditions_tensor[start:end] | ||
| data = th.zeros((b, 1, *design_shape), device=device) | ||
|
|
||
| # Autoregressive pixel sampling for this batch | ||
| for i in range(design_shape[0]): | ||
| for j in range(design_shape[1]): | ||
| out = model(data, batch_conds) | ||
| out_sample = sample_from_discretized_mix_logistic(out, run.config["nr_logistic_mix"]) | ||
| data[:, :, i, j] = out_sample.data[:, :, i, j] | ||
|
|
||
| # move completed batch to CPU to free GPU memory and store | ||
| all_batches.append(data.cpu()) | ||
|
|
||
| # concatenate all batches on CPU | ||
| gen_designs = th.cat(all_batches, dim=0) | ||
|
|
||
| gen_designs_np = gen_designs.detach().cpu().numpy() | ||
| gen_designs_np = gen_designs_np.reshape(args.n_samples, *problem.design_space.shape) # remove channel dim | ||
| gen_designs_np = np.clip(gen_designs_np, 1e-3, 1.0) | ||
|
|
||
| # Compute metrics | ||
| metrics_dict = metrics.metrics( | ||
| problem, | ||
| gen_designs_np, | ||
| sampled_designs_np, | ||
| sampled_conditions, | ||
| sigma=args.sigma, | ||
| ) | ||
|
|
||
| # Add metadata to metrics | ||
| metrics_dict.update( | ||
| { | ||
| "seed": seed, | ||
| "problem_id": args.problem_id, | ||
| "model_id": "pixel_cnn_pp_2d", | ||
| "n_samples": args.n_samples, | ||
| "sigma": args.sigma, | ||
| } | ||
| ) | ||
|
|
||
| # Append result row to CSV | ||
| metrics_df = pd.DataFrame([metrics_dict]) | ||
| out_path = args.output_csv.format(problem_id=args.problem_id) | ||
| write_header = not os.path.exists(out_path) | ||
| metrics_df.to_csv(out_path, mode="a", header=write_header, index=False) | ||
|
|
||
| print(f"Seed {seed} done; appended to {out_path}") | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.