-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
71 lines (53 loc) · 2.09 KB
/
test.py
File metadata and controls
71 lines (53 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import yaml
import argparse
from datasets import load_dataset
from models import load_model
from utils import *
import torch
import matplotlib.pyplot as plt
from einops import rearrange
import cv2
def get_args_parser():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--config', type=str, default='vae.mnist')
return parser
def main(cfg):
print(f"=================[{cfg['expr']}]=================")
# Device Setting
device = None
if cfg['device'] != 'cpu' and torch.cuda.is_available():
device = cfg['device']
else:
device = 'cpu'
# Load Model
model_cfg = cfg['model']
model = load_model(**model_cfg).to(device)
ckpt = torch.load(os.path.join(cfg['save_path'], cfg['load_weights']),
weights_only=False)
model.load_state_dict(ckpt['model'])
print(f"Load Model {model_cfg['name']}")
# Select Task
if cfg['task'] == 'recon':
data_cfg = cfg['data']
ds = load_dataset(**data_cfg)
x, _ = ds[200]
x = x.unsqueeze(0).to(device)
model.eval()
with torch.no_grad():
x_prime, _, __ = model(x)
x_prime = rearrange(x_prime, '1 c h w -> h w c').detach().cpu().numpy() * 255.
x = rearrange(x, '1 c h w -> h w c').detach().cpu().numpy() * 255.
x_prime = cv2.hconcat([x, x_prime])
cv2.imwrite(f"assets/test_recon_{cfg['data']['dataset']}.jpg", x_prime)
elif cfg['task'] == 'gen':
model.eval()
with torch.no_grad():
x_prime = model.sample(1, device)
x_prime = rearrange(x_prime, '1 c h w -> h w c').detach().cpu().numpy() * 255.
cv2.imwrite(f"assets/test_gen_{cfg['data']['dataset']}.jpg", x_prime)
if __name__ == '__main__':
parser = argparse.ArgumentParser(parents=[get_args_parser()])
args = parser.parse_args()
with open(f'configs/test.{args.config}.yaml') as f:
cfg = yaml.full_load(f)
main(cfg)