diff --git a/engiopt/vqgan/evaluate_vqgan.py b/engiopt/vqgan/evaluate_vqgan.py index 40c19ff..bc2d3ed 100644 --- a/engiopt/vqgan/evaluate_vqgan.py +++ b/engiopt/vqgan/evaluate_vqgan.py @@ -84,8 +84,10 @@ def __init__(self): super().__init__("Failed to retrieve the run") run = artifact_transformer.logged_by() - if run is None or not hasattr(run, "config"): + if run is None: raise RunRetrievalError + run = api.run(f"{run.entity}/{run.project}/{run.id}") + artifact_dir_cvqgan = artifact_cvqgan.download() artifact_dir_vqgan = artifact_vqgan.download() artifact_dir_transformer = artifact_transformer.download() diff --git a/engiopt/vqgan/vqgan.py b/engiopt/vqgan/vqgan.py index efdc2fa..c34e27f 100644 --- a/engiopt/vqgan/vqgan.py +++ b/engiopt/vqgan/vqgan.py @@ -680,8 +680,8 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu # Now we assume the dataset is of shape (N, C, H, W) and work from there image_channels = training_ds["optimal_upsampled"][:].shape[1] latent_size = args.image_size // (2 ** (len(args.encoder_channels) - 2)) - conditions = problem.conditions_keys + conditions = problem.conditions_keys # Optionally drop condition columns that are constant like overhang_constraint in beams2d if args.drop_constant_conditions: training_ds, conditions = drop_constant(training_ds, conditions) @@ -787,6 +787,8 @@ def log_images(self, x: th.Tensor, c: th.Tensor, top_k: int | None = None) -> tu wandb.define_metric("epoch_transformer", step_metric="transformer_step") if args.early_stopping: wandb.define_metric("transformer_val_loss", step_metric="transformer_step") + wandb.config["image_channels"] = image_channels + wandb.config["latent_size"] = latent_size vqgan = VQGAN( device=device,