-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
135 lines (107 loc) · 4.21 KB
/
Copy pathtrain.py
File metadata and controls
135 lines (107 loc) · 4.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Training utilities for :class:`CausalConvLSTM`.
Implements an L2-regularised Binary Cross-Entropy objective. The L2 term
``(lambda / 2) * sum(theta_j^2)`` is realised via the optimizer's ``weight_decay``,
which is the standard, numerically-stable way to add L2 regularisation in PyTorch.
The device (CPU/GPU) is auto-detected so the code never crashes on CPU-only machines,
and the log-key vocabulary is persisted next to the weights so inference indices match
training exactly.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Sequence, Tuple
import torch
import torch.nn as nn
from src.models.causal_lstm import CausalConvLSTM
# Reserved index for unknown / out-of-vocabulary log keys at inference time.
UNK_TOKEN = "<UNK>"
def get_device() -> torch.device:
"""Return CUDA if available, otherwise CPU."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def build_vocab(sequences: Sequence[Sequence[str]]) -> Dict[str, int]:
"""Build a deterministic log-key -> index vocabulary.
Index 0 is reserved for :data:`UNK_TOKEN`. Keys are assigned in sorted order so the
mapping is reproducible across runs.
"""
keys = sorted({k for seq in sequences for k in seq})
vocab: Dict[str, int] = {UNK_TOKEN: 0}
for key in keys:
vocab[key] = len(vocab)
return vocab
def encode_window(window: Sequence[str], vocab: Dict[str, int]) -> List[int]:
"""Encode a window of log keys to indices, mapping unknowns to ``<UNK>``."""
unk = vocab[UNK_TOKEN]
return [vocab.get(key, unk) for key in window]
@dataclass
class TrainConfig:
"""Hyper-parameters for the training loop."""
epochs: int = 30
batch_size: int = 16
learning_rate: float = 0.005
l2_lambda: float = 1e-4
seed: int = 42
def train_model(
model: CausalConvLSTM,
windows: List[List[int]],
labels: List[float],
config: TrainConfig,
device: torch.device | None = None,
) -> List[float]:
"""Train ``model`` with L2-regularised BCE and return per-epoch losses.
Args:
model: The model to train (moved to ``device`` in-place).
windows: List of equal-length index windows.
labels: Binary anomaly labels (0.0 normal, 1.0 anomaly).
config: Training hyper-parameters.
device: Target device; auto-detected when ``None``.
"""
device = device or get_device()
torch.manual_seed(config.seed)
model.to(device)
model.train()
x = torch.tensor(windows, dtype=torch.long, device=device)
y = torch.tensor(labels, dtype=torch.float32, device=device)
criterion = nn.BCELoss()
# weight_decay = lambda implements the (lambda/2) * sum(theta^2) L2 penalty.
optimizer = torch.optim.Adam(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.l2_lambda,
)
n = x.shape[0]
losses: List[float] = []
for _ in range(config.epochs):
perm = torch.randperm(n, device=device)
epoch_loss = 0.0
batches = 0
for start in range(0, n, config.batch_size):
idx = perm[start : start + config.batch_size]
optimizer.zero_grad()
preds = model(x[idx])
loss = criterion(preds, y[idx])
loss.backward()
optimizer.step()
epoch_loss += float(loss.item())
batches += 1
losses.append(epoch_loss / max(batches, 1))
return losses
def save_artifacts(
model: CausalConvLSTM,
vocab: Dict[str, int],
out_dir: str | Path,
) -> Tuple[Path, Path]:
"""Persist model weights and the vocabulary together.
Returns the paths to ``model.pt`` and ``vocab.json``.
"""
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
weights_path = out / "model.pt"
vocab_path = out / "vocab.json"
torch.save(model.state_dict(), weights_path)
vocab_path.write_text(json.dumps(vocab, ensure_ascii=False, indent=2))
return weights_path, vocab_path
def load_vocab(vocab_path: str | Path) -> Dict[str, int]:
"""Load a vocabulary previously saved by :func:`save_artifacts`."""
data = json.loads(Path(vocab_path).read_text())
return {str(k): int(v) for k, v in data.items()}