-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathengine.py
More file actions
executable file
·116 lines (79 loc) · 3.38 KB
/
engine.py
File metadata and controls
executable file
·116 lines (79 loc) · 3.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import lightning.pytorch as pl
import torch
class FeatureExtractor(pl.LightningModule):
def __init__(self, model, loss_function, optimizer, scheduler):
super().__init__()
# ⚡ model
self.model = model
print(self.model)
# ⚡ loss
self.loss_function = loss_function
# ⚡ optimizer
self.optimizer = optimizer
# ⚡ scheduler
self.scheduler = scheduler # **kwargs: **config['scheduler_config']
# save hyperparameters
self.save_hyperparameters(ignore=['model'])
#⚡⚡⚡ debugging - print input output layer ⚡⚡⚡
self.example_input_array = torch.Tensor(64, 1, 28, 28)
# for validation & test
self.training_step_outputs = [] # not used, but I want to keep it for future implementation
self.validation_step_outputs = []
# ===============================================================
# ⚡⚡ Train
# ===============================================================
def training_step(self, batch, batch_idx):
x, y = batch
# preprocess
# inference
y_hat = self.model(x)
# post processing
# calculate loss
loss = self.loss_function(y_hat, y)
self.training_step_outputs.append(loss)
# Logging to TensorBoard
self.log("loss", loss, on_epoch= True, prog_bar=True, logger=True)
return loss
def on_train_epoch_end(self):
self.training_step_outputs.clear() # free memory
# ===============================================================
# ⚡⚡ Validation
# ===============================================================
def validation_step(self, batch, batch_idx):
# this is the validation loop
x, y = batch
y_hat = self.model(x)
loss = self.loss_function(y_hat, y)
self.log("val_loss", loss, on_epoch= True, prog_bar=True, logger=True)
correct = (y_hat.argmax(1) == y).type(torch.float).sum().item()
size = x.shape[0]
validation_step_output = {'correct': correct, 'size': size}
self.validation_step_outputs.append(validation_step_output)
return validation_step_output
def on_validation_epoch_end(self):
correct_score = sum([dic['correct'] for dic in self.validation_step_outputs])
total_size = sum([dic['size'] for dic in self.validation_step_outputs])
acc = correct_score/total_size
self.log("val_ACC", acc * 100, on_epoch = True, prog_bar=True, sync_dist=True)
self.validation_step_outputs.clear() # free memory
# ===============================================================
# ⚡⚡ test
# ===============================================================
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def on_test_epoch_end(self):
return self.on_validation_epoch_end()
def forward(self, x):
y_hat = self.model(x)
return y_hat
def configure_optimizers(self):
return {
"optimizer": self.optimizer,
"lr_scheduler": {
"scheduler": self.scheduler,
'interval': self.scheduler.update, # ⚡⚡ "step" or "epoch"
"monitor": "val_loss",
"frequency": 1,
"name": "lr_log",
},
}