-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
91 lines (58 loc) · 1.95 KB
/
train.py
File metadata and controls
91 lines (58 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import copy
import os
import numpy as np
import torch
from tqdm import tqdm
from model import Net
from loader import DataLoader
from val import eval
"""
Time series analysis on the UNSW-NB15 dataset:
Deep Learning for Intrusion Detection Systems (IDSs) in Time Series Data
"""
@torch.no_grad()
def update_ema(model, ema_model, alpha=0.999):
for p, ema_p in zip(model.parameters(), ema_model.parameters()):
ema_p.set_(alpha * ema_p.data + (1 - alpha) * p.data)
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# =================
device = "cuda" if torch.cuda.is_available() else "cpu"
# =================
model = Net(input_size=180)
ema_model = copy.deepcopy(model)
print(f"Parameters: {sum(p.numel() for p in model.parameters())//1000}k")
_, prog, vollo_stats = model.compile()
print(f"{vollo_stats=}")
print(model)
model = model.to(device)
ema_model = ema_model.to(device)
# =================
optimizer = torch.optim.AdamW(model.parameters())
# =================
loader = DataLoader(device=device)
for i in range(10):
for x, y in (
t := tqdm(loader.iter("train"), leave=False, total=loader.len("train"))
):
optimizer.zero_grad()
probs = model(x)
# The first column of y is attack/!attack
y = y[:, :, :1]
loss = torch.nn.functional.binary_cross_entropy(probs, y)
loss.backward()
optimizer.step()
t.set_description(f"Loss: {loss.item():.4f}")
update_ema(model, ema_model)
for p in optimizer.param_groups:
p["lr"] = max(1e-5, p["lr"] * 0.9)
print(f"Dev set - epoch {i}:")
for k, v in eval(ema_model, loader.iter("dev", drop_last=False), device).items():
print(f"\t{k}: {v}")
print("Test set:")
for k, v in eval(ema_model, loader.iter("test", drop_last=False), device).items():
print(f"\t{k}: {v}")
# Save the model
os.makedirs("build", exist_ok=True)
torch.save(model.state_dict(), "build/model.pt")