Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/layer/AbstractProbLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class AbstractProbLayer(nn.Module, ABC):
_bias_dist: AbstractVariable
_prior_weight_dist: AbstractVariable
_prior_bias_dist: AbstractVariable
_sampled_weight: Tensor
_sampled_bias: Tensor

def probabilistic(self, mode: bool = True):
if not isinstance(mode, bool):
Expand All @@ -32,4 +34,6 @@ def sample_from_distribution(self) -> Tuple[Tensor, Tensor]:
sampled_bias = self._bias_dist.mu if self._bias_dist else None
else:
raise ValueError('Only training with probabilistic mode is allowed')
self._sampled_weight = sampled_weight
self._sampled_bias = sampled_bias
return sampled_weight, sampled_bias
14 changes: 14 additions & 0 deletions core/objective/AbstractSamplingObjective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from torch import Tensor
from typing import List

from abc import ABC, abstractmethod


class AbstractSamplingObjective(ABC):
def __init__(self, kl_penalty: float, n: int) -> None:
self._kl_penalty = kl_penalty
self.n: int = n

@abstractmethod
def calculate(self, losses: List[Tensor], kl: Tensor, num_samples: float) -> Tensor:
pass
112 changes: 112 additions & 0 deletions core/objective/IWAEObjective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import logging, math
from typing import Dict, Optional

import torch, torch.distributions as dists, wandb
from torch import nn, Tensor

from core.model import bounded_call
from core.layer.utils import get_torch_layers


class IWAEObjective:
def __init__(self, kl_penalty: float, n: int, temperature: float = 1.0) -> None:
self.k = n
self.kl_penalty = kl_penalty # usually 1 / |D|
self.temperature = temperature

# -------- helpers to compute log p(w) and log q(w) -------------------
@staticmethod
def _log_prior(model: nn.Module, eps: float = 1e-6) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
s = torch.zeros(1, device=device, dtype=dtype)

for _, l in get_torch_layers(model):
s += dists.Normal(l._prior_weight_dist.mu,
l._prior_weight_dist.sigma + eps
).log_prob(l._sampled_weight).sum()
s += dists.Normal(l._prior_bias_dist.mu,
l._prior_bias_dist.sigma + eps
).log_prob(l._sampled_bias).sum()
return s

@staticmethod
def _log_post(model: nn.Module, eps: float = 1e-6) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
s = torch.zeros(1, device=device, dtype=dtype)

for _, l in get_torch_layers(model):
s += dists.Normal(l._weight_dist.mu,
l._weight_dist.sigma + eps
).log_prob(l._sampled_weight).sum()
s += dists.Normal(l._bias_dist.mu,
l._bias_dist.sigma + eps
).log_prob(l._sampled_bias).sum()
return s

# --------------------------------------------------------------------
def calculate(
self,
model: nn.Module,
data: Tensor,
target: Tensor,
epoch: int,
batch_idx: int,
dataset_size: int,
pmin: Optional[float] = None,
wandb_params: Optional[Dict] = None,
) -> Tensor:


batch_size = data.size(0)
scale = dataset_size / batch_size # N / |B|
log_ws = [] # list[k] of scalars

kl_pen = 1 / dataset_size
temp = 1.0

for l in range(self.k):
# sample w and compute log p(x|w)
logits = bounded_call(model, data, pmin) if pmin is not None else model(data)

if torch.isnan(logits).any() or torch.isinf(logits).any():
logging.warning(f"NaN/Inf in logits at epoch {epoch}, batch {batch_idx}")
logits = torch.where(torch.isfinite(logits), logits, torch.zeros_like(logits))

log_px = dists.Categorical(logits=logits).log_prob(target) # (batch,)
log_lik = scale * log_px.sum() # scalar

# global KL part
kl = (self._log_prior(model) - self._log_post(model)) * kl_pen
log_w = log_lik + temp * kl # scalar
log_ws.append(log_w)

# -------------------- per-sample logging --------------------
if wandb_params and wandb_params.get("log_wandb", False):
tag = wandb_params["name_wandb"]
wandb.log({
f"{tag}/epoch": epoch,
f"{tag}/batch": batch_idx,
f"{tag}/sample": l,
f"{tag}/log_likelihood": log_lik.detach(),
f"{tag}/kl": kl.detach(),
f"{tag}/log_weight": log_w.detach(),
})

# ----------- PB-IWAE loss (one scalar) ---------------------------
log_ws_tensor = torch.stack(log_ws) # (k,)
loss = -(torch.logsumexp(log_ws_tensor, dim=0) - math.log(self.k))

# ----------- final logging --------------------------------------
if wandb_params and wandb_params.get("log_wandb", False):
wandb.log({f"{wandb_params['name_wandb']}/iwae_loss": loss})

if batch_idx == 0:
logging.info(
f"[Epoch {epoch:03d} | Batch {batch_idx:04d}] "
f"IWAE-loss {loss.item():.4f} "
f"| mean log_px {(log_px.mean()).item():.4f} "
f"| KL {kl.item():.2f}"
)
return loss
12 changes: 12 additions & 0 deletions core/objective/NaiveIWAEObjective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
from torch import Tensor
from typing import List

from core.objective import AbstractSamplingObjective


class NaiveIWAEObjective(AbstractSamplingObjective):
def calculate(self, losses: List[Tensor], kl: Tensor, num_samples: float) -> Tensor:
assert self.n == len(losses)
loss = sum(losses) / self.n
return loss + self._kl_penalty * (kl / num_samples)
3 changes: 3 additions & 0 deletions core/objective/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from core.objective.AbstractObjective import AbstractObjective
from core.objective.AbstractSamplingObjective import AbstractSamplingObjective
from core.objective.BBBObjective import BBBObjective
from core.objective.FQuadObjective import FQuadObjective
from core.objective.FClassicObjective import FClassicObjective
from core.objective.McAllisterObjective import McAllisterObjective
from core.objective.TolstikhinObjective import TolstikhinObjective
from core.objective.NaiveIWAEObjective import NaiveIWAEObjective
from core.objective.IWAEObjective import IWAEObjective
57 changes: 47 additions & 10 deletions core/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.tensorboard import SummaryWriter

from core.distribution.utils import compute_kl, DistributionT
from core.objective import AbstractObjective
from core.objective import AbstractObjective, AbstractSamplingObjective, IWAEObjective
from core.model import bounded_call


Expand All @@ -24,24 +24,61 @@ def train(model: nn.Module,
wandb_params: Dict = None,
):
criterion = torch.nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(),
lr=parameters['lr'],
momentum=parameters['momentum'])
#optimizer = torch.optim.SGD(model.parameters(),
# lr=parameters['lr'],
# momentum=parameters['momentum'])

#just to try
optimizer = torch.optim.Adam(model.parameters(),
lr=parameters['lr'])
if 'seed' in parameters:
torch.manual_seed(parameters['seed'])

loss = None
kl = None
objective_value = None

dataset_size = len(train_loader.dataset) # new added

for epoch in range(parameters['epochs']):
for i, (data, target) in tqdm(enumerate(train_loader)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
if 'pmin' in parameters:
output = bounded_call(model, data, parameters['pmin'])
if isinstance(objective, AbstractObjective):
if 'pmin' in parameters:
output = bounded_call(model, data, parameters['pmin'])
else:
output = model(data)
loss = criterion(output, target)
kl = compute_kl(posterior, prior)
objective_value = objective.calculate(loss, kl, parameters['num_samples'])
elif isinstance(objective, AbstractSamplingObjective):
losses = []
for j in range(objective.n):
if 'pmin' in parameters:
output = bounded_call(model, data, parameters['pmin'])
else:
output = model(data)
losses.append(criterion(output, target))
kl = compute_kl(posterior, prior)
objective_value = objective.calculate(losses, kl, parameters['num_samples'])
loss = sum(losses) / objective.n
elif isinstance(objective, IWAEObjective):
objective_value = objective.calculate(model,
data,
target,
epoch=epoch,
batch_idx=i,
dataset_size= dataset_size,
pmin=parameters.get('pmin', None),
wandb_params=wandb_params)
with torch.no_grad():
loss = criterion(model(data), target)
kl = compute_kl(posterior, prior)
else:
output = model(data)
kl = compute_kl(posterior, prior)
loss = criterion(output, target)
objective_value = objective.calculate(loss, kl, parameters['num_samples'])
raise ValueError(f'Invalid objective type: {type(objective)}')
objective_value.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
logging.info(f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl/parameters['num_samples']}")
if wandb_params is not None and wandb_params["log_wandb"]:
Expand Down
Loading