diff --git a/core/layer/AbstractProbLayer.py b/core/layer/AbstractProbLayer.py index b33b092..6b819dd 100644 --- a/core/layer/AbstractProbLayer.py +++ b/core/layer/AbstractProbLayer.py @@ -12,6 +12,8 @@ class AbstractProbLayer(nn.Module, ABC): _bias_dist: AbstractVariable _prior_weight_dist: AbstractVariable _prior_bias_dist: AbstractVariable + _sampled_weight: Tensor + _sampled_bias: Tensor def probabilistic(self, mode: bool = True): if not isinstance(mode, bool): @@ -32,4 +34,6 @@ def sample_from_distribution(self) -> Tuple[Tensor, Tensor]: sampled_bias = self._bias_dist.mu if self._bias_dist else None else: raise ValueError('Only training with probabilistic mode is allowed') + self._sampled_weight = sampled_weight + self._sampled_bias = sampled_bias return sampled_weight, sampled_bias diff --git a/core/objective/AbstractSamplingObjective.py b/core/objective/AbstractSamplingObjective.py new file mode 100644 index 0000000..72ba065 --- /dev/null +++ b/core/objective/AbstractSamplingObjective.py @@ -0,0 +1,14 @@ +from torch import Tensor +from typing import List + +from abc import ABC, abstractmethod + + +class AbstractSamplingObjective(ABC): + def __init__(self, kl_penalty: float, n: int) -> None: + self._kl_penalty = kl_penalty + self.n: int = n + + @abstractmethod + def calculate(self, losses: List[Tensor], kl: Tensor, num_samples: float) -> Tensor: + pass diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py new file mode 100644 index 0000000..034faa1 --- /dev/null +++ b/core/objective/IWAEObjective.py @@ -0,0 +1,112 @@ +import logging, math +from typing import Dict, Optional + +import torch, torch.distributions as dists, wandb +from torch import nn, Tensor + +from core.model import bounded_call +from core.layer.utils import get_torch_layers + + +class IWAEObjective: + def __init__(self, kl_penalty: float, n: int, temperature: float = 1.0) -> None: + self.k = n + self.kl_penalty = kl_penalty # usually 1 / |D| + self.temperature = temperature + + # -------- helpers to compute log p(w) and log q(w) ------------------- + @staticmethod + def _log_prior(model: nn.Module, eps: float = 1e-6) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + s = torch.zeros(1, device=device, dtype=dtype) + + for _, l in get_torch_layers(model): + s += dists.Normal(l._prior_weight_dist.mu, + l._prior_weight_dist.sigma + eps + ).log_prob(l._sampled_weight).sum() + s += dists.Normal(l._prior_bias_dist.mu, + l._prior_bias_dist.sigma + eps + ).log_prob(l._sampled_bias).sum() + return s + + @staticmethod + def _log_post(model: nn.Module, eps: float = 1e-6) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + s = torch.zeros(1, device=device, dtype=dtype) + + for _, l in get_torch_layers(model): + s += dists.Normal(l._weight_dist.mu, + l._weight_dist.sigma + eps + ).log_prob(l._sampled_weight).sum() + s += dists.Normal(l._bias_dist.mu, + l._bias_dist.sigma + eps + ).log_prob(l._sampled_bias).sum() + return s + + # -------------------------------------------------------------------- + def calculate( + self, + model: nn.Module, + data: Tensor, + target: Tensor, + epoch: int, + batch_idx: int, + dataset_size: int, + pmin: Optional[float] = None, + wandb_params: Optional[Dict] = None, + ) -> Tensor: + + + batch_size = data.size(0) + scale = dataset_size / batch_size # N / |B| + log_ws = [] # list[k] of scalars + + kl_pen = 1 / dataset_size + temp = 1.0 + + for l in range(self.k): + # sample w and compute log p(x|w) + logits = bounded_call(model, data, pmin) if pmin is not None else model(data) + + if torch.isnan(logits).any() or torch.isinf(logits).any(): + logging.warning(f"NaN/Inf in logits at epoch {epoch}, batch {batch_idx}") + logits = torch.where(torch.isfinite(logits), logits, torch.zeros_like(logits)) + + log_px = dists.Categorical(logits=logits).log_prob(target) # (batch,) + log_lik = scale * log_px.sum() # scalar + + # global KL part + kl = (self._log_prior(model) - self._log_post(model)) * kl_pen + log_w = log_lik + temp * kl # scalar + log_ws.append(log_w) + + # -------------------- per-sample logging -------------------- + if wandb_params and wandb_params.get("log_wandb", False): + tag = wandb_params["name_wandb"] + wandb.log({ + f"{tag}/epoch": epoch, + f"{tag}/batch": batch_idx, + f"{tag}/sample": l, + f"{tag}/log_likelihood": log_lik.detach(), + f"{tag}/kl": kl.detach(), + f"{tag}/log_weight": log_w.detach(), + }) + + # ----------- PB-IWAE loss (one scalar) --------------------------- + log_ws_tensor = torch.stack(log_ws) # (k,) + loss = -(torch.logsumexp(log_ws_tensor, dim=0) - math.log(self.k)) + + # ----------- final logging -------------------------------------- + if wandb_params and wandb_params.get("log_wandb", False): + wandb.log({f"{wandb_params['name_wandb']}/iwae_loss": loss}) + + if batch_idx == 0: + logging.info( + f"[Epoch {epoch:03d} | Batch {batch_idx:04d}] " + f"IWAE-loss {loss.item():.4f} " + f"| mean log_px {(log_px.mean()).item():.4f} " + f"| KL {kl.item():.2f}" + ) + return loss diff --git a/core/objective/NaiveIWAEObjective.py b/core/objective/NaiveIWAEObjective.py new file mode 100644 index 0000000..67ed381 --- /dev/null +++ b/core/objective/NaiveIWAEObjective.py @@ -0,0 +1,12 @@ +import torch +from torch import Tensor +from typing import List + +from core.objective import AbstractSamplingObjective + + +class NaiveIWAEObjective(AbstractSamplingObjective): + def calculate(self, losses: List[Tensor], kl: Tensor, num_samples: float) -> Tensor: + assert self.n == len(losses) + loss = sum(losses) / self.n + return loss + self._kl_penalty * (kl / num_samples) diff --git a/core/objective/__init__.py b/core/objective/__init__.py index 1287f39..29dad9c 100644 --- a/core/objective/__init__.py +++ b/core/objective/__init__.py @@ -1,6 +1,9 @@ from core.objective.AbstractObjective import AbstractObjective +from core.objective.AbstractSamplingObjective import AbstractSamplingObjective from core.objective.BBBObjective import BBBObjective from core.objective.FQuadObjective import FQuadObjective from core.objective.FClassicObjective import FClassicObjective from core.objective.McAllisterObjective import McAllisterObjective from core.objective.TolstikhinObjective import TolstikhinObjective +from core.objective.NaiveIWAEObjective import NaiveIWAEObjective +from core.objective.IWAEObjective import IWAEObjective diff --git a/core/training.py b/core/training.py index cae3ecb..8619650 100644 --- a/core/training.py +++ b/core/training.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from core.distribution.utils import compute_kl, DistributionT -from core.objective import AbstractObjective +from core.objective import AbstractObjective, AbstractSamplingObjective, IWAEObjective from core.model import bounded_call @@ -24,24 +24,61 @@ def train(model: nn.Module, wandb_params: Dict = None, ): criterion = torch.nn.NLLLoss() - optimizer = torch.optim.SGD(model.parameters(), - lr=parameters['lr'], - momentum=parameters['momentum']) + #optimizer = torch.optim.SGD(model.parameters(), + # lr=parameters['lr'], + # momentum=parameters['momentum']) + #just to try + optimizer = torch.optim.Adam(model.parameters(), + lr=parameters['lr']) if 'seed' in parameters: torch.manual_seed(parameters['seed']) + + loss = None + kl = None + objective_value = None + + dataset_size = len(train_loader.dataset) # new added + for epoch in range(parameters['epochs']): for i, (data, target) in tqdm(enumerate(train_loader)): data, target = data.to(device), target.to(device) optimizer.zero_grad() - if 'pmin' in parameters: - output = bounded_call(model, data, parameters['pmin']) + if isinstance(objective, AbstractObjective): + if 'pmin' in parameters: + output = bounded_call(model, data, parameters['pmin']) + else: + output = model(data) + loss = criterion(output, target) + kl = compute_kl(posterior, prior) + objective_value = objective.calculate(loss, kl, parameters['num_samples']) + elif isinstance(objective, AbstractSamplingObjective): + losses = [] + for j in range(objective.n): + if 'pmin' in parameters: + output = bounded_call(model, data, parameters['pmin']) + else: + output = model(data) + losses.append(criterion(output, target)) + kl = compute_kl(posterior, prior) + objective_value = objective.calculate(losses, kl, parameters['num_samples']) + loss = sum(losses) / objective.n + elif isinstance(objective, IWAEObjective): + objective_value = objective.calculate(model, + data, + target, + epoch=epoch, + batch_idx=i, + dataset_size= dataset_size, + pmin=parameters.get('pmin', None), + wandb_params=wandb_params) + with torch.no_grad(): + loss = criterion(model(data), target) + kl = compute_kl(posterior, prior) else: - output = model(data) - kl = compute_kl(posterior, prior) - loss = criterion(output, target) - objective_value = objective.calculate(loss, kl, parameters['num_samples']) + raise ValueError(f'Invalid objective type: {type(objective)}') objective_value.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) optimizer.step() logging.info(f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl/parameters['num_samples']}") if wandb_params is not None and wandb_params["log_wandb"]: diff --git a/scripts/generic_train.py b/scripts/generic_train.py index a654f72..6005f0d 100644 --- a/scripts/generic_train.py +++ b/scripts/generic_train.py @@ -2,7 +2,7 @@ import torch import logging -from core.split_strategy import PBPSplitStrategy +from core.split_strategy import FaultySplitStrategy from core.distribution.utils import from_copy, from_zeros, from_random from core.distribution import GaussianVariable from core.training import train @@ -20,18 +20,18 @@ logging.basicConfig(level=logging.INFO) config = { - 'log_wandb': True, + 'log_wandb': False, 'mcsamples': 1000, 'pmin': 1e-5, - 'sigma': 0.03, + 'sigma': 0.01, 'factory': { 'losses': ['nll_loss', 'scaled_nll_loss', '01_loss'], 'metrics': ['accuracy_micro_metric', 'accuracy_macro_metric', 'f1_micro_metric', 'f1_macro_metric'], 'bounds': ['kl', 'mcallister'], - 'data_loader': {'name': 'cifar10', - 'params': {'dataset_path': './data/cifar10'} - }, + # 'data_loader': {'name': 'cifar10', + # 'params': {'dataset_path': './data/cifar10'} + # }, # 'model': {'name': 'resnet', # 'params': {'num_channels': 3} # }, @@ -40,31 +40,37 @@ # 'hidden_dim': 100, # 'output_dim': 10} # }, - 'model': {'name': 'conv', - 'params': {'in_channels': 3, 'dataset': 'cifar10'} - }, + # 'model': {'name': 'conv', + # 'params': {'in_channels': 3, 'dataset': 'cifar10'} + # }, # 'model': {'name': 'conv15', # 'params': {'in_channels': 3, 'dataset': 'cifar10'} # }, - # 'data_loader': {'name': 'mnist', - # 'params': {'dataset_path': './data/mnist'} - # }, - # 'model': {'name': 'nn', - # 'params': {'input_dim': 28*28, - # 'hidden_dim': 100, - # 'output_dim': 10} - # }, + 'data_loader': {'name': 'mnist', + 'params': {'dataset_path': './data/mnist'} + }, + 'model': {'name': 'nn', + 'params': {'input_dim': 28*28, + 'hidden_dim': 100, + 'output_dim': 10} + }, # 'model': {'name': 'conv', # 'params': {'in_channels': 1, 'dataset': 'mnist'} # }, - 'prior_objective': {'name': 'fquad', + # 'prior_objective': {'name': 'bbb', + # 'params': {'kl_penalty': 0.001, + # # 'delta': 0.025 + # } + # }, + 'prior_objective': {'name': 'iwae',#iwae or naive iwae from objective factory 'params': {'kl_penalty': 0.001, - 'delta': 0.025 + 'n': 10, + 'temperature': 1e-4, } }, - 'posterior_objective': {'name': 'fquad', + 'posterior_objective': {'name': 'bbb', 'params': {'kl_penalty': 1.0, - 'delta': 0.025 + # 'delta': 0.025 } }, }, @@ -89,15 +95,15 @@ }, 'prior': { 'training': { - 'lr': 0.001, - 'momentum': 0.95, + 'lr': 0.0005, + 'momentum': 0.9, 'epochs': 100, 'seed': 1135, } }, 'posterior': { 'training': { - 'lr': 0.001, + 'lr': 0.0001, 'momentum': 0.9, 'epochs': 1, 'seed': 1135, @@ -135,11 +141,11 @@ def main(): loader = data_loader_factory.create(config["factory"]["data_loader"]["name"], **config["factory"]["data_loader"]["params"]) - strategy = PBPSplitStrategy(prior_type=config['split_strategy']['prior_type'], - train_percent=config['split_strategy']['train_percent'], - val_percent=config['split_strategy']['val_percent'], - prior_percent=config['split_strategy']['prior_percent'], - self_certified=config['split_strategy']['self_certified']) + strategy = FaultySplitStrategy(prior_type=config['split_strategy']['prior_type'], + train_percent=config['split_strategy']['train_percent'], + val_percent=config['split_strategy']['val_percent'], + prior_percent=config['split_strategy']['prior_percent'], + self_certified=config['split_strategy']['self_certified']) strategy.split(loader, split_config=config['split_config']) # Model @@ -148,14 +154,19 @@ def main(): model = model_factory.create(config["factory"]["model"]["name"], **config["factory"]["model"]["params"]) torch.manual_seed(config['dist_init']['seed']) - prior_prior = from_zeros(model=model, - rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), - distribution=GaussianVariable, - requires_grad=False) - prior = from_random(model=model, - rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), - distribution=GaussianVariable, - requires_grad=True) + # prior_prior = from_zeros(model=model, + # rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + # distribution=GaussianVariable, + # requires_grad=False) + prior_prior = from_random(model=model, + rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + distribution=GaussianVariable, + requires_grad=False) + # prior = from_random(model=model, + # rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + # distribution=GaussianVariable, + # requires_grad=True) + prior = from_copy(dist=prior_prior, distribution=GaussianVariable, requires_grad=True) dnn_to_probnn(model, prior, prior_prior) model.to(device) @@ -183,15 +194,15 @@ def main(): wandb_params={'log_wandb': config["log_wandb"], 'name_wandb': 'Prior Train'}) - if strategy.test_loader is not None: - _ = evaluate_metrics(model=model, - metrics=metrics, - test_loader=strategy.test_loader, - num_samples_metric=config["mcsamples"], - device=device, - pmin=config["pmin"], - wandb_params={'log_wandb': config["log_wandb"], - 'name_wandb': 'Prior Evaluation'}) + # if strategy.test_loader is not None: + # _ = evaluate_metrics(model=model, + # metrics=metrics, + # test_loader=strategy.test_loader, + # num_samples_metric=config["mcsamples"], + # device=device, + # pmin=config["pmin"], + # wandb_params={'log_wandb': config["log_wandb"], + # 'name_wandb': 'Prior Evaluation'}) _ = certify_risk(model=model, bounds=bounds, diff --git a/scripts/generic_train_2.py b/scripts/generic_train_2.py index fe1ec0b..900a325 100644 --- a/scripts/generic_train_2.py +++ b/scripts/generic_train_2.py @@ -57,14 +57,19 @@ # 'model': {'name': 'conv', # 'params': {'in_channels': 1, 'dataset': 'mnist'} # }, - 'prior_objective': {'name': 'fquad', + # 'prior_objective': {'name': 'bbb', + # 'params': {'kl_penalty': 0.001, + # # 'delta': 0.025 + # } + # }, + 'prior_objective': {'name': 'bbb', 'params': {'kl_penalty': 0.001, - 'delta': 0.025 + # 'delta': 0.025 } }, - 'posterior_objective': {'name': 'fquad', + 'posterior_objective': {'name': 'bbb', 'params': {'kl_penalty': 1.0, - 'delta': 0.025 + # 'delta': 0.025 } }, }, @@ -148,13 +153,14 @@ def main(): model = model_factory.create(config["factory"]["model"]["name"], **config["factory"]["model"]["params"]) torch.manual_seed(config['dist_init']['seed']) - prior_prior = from_random(model=model, + prior_prior = from_zeros(model=model, rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), distribution=GaussianVariable, requires_grad=False) - prior = from_copy(dist=prior_prior, - distribution=GaussianVariable, - requires_grad=True) + prior = from_random(model=model, + rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + distribution=GaussianVariable, + requires_grad=True) dnn_to_probnn(model, prior, prior_prior) model.to(device) @@ -191,18 +197,18 @@ def main(): # pmin=config["pmin"], # wandb_params={'log_wandb': config["log_wandb"], # 'name_wandb': 'Prior Evaluation'}) - # - # _ = certify_risk(model=model, - # bounds=bounds, - # losses=losses, - # posterior=prior, - # prior=prior_prior, - # bound_loader=strategy.bound_loader, - # num_samples_loss=config["mcsamples"], - # device=device, - # pmin=config["pmin"], - # wandb_params={'log_wandb': config["log_wandb"], - # 'name_wandb': 'Prior Bound'}) + + _ = certify_risk(model=model, + bounds=bounds, + losses=losses, + posterior=prior, + prior=prior_prior, + bound_loader=strategy.bound_loader, + num_samples_loss=config["mcsamples"], + device=device, + pmin=config["pmin"], + wandb_params={'log_wandb': config["log_wandb"], + 'name_wandb': 'Prior Bound'}) posterior_prior = from_copy(dist=prior, distribution=GaussianVariable, diff --git a/scripts/generic_train_3.py b/scripts/generic_train_3.py new file mode 100644 index 0000000..c830eeb --- /dev/null +++ b/scripts/generic_train_3.py @@ -0,0 +1,271 @@ +import wandb +import torch +import logging + +from core.split_strategy import PBPSplitStrategy +from core.distribution.utils import from_copy, from_zeros, from_random +from core.distribution import GaussianVariable +from core.training import train +from core.model import dnn_to_probnn, update_dist +from core.risk import certify_risk +from core.metric import evaluate_metrics + +from scripts.utils.factory import (LossFactory, + MetricFactory, + BoundFactory, + DataLoaderFactory, + ModelFactory, + ObjectiveFactory) + +logging.basicConfig(level=logging.INFO) + +config = { + 'log_wandb': True, + 'mcsamples': 1000, + 'pmin': 1e-5, + 'sigma': 0.03, + 'factory': + { + 'losses': ['nll_loss', 'scaled_nll_loss', '01_loss'], + 'metrics': ['accuracy_micro_metric', 'accuracy_macro_metric', 'f1_micro_metric', 'f1_macro_metric'], + 'bounds': ['kl', 'mcallister'], + 'data_loader': {'name': 'cifar10', + 'params': {'dataset_path': './data/cifar10'} + }, + # 'model': {'name': 'resnet', + # 'params': {'num_channels': 3} + # }, + # 'model': {'name': 'nn', + # 'params': {'input_dim': 32*32*3, + # 'hidden_dim': 100, + # 'output_dim': 10} + # }, + 'model': {'name': 'conv', + 'params': {'in_channels': 3, 'dataset': 'cifar10'} + }, + # 'model': {'name': 'conv15', + # 'params': {'in_channels': 3, 'dataset': 'cifar10'} + # }, + # 'data_loader': {'name': 'mnist', + # 'params': {'dataset_path': './data/mnist'} + # }, + # 'model': {'name': 'nn', + # 'params': {'input_dim': 28*28, + # 'hidden_dim': 100, + # 'output_dim': 10} + # }, + # 'model': {'name': 'conv', + # 'params': {'in_channels': 1, 'dataset': 'mnist'} + # }, + # 'prior_objective': {'name': 'bbb', + # 'params': {'kl_penalty': 0.001, + # # 'delta': 0.025 + # } + # }, + 'prior_objective': {'name': 'fquad', + 'params': {'kl_penalty': 0.001, + 'delta': 0.025, + } + }, + 'posterior_objective': {'name': 'bbb', + 'params': {'kl_penalty': 1.0, + # 'delta': 0.025 + } + }, + }, + 'bound': { + 'delta': 0.025, + 'delta_test': 0.01, + }, + 'split_config': { + 'seed': 111, + 'dataset_loader_seed': 112, + 'batch_size': 250, + }, + 'dist_init': { + 'seed': 110, + }, + 'split_strategy': { + 'prior_type': 'learnt', + 'train_percent': 1.0, + 'val_percent': 0.05, + 'prior_percent': .5, + 'self_certified': True, + }, + 'prior': { + 'training': { + 'lr': 0.001, + 'momentum': 0.95, + 'epochs': 100, + 'seed': 1135, + } + }, + 'posterior': { + 'training': { + 'lr': 0.001, + 'momentum': 0.9, + 'epochs': 1, + 'seed': 1135, + } + } +} + + +def main(): + if config['log_wandb']: + wandb.init(project='pbb-framework', config=config) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Device ", device) + # Losses + logging.info(f'Selected losses: {config["factory"]["losses"]}') + loss_factory = LossFactory() + losses = {loss_name: loss_factory.create(loss_name) for loss_name in config["factory"]["losses"]} + + # Metrics + logging.info(f'Select metrics: {config["factory"]["metrics"]}') + metric_factory = MetricFactory() + metrics = {metric_name: metric_factory.create(metric_name) for metric_name in config["factory"]["metrics"]} + + # Bound + logging.info(f'Selected bounds: {config["factory"]["bounds"]}') + bound_factory = BoundFactory() + bounds = {bound_name: bound_factory.create(bound_name, + bound_delta=config['bound']['delta'], + loss_delta=config['bound']['delta_test']) + for bound_name in config["factory"]["bounds"]} + + # Data + logging.info(f'Selected data loader: {config["factory"]["data_loader"]}') + data_loader_factory = DataLoaderFactory() + loader = data_loader_factory.create(config["factory"]["data_loader"]["name"], + **config["factory"]["data_loader"]["params"]) + + strategy = PBPSplitStrategy(prior_type=config['split_strategy']['prior_type'], + train_percent=config['split_strategy']['train_percent'], + val_percent=config['split_strategy']['val_percent'], + prior_percent=config['split_strategy']['prior_percent'], + self_certified=config['split_strategy']['self_certified']) + strategy.split(loader, split_config=config['split_config']) + + # Model + logging.info(f'Select model: {config["factory"]["model"]["name"]}') + model_factory = ModelFactory() + model = model_factory.create(config["factory"]["model"]["name"], **config["factory"]["model"]["params"]) + + torch.manual_seed(config['dist_init']['seed']) + prior_prior = from_zeros(model=model, + rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + distribution=GaussianVariable, + requires_grad=False) + prior = from_random(model=model, + rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + distribution=GaussianVariable, + requires_grad=True) + dnn_to_probnn(model, prior, prior_prior) + model.to(device) + + # Training prior + train_params = { + 'lr': config['prior']['training']['lr'], + 'momentum': config['prior']['training']['momentum'], + 'epochs': config['prior']['training']['epochs'], + 'seed': config['prior']['training']['seed'], + 'num_samples': strategy.prior_loader.batch_size * len(strategy.prior_loader), + } + logging.info(f'Select objective: {config["factory"]["prior_objective"]["name"]}') + objective_factory = ObjectiveFactory() + objective = objective_factory.create(config["factory"]["prior_objective"]["name"], + **config["factory"]["prior_objective"]["params"]) + + train(model=model, + posterior=prior, + prior=prior_prior, + objective=objective, + train_loader=strategy.prior_loader, + val_loader=strategy.val_loader, + parameters=train_params, + device=device, + wandb_params={'log_wandb': config["log_wandb"], + 'name_wandb': 'Prior Train'}) + + # if strategy.test_loader is not None: + # _ = evaluate_metrics(model=model, + # metrics=metrics, + # test_loader=strategy.test_loader, + # num_samples_metric=config["mcsamples"], + # device=device, + # pmin=config["pmin"], + # wandb_params={'log_wandb': config["log_wandb"], + # 'name_wandb': 'Prior Evaluation'}) + + _ = certify_risk(model=model, + bounds=bounds, + losses=losses, + posterior=prior, + prior=prior_prior, + bound_loader=strategy.bound_loader, + num_samples_loss=config["mcsamples"], + device=device, + pmin=config["pmin"], + wandb_params={'log_wandb': config["log_wandb"], + 'name_wandb': 'Prior Bound'}) + + posterior_prior = from_copy(dist=prior, + distribution=GaussianVariable, + requires_grad=False) + posterior = from_copy(dist=prior, + distribution=GaussianVariable, + requires_grad=True) + update_dist(model, weight_dist=posterior, prior_weight_dist=posterior_prior) + model.to(device) + + # Train posterior + train_params = { + 'lr': config['posterior']['training']['lr'], + 'momentum': config['posterior']['training']['momentum'], + 'epochs': config['posterior']['training']['epochs'], + 'seed': config['posterior']['training']['seed'], + 'num_samples': strategy.posterior_loader.batch_size * len(strategy.posterior_loader), + } + + logging.info(f'Select objective: {config["factory"]["posterior_objective"]["name"]}') + objective = objective_factory.create(config["factory"]["posterior_objective"]["name"], + **config["factory"]["posterior_objective"]["params"]) + + train(model=model, + posterior=posterior, + prior=posterior_prior, + objective=objective, + train_loader=strategy.posterior_loader, + val_loader=strategy.val_loader, + parameters=train_params, + device=device, + wandb_params={'log_wandb': config["log_wandb"], + 'name_wandb': 'Posterior Train'}) + + # if strategy.test_loader is not None: + # _ = evaluate_metrics(model=model, + # metrics=metrics, + # test_loader=strategy.test_loader, + # num_samples_metric=config["mcsamples"], + # device=device, + # pmin=config["pmin"], + # wandb_params={'log_wandb': config["log_wandb"], + # 'name_wandb': 'Posterior Evaluation'}) + + _ = certify_risk(model=model, + bounds=bounds, + losses=losses, + posterior=posterior, + prior=posterior_prior, + bound_loader=strategy.bound_loader, + num_samples_loss=config["mcsamples"], + device=device, + pmin=config["pmin"], + wandb_params={'log_wandb': config["log_wandb"], + 'name_wandb': 'Posterior Bound'}) + + +if __name__ == '__main__': + main() + diff --git a/scripts/utils/factory/ObjectiveFactory.py b/scripts/utils/factory/ObjectiveFactory.py index 8dbf3dc..bec654d 100644 --- a/scripts/utils/factory/ObjectiveFactory.py +++ b/scripts/utils/factory/ObjectiveFactory.py @@ -1,4 +1,11 @@ -from core.objective import AbstractObjective, FClassicObjective, McAllisterObjective, FQuadObjective, BBBObjective, TolstikhinObjective +from core.objective import (AbstractObjective, + FClassicObjective, + McAllisterObjective, + FQuadObjective, + BBBObjective, + TolstikhinObjective, + NaiveIWAEObjective, + IWAEObjective) from scripts.utils.factory import AbstractFactory @@ -11,3 +18,5 @@ def __init__(self) -> None: self.register_creator("fquad", FQuadObjective) self.register_creator("mcallister", McAllisterObjective) self.register_creator("tolstikhin", TolstikhinObjective) + self.register_creator("naive_iwae", NaiveIWAEObjective) + self.register_creator("iwae", IWAEObjective)