-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
66 lines (53 loc) · 1.94 KB
/
evaluate.py
File metadata and controls
66 lines (53 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import cv2
import numpy as np
from tqdm.auto import tqdm
from PIL import Image
from pathlib import Path
from os import listdir
from os.path import isfile, join
from model import get_generator
from util import arguments
from util.transforms import ToNumpyRGB256
from dataset.InpaintDataset import SCDataset
from util.metrics import PSNR, SSIM
torch.backends.cudnn.benchmark = True
if __name__ == "__main__":
args = arguments.parse_arguments()
args.data = Path(args.data)
net_G = get_generator(args)
net_G.load_state_dict(torch.load(args.load_G))
net_G = net_G.cuda().eval()
imgpath = args.data / "images_256"
collpath = args.data / "collages"
respath = args.data / "results"
collpath.mkdir(exist_ok=True)
respath.mkdir(exist_ok=True)
files = [Path(f).stem for f in listdir(imgpath) if isfile(join(imgpath, f))][::50]
dataset = SCDataset(args.data, files)
psnr = 0.0
ssim = 0.0
l2 = 0.0
P = PSNR()
S = SSIM()
L = torch.nn.MSELoss()
n_elems = len(dataset)
for i, item in enumerate(tqdm(dataset)):
image, colormap, sketch, mask = (
item["image"].unsqueeze(0).cuda(),
item["colormap"].unsqueeze(0).cuda(),
item["sketch"].unsqueeze(0).cuda(),
item["mask"].unsqueeze(0).cuda(),
)
generator_input = torch.cat(
(image * mask, colormap * (1 - mask), sketch * (1 - mask), mask), dim=1
)
coarse_image, refined_image = net_G(generator_input)
completed_image = refined_image * (1 - mask) + image * mask
metrics_input = (completed_image.squeeze() + 1) * 255 / 2, (image.squeeze() + 1) * 255 / 2
psnr += (P(*metrics_input) / n_elems).item()
ssim += (S(*metrics_input) / n_elems).item()
l2 += L(completed_image, image).item() / n_elems / 2
print("PSNR:", np.round(psnr, 4))
print("SSIM:", np.round(ssim, 4))
print("L2: ", np.round(l2 * 100, 4))