-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
214 lines (194 loc) · 6.38 KB
/
main.py
File metadata and controls
214 lines (194 loc) · 6.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import argparse
import torch as T
import torch.optim as optim
import torch.nn.functional as F
from DiffusionTest import train
from DiffusionTest.model import SmallUnetWithEmb
from DiffusionTest.diffusion import Diffusion
from DiffusionTest.loader import get_loaders
from DiffusionTest.utils import save_images
def parse_args():
parser = argparse.ArgumentParser(
description="Train a diffusion model or generate images."
)
# Arguments related to device and training mode
parser.add_argument(
"--device",
type=str,
default="cuda" if T.cuda.is_available() else "cpu",
help="Device to run the model on, 'cuda' or 'cpu'. Default is based on availability.",
)
parser.add_argument(
"--training",
action="store_true",
default=False,
help="Flag to train the model. If not set, images will be generated instead.",
)
# Arguments related to saving, dataset and model configurations
parser.add_argument(
"--saving_path",
type=str,
default="/teamspace/studios/this_studio/runs",
help="Path where model checkpoints and images are saved.",
)
parser.add_argument(
"--version",
type=int,
default=12,
help="Version number for saving models and images.",
)
parser.add_argument(
"--dataset",
type=str,
choices=["celeb", "bridge", "cifar10", "fashion"],
default="celeb",
help="Dataset to use. Default is 'celeb'.",
)
parser.add_argument(
"--img_size",
type=int,
default=64,
help="Image size to use for training and generation. Default is 64x64.",
)
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="Batch size for the dataloader. Default is 64.",
)
# Arguments related to training specifics
parser.add_argument(
"--start_epoch",
type=int,
default=0,
help="The starting epoch number for resuming training. Default is 0.",
)
parser.add_argument(
"--epochs",
type=int,
default=1001,
help="Maximum number of epochs for training. Default is 1001.",
)
parser.add_argument(
"--save_every_n",
type=int,
default=2,
help="Frequency (in epochs) to save the model checkpoint and generate images. Default is every 2 epochs.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate for the optimizer. Default is 1e-4.",
)
parser.add_argument(
"--conditional",
action="store_true",
default=False,
help="Flag to use conditional generation. Default is False.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
device = T.device(args.device)
print(f"Using device: {device}")
print(f"Mode is {'training' if args.training else 'generating'}")
saving_path = args.saving_path
training = args.training
version = args.version
ds = args.dataset
img_size = args.img_size
batch_size = args.batch_size
start_epoch = args.start_epoch
epochs = args.epochs
conditional = args.conditional
save_every_n = args.save_every_n
lr = args.learning_rate
chans = 1 if ds == "fashion" else 3
n_classes = None
if conditional:
if ds == "fashion":
n_classes = 10
class_names = [
"T-shirt/top", # 0
"Trouser", # 1
"Pullover", # 2
"Dress", # 3
"Coat", # 4
"Sandal", # 5
"Shirt", # 6
"Sneaker", # 7
"Bag", # 8
"Ankle boot", # 9
]
elif ds == "cifar10":
class_names = [
"Airplane", # 0
"Automobile", # 1
"Bird", # 2
"Cat", # 3
"Deer", # 4
"Dog", # 5
"Frog", # 6
"Horse", # 7
"Ship", # 8
"Truck", # 9
]
n_classes = 10
model = SmallUnetWithEmb(img_channels=chans, n_classes=n_classes).to(device)
if start_epoch > 0:
model.load_state_dict(
T.load(
saving_path + f"/version{version}/models/epoch_{start_epoch-1}.pt",
map_location=device,
)
)
print("Num params: ", sum(p.numel() for p in model.parameters()))
diff = Diffusion(device=device, img_channels=chans, img_size=img_size)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = F.mse_loss
if not training:
print("Generating images")
if n_classes is not None:
print("Generating images with conditional generation")
if n_classes is not None:
classes = []
for cls in range(n_classes):
classes += [cls] * 2
classes = T.tensor(classes).to(device)
imgs = diff.generate_sample(
model, n_images=len(classes), labels=classes
)
class_names = [class_names[cls] for cls in classes.cpu().tolist()]
save_images(
imgs,
version=version,
epoch_n=start_epoch,
classes_list=class_names,
)
else:
print("Generating images without conditional generation")
imgs = diff.generate_sample(model, n_images=10)
save_images(imgs, version=version, epoch_n=start_epoch)
else:
print("Training model")
print(f"Using dataset: {ds} with image size: {img_size}")
print(f"Using batch size: {batch_size}")
print(f"Is conditional: {conditional}")
print(f"Number of classes: {n_classes}")
train_loader, val_loader = get_loaders(ds, batch_size, img_size)
train(
model=model,
diff=diff,
optimizer=optimizer,
criterion=criterion,
train_loader=train_loader,
val_loader=val_loader,
device=device,
apply_ema=False,
epochs=epochs,
version=version,
start_epoch=start_epoch,
n_classes=n_classes,
save_every_n=save_every_n,
)