From 49019515ee6168c47c2a58c2fb4254ba47f411fe Mon Sep 17 00:00:00 2001 From: yauhenii Date: Sun, 11 Aug 2024 13:46:37 +0200 Subject: [PATCH 1/9] implement basic iwae --- core/objective/AbstractSamplingObjective.py | 14 +++++++++ core/objective/NaiveIWAEObjective.py | 12 +++++++ core/objective/__init__.py | 2 ++ core/training.py | 28 ++++++++++++----- scripts/generic_train.py | 35 ++++++++++++--------- scripts/utils/factory/ObjectiveFactory.py | 3 +- 6 files changed, 71 insertions(+), 23 deletions(-) create mode 100644 core/objective/AbstractSamplingObjective.py create mode 100644 core/objective/NaiveIWAEObjective.py diff --git a/core/objective/AbstractSamplingObjective.py b/core/objective/AbstractSamplingObjective.py new file mode 100644 index 0000000..72ba065 --- /dev/null +++ b/core/objective/AbstractSamplingObjective.py @@ -0,0 +1,14 @@ +from torch import Tensor +from typing import List + +from abc import ABC, abstractmethod + + +class AbstractSamplingObjective(ABC): + def __init__(self, kl_penalty: float, n: int) -> None: + self._kl_penalty = kl_penalty + self.n: int = n + + @abstractmethod + def calculate(self, losses: List[Tensor], kl: Tensor, num_samples: float) -> Tensor: + pass diff --git a/core/objective/NaiveIWAEObjective.py b/core/objective/NaiveIWAEObjective.py new file mode 100644 index 0000000..67ed381 --- /dev/null +++ b/core/objective/NaiveIWAEObjective.py @@ -0,0 +1,12 @@ +import torch +from torch import Tensor +from typing import List + +from core.objective import AbstractSamplingObjective + + +class NaiveIWAEObjective(AbstractSamplingObjective): + def calculate(self, losses: List[Tensor], kl: Tensor, num_samples: float) -> Tensor: + assert self.n == len(losses) + loss = sum(losses) / self.n + return loss + self._kl_penalty * (kl / num_samples) diff --git a/core/objective/__init__.py b/core/objective/__init__.py index 1287f39..a6e2f45 100644 --- a/core/objective/__init__.py +++ b/core/objective/__init__.py @@ -1,6 +1,8 @@ from core.objective.AbstractObjective import AbstractObjective +from core.objective.AbstractSamplingObjective import AbstractSamplingObjective from core.objective.BBBObjective import BBBObjective from core.objective.FQuadObjective import FQuadObjective from core.objective.FClassicObjective import FClassicObjective from core.objective.McAllisterObjective import McAllisterObjective from core.objective.TolstikhinObjective import TolstikhinObjective +from core.objective.NaiveIWAEObjective import NaiveIWAEObjective diff --git a/core/training.py b/core/training.py index cae3ecb..6733a3a 100644 --- a/core/training.py +++ b/core/training.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from core.distribution.utils import compute_kl, DistributionT -from core.objective import AbstractObjective +from core.objective import AbstractObjective, AbstractSamplingObjective from core.model import bounded_call @@ -34,13 +34,27 @@ def train(model: nn.Module, for i, (data, target) in tqdm(enumerate(train_loader)): data, target = data.to(device), target.to(device) optimizer.zero_grad() - if 'pmin' in parameters: - output = bounded_call(model, data, parameters['pmin']) + if isinstance(objective, AbstractObjective): + if 'pmin' in parameters: + output = bounded_call(model, data, parameters['pmin']) + else: + output = model(data) + loss = criterion(output, target) + kl = compute_kl(posterior, prior) + objective_value = objective.calculate(loss, kl, parameters['num_samples']) + elif isinstance(objective, AbstractSamplingObjective): + losses = [] + for i in range(objective.n): + if 'pmin' in parameters: + output = bounded_call(model, data, parameters['pmin']) + else: + output = model(data) + losses.append(criterion(output, target)) + kl = compute_kl(posterior, prior) + objective_value = objective.calculate(losses, kl, parameters['num_samples']) + loss = sum(losses) / objective.n else: - output = model(data) - kl = compute_kl(posterior, prior) - loss = criterion(output, target) - objective_value = objective.calculate(loss, kl, parameters['num_samples']) + raise ValueError(f'Invalid objective type: {type(objective)}') objective_value.backward() optimizer.step() logging.info(f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl/parameters['num_samples']}") diff --git a/scripts/generic_train.py b/scripts/generic_train.py index a654f72..37a620b 100644 --- a/scripts/generic_train.py +++ b/scripts/generic_train.py @@ -20,7 +20,7 @@ logging.basicConfig(level=logging.INFO) config = { - 'log_wandb': True, + 'log_wandb': False, 'mcsamples': 1000, 'pmin': 1e-5, 'sigma': 0.03, @@ -57,14 +57,19 @@ # 'model': {'name': 'conv', # 'params': {'in_channels': 1, 'dataset': 'mnist'} # }, - 'prior_objective': {'name': 'fquad', + # 'prior_objective': {'name': 'bbb', + # 'params': {'kl_penalty': 0.001, + # # 'delta': 0.025 + # } + # }, + 'prior_objective': {'name': 'naive_iwae', 'params': {'kl_penalty': 0.001, - 'delta': 0.025 + 'n': 10, } }, - 'posterior_objective': {'name': 'fquad', + 'posterior_objective': {'name': 'bbb', 'params': {'kl_penalty': 1.0, - 'delta': 0.025 + # 'delta': 0.025 } }, }, @@ -91,7 +96,7 @@ 'training': { 'lr': 0.001, 'momentum': 0.95, - 'epochs': 100, + 'epochs': 5, 'seed': 1135, } }, @@ -183,15 +188,15 @@ def main(): wandb_params={'log_wandb': config["log_wandb"], 'name_wandb': 'Prior Train'}) - if strategy.test_loader is not None: - _ = evaluate_metrics(model=model, - metrics=metrics, - test_loader=strategy.test_loader, - num_samples_metric=config["mcsamples"], - device=device, - pmin=config["pmin"], - wandb_params={'log_wandb': config["log_wandb"], - 'name_wandb': 'Prior Evaluation'}) + # if strategy.test_loader is not None: + # _ = evaluate_metrics(model=model, + # metrics=metrics, + # test_loader=strategy.test_loader, + # num_samples_metric=config["mcsamples"], + # device=device, + # pmin=config["pmin"], + # wandb_params={'log_wandb': config["log_wandb"], + # 'name_wandb': 'Prior Evaluation'}) _ = certify_risk(model=model, bounds=bounds, diff --git a/scripts/utils/factory/ObjectiveFactory.py b/scripts/utils/factory/ObjectiveFactory.py index 8dbf3dc..f8ba750 100644 --- a/scripts/utils/factory/ObjectiveFactory.py +++ b/scripts/utils/factory/ObjectiveFactory.py @@ -1,4 +1,4 @@ -from core.objective import AbstractObjective, FClassicObjective, McAllisterObjective, FQuadObjective, BBBObjective, TolstikhinObjective +from core.objective import AbstractObjective, FClassicObjective, McAllisterObjective, FQuadObjective, BBBObjective, TolstikhinObjective, NaiveIWAEObjective from scripts.utils.factory import AbstractFactory @@ -11,3 +11,4 @@ def __init__(self) -> None: self.register_creator("fquad", FQuadObjective) self.register_creator("mcallister", McAllisterObjective) self.register_creator("tolstikhin", TolstikhinObjective) + self.register_creator("naive_iwae", NaiveIWAEObjective) From 5887b826b3d0dd0dcb5128f860ddc41399154c55 Mon Sep 17 00:00:00 2001 From: yauhenii Date: Sun, 11 Aug 2024 13:52:29 +0200 Subject: [PATCH 2/9] prepare experiment --- scripts/generic_train.py | 4 +- scripts/generic_train_2.py | 46 ++++--- scripts/generic_train_3.py | 271 +++++++++++++++++++++++++++++++++++++ 3 files changed, 299 insertions(+), 22 deletions(-) create mode 100644 scripts/generic_train_3.py diff --git a/scripts/generic_train.py b/scripts/generic_train.py index 37a620b..2fb65cf 100644 --- a/scripts/generic_train.py +++ b/scripts/generic_train.py @@ -20,7 +20,7 @@ logging.basicConfig(level=logging.INFO) config = { - 'log_wandb': False, + 'log_wandb': True, 'mcsamples': 1000, 'pmin': 1e-5, 'sigma': 0.03, @@ -96,7 +96,7 @@ 'training': { 'lr': 0.001, 'momentum': 0.95, - 'epochs': 5, + 'epochs': 100, 'seed': 1135, } }, diff --git a/scripts/generic_train_2.py b/scripts/generic_train_2.py index fe1ec0b..900a325 100644 --- a/scripts/generic_train_2.py +++ b/scripts/generic_train_2.py @@ -57,14 +57,19 @@ # 'model': {'name': 'conv', # 'params': {'in_channels': 1, 'dataset': 'mnist'} # }, - 'prior_objective': {'name': 'fquad', + # 'prior_objective': {'name': 'bbb', + # 'params': {'kl_penalty': 0.001, + # # 'delta': 0.025 + # } + # }, + 'prior_objective': {'name': 'bbb', 'params': {'kl_penalty': 0.001, - 'delta': 0.025 + # 'delta': 0.025 } }, - 'posterior_objective': {'name': 'fquad', + 'posterior_objective': {'name': 'bbb', 'params': {'kl_penalty': 1.0, - 'delta': 0.025 + # 'delta': 0.025 } }, }, @@ -148,13 +153,14 @@ def main(): model = model_factory.create(config["factory"]["model"]["name"], **config["factory"]["model"]["params"]) torch.manual_seed(config['dist_init']['seed']) - prior_prior = from_random(model=model, + prior_prior = from_zeros(model=model, rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), distribution=GaussianVariable, requires_grad=False) - prior = from_copy(dist=prior_prior, - distribution=GaussianVariable, - requires_grad=True) + prior = from_random(model=model, + rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + distribution=GaussianVariable, + requires_grad=True) dnn_to_probnn(model, prior, prior_prior) model.to(device) @@ -191,18 +197,18 @@ def main(): # pmin=config["pmin"], # wandb_params={'log_wandb': config["log_wandb"], # 'name_wandb': 'Prior Evaluation'}) - # - # _ = certify_risk(model=model, - # bounds=bounds, - # losses=losses, - # posterior=prior, - # prior=prior_prior, - # bound_loader=strategy.bound_loader, - # num_samples_loss=config["mcsamples"], - # device=device, - # pmin=config["pmin"], - # wandb_params={'log_wandb': config["log_wandb"], - # 'name_wandb': 'Prior Bound'}) + + _ = certify_risk(model=model, + bounds=bounds, + losses=losses, + posterior=prior, + prior=prior_prior, + bound_loader=strategy.bound_loader, + num_samples_loss=config["mcsamples"], + device=device, + pmin=config["pmin"], + wandb_params={'log_wandb': config["log_wandb"], + 'name_wandb': 'Prior Bound'}) posterior_prior = from_copy(dist=prior, distribution=GaussianVariable, diff --git a/scripts/generic_train_3.py b/scripts/generic_train_3.py new file mode 100644 index 0000000..c830eeb --- /dev/null +++ b/scripts/generic_train_3.py @@ -0,0 +1,271 @@ +import wandb +import torch +import logging + +from core.split_strategy import PBPSplitStrategy +from core.distribution.utils import from_copy, from_zeros, from_random +from core.distribution import GaussianVariable +from core.training import train +from core.model import dnn_to_probnn, update_dist +from core.risk import certify_risk +from core.metric import evaluate_metrics + +from scripts.utils.factory import (LossFactory, + MetricFactory, + BoundFactory, + DataLoaderFactory, + ModelFactory, + ObjectiveFactory) + +logging.basicConfig(level=logging.INFO) + +config = { + 'log_wandb': True, + 'mcsamples': 1000, + 'pmin': 1e-5, + 'sigma': 0.03, + 'factory': + { + 'losses': ['nll_loss', 'scaled_nll_loss', '01_loss'], + 'metrics': ['accuracy_micro_metric', 'accuracy_macro_metric', 'f1_micro_metric', 'f1_macro_metric'], + 'bounds': ['kl', 'mcallister'], + 'data_loader': {'name': 'cifar10', + 'params': {'dataset_path': './data/cifar10'} + }, + # 'model': {'name': 'resnet', + # 'params': {'num_channels': 3} + # }, + # 'model': {'name': 'nn', + # 'params': {'input_dim': 32*32*3, + # 'hidden_dim': 100, + # 'output_dim': 10} + # }, + 'model': {'name': 'conv', + 'params': {'in_channels': 3, 'dataset': 'cifar10'} + }, + # 'model': {'name': 'conv15', + # 'params': {'in_channels': 3, 'dataset': 'cifar10'} + # }, + # 'data_loader': {'name': 'mnist', + # 'params': {'dataset_path': './data/mnist'} + # }, + # 'model': {'name': 'nn', + # 'params': {'input_dim': 28*28, + # 'hidden_dim': 100, + # 'output_dim': 10} + # }, + # 'model': {'name': 'conv', + # 'params': {'in_channels': 1, 'dataset': 'mnist'} + # }, + # 'prior_objective': {'name': 'bbb', + # 'params': {'kl_penalty': 0.001, + # # 'delta': 0.025 + # } + # }, + 'prior_objective': {'name': 'fquad', + 'params': {'kl_penalty': 0.001, + 'delta': 0.025, + } + }, + 'posterior_objective': {'name': 'bbb', + 'params': {'kl_penalty': 1.0, + # 'delta': 0.025 + } + }, + }, + 'bound': { + 'delta': 0.025, + 'delta_test': 0.01, + }, + 'split_config': { + 'seed': 111, + 'dataset_loader_seed': 112, + 'batch_size': 250, + }, + 'dist_init': { + 'seed': 110, + }, + 'split_strategy': { + 'prior_type': 'learnt', + 'train_percent': 1.0, + 'val_percent': 0.05, + 'prior_percent': .5, + 'self_certified': True, + }, + 'prior': { + 'training': { + 'lr': 0.001, + 'momentum': 0.95, + 'epochs': 100, + 'seed': 1135, + } + }, + 'posterior': { + 'training': { + 'lr': 0.001, + 'momentum': 0.9, + 'epochs': 1, + 'seed': 1135, + } + } +} + + +def main(): + if config['log_wandb']: + wandb.init(project='pbb-framework', config=config) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("Device ", device) + # Losses + logging.info(f'Selected losses: {config["factory"]["losses"]}') + loss_factory = LossFactory() + losses = {loss_name: loss_factory.create(loss_name) for loss_name in config["factory"]["losses"]} + + # Metrics + logging.info(f'Select metrics: {config["factory"]["metrics"]}') + metric_factory = MetricFactory() + metrics = {metric_name: metric_factory.create(metric_name) for metric_name in config["factory"]["metrics"]} + + # Bound + logging.info(f'Selected bounds: {config["factory"]["bounds"]}') + bound_factory = BoundFactory() + bounds = {bound_name: bound_factory.create(bound_name, + bound_delta=config['bound']['delta'], + loss_delta=config['bound']['delta_test']) + for bound_name in config["factory"]["bounds"]} + + # Data + logging.info(f'Selected data loader: {config["factory"]["data_loader"]}') + data_loader_factory = DataLoaderFactory() + loader = data_loader_factory.create(config["factory"]["data_loader"]["name"], + **config["factory"]["data_loader"]["params"]) + + strategy = PBPSplitStrategy(prior_type=config['split_strategy']['prior_type'], + train_percent=config['split_strategy']['train_percent'], + val_percent=config['split_strategy']['val_percent'], + prior_percent=config['split_strategy']['prior_percent'], + self_certified=config['split_strategy']['self_certified']) + strategy.split(loader, split_config=config['split_config']) + + # Model + logging.info(f'Select model: {config["factory"]["model"]["name"]}') + model_factory = ModelFactory() + model = model_factory.create(config["factory"]["model"]["name"], **config["factory"]["model"]["params"]) + + torch.manual_seed(config['dist_init']['seed']) + prior_prior = from_zeros(model=model, + rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + distribution=GaussianVariable, + requires_grad=False) + prior = from_random(model=model, + rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + distribution=GaussianVariable, + requires_grad=True) + dnn_to_probnn(model, prior, prior_prior) + model.to(device) + + # Training prior + train_params = { + 'lr': config['prior']['training']['lr'], + 'momentum': config['prior']['training']['momentum'], + 'epochs': config['prior']['training']['epochs'], + 'seed': config['prior']['training']['seed'], + 'num_samples': strategy.prior_loader.batch_size * len(strategy.prior_loader), + } + logging.info(f'Select objective: {config["factory"]["prior_objective"]["name"]}') + objective_factory = ObjectiveFactory() + objective = objective_factory.create(config["factory"]["prior_objective"]["name"], + **config["factory"]["prior_objective"]["params"]) + + train(model=model, + posterior=prior, + prior=prior_prior, + objective=objective, + train_loader=strategy.prior_loader, + val_loader=strategy.val_loader, + parameters=train_params, + device=device, + wandb_params={'log_wandb': config["log_wandb"], + 'name_wandb': 'Prior Train'}) + + # if strategy.test_loader is not None: + # _ = evaluate_metrics(model=model, + # metrics=metrics, + # test_loader=strategy.test_loader, + # num_samples_metric=config["mcsamples"], + # device=device, + # pmin=config["pmin"], + # wandb_params={'log_wandb': config["log_wandb"], + # 'name_wandb': 'Prior Evaluation'}) + + _ = certify_risk(model=model, + bounds=bounds, + losses=losses, + posterior=prior, + prior=prior_prior, + bound_loader=strategy.bound_loader, + num_samples_loss=config["mcsamples"], + device=device, + pmin=config["pmin"], + wandb_params={'log_wandb': config["log_wandb"], + 'name_wandb': 'Prior Bound'}) + + posterior_prior = from_copy(dist=prior, + distribution=GaussianVariable, + requires_grad=False) + posterior = from_copy(dist=prior, + distribution=GaussianVariable, + requires_grad=True) + update_dist(model, weight_dist=posterior, prior_weight_dist=posterior_prior) + model.to(device) + + # Train posterior + train_params = { + 'lr': config['posterior']['training']['lr'], + 'momentum': config['posterior']['training']['momentum'], + 'epochs': config['posterior']['training']['epochs'], + 'seed': config['posterior']['training']['seed'], + 'num_samples': strategy.posterior_loader.batch_size * len(strategy.posterior_loader), + } + + logging.info(f'Select objective: {config["factory"]["posterior_objective"]["name"]}') + objective = objective_factory.create(config["factory"]["posterior_objective"]["name"], + **config["factory"]["posterior_objective"]["params"]) + + train(model=model, + posterior=posterior, + prior=posterior_prior, + objective=objective, + train_loader=strategy.posterior_loader, + val_loader=strategy.val_loader, + parameters=train_params, + device=device, + wandb_params={'log_wandb': config["log_wandb"], + 'name_wandb': 'Posterior Train'}) + + # if strategy.test_loader is not None: + # _ = evaluate_metrics(model=model, + # metrics=metrics, + # test_loader=strategy.test_loader, + # num_samples_metric=config["mcsamples"], + # device=device, + # pmin=config["pmin"], + # wandb_params={'log_wandb': config["log_wandb"], + # 'name_wandb': 'Posterior Evaluation'}) + + _ = certify_risk(model=model, + bounds=bounds, + losses=losses, + posterior=posterior, + prior=posterior_prior, + bound_loader=strategy.bound_loader, + num_samples_loss=config["mcsamples"], + device=device, + pmin=config["pmin"], + wandb_params={'log_wandb': config["log_wandb"], + 'name_wandb': 'Posterior Bound'}) + + +if __name__ == '__main__': + main() + From 78bc122a908bda860bbb47b542955baa38d0b5d7 Mon Sep 17 00:00:00 2001 From: yauhenii Date: Sat, 17 Aug 2024 16:18:56 +0200 Subject: [PATCH 3/9] try --- core/layer/AbstractProbLayer.py | 4 ++ core/objective/IWAEObjective.py | 63 +++++++++++++++++++++++ core/objective/__init__.py | 1 + core/training.py | 6 ++- scripts/generic_train.py | 55 ++++++++++---------- scripts/utils/factory/ObjectiveFactory.py | 10 +++- 6 files changed, 110 insertions(+), 29 deletions(-) create mode 100644 core/objective/IWAEObjective.py diff --git a/core/layer/AbstractProbLayer.py b/core/layer/AbstractProbLayer.py index b33b092..6b819dd 100644 --- a/core/layer/AbstractProbLayer.py +++ b/core/layer/AbstractProbLayer.py @@ -12,6 +12,8 @@ class AbstractProbLayer(nn.Module, ABC): _bias_dist: AbstractVariable _prior_weight_dist: AbstractVariable _prior_bias_dist: AbstractVariable + _sampled_weight: Tensor + _sampled_bias: Tensor def probabilistic(self, mode: bool = True): if not isinstance(mode, bool): @@ -32,4 +34,6 @@ def sample_from_distribution(self) -> Tuple[Tensor, Tensor]: sampled_bias = self._bias_dist.mu if self._bias_dist else None else: raise ValueError('Only training with probabilistic mode is allowed') + self._sampled_weight = sampled_weight + self._sampled_bias = sampled_bias return sampled_weight, sampled_bias diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py new file mode 100644 index 0000000..eccc498 --- /dev/null +++ b/core/objective/IWAEObjective.py @@ -0,0 +1,63 @@ +import torch +from torch import Tensor, nn +from typing import List +import numpy as np +import torch.distributions as dists + +from core.model import bounded_call +from core.layer.utils import get_torch_layers + + +class IWAEObjective: + def __init__(self, kl_penalty: float, n: int) -> None: + self._kl_penalty = kl_penalty + self.n: int = n + self.criterion = torch.nn.NLLLoss() + + def calculate(self, model: nn.Module, data: Tensor, target: Tensor, pmin: float = None) -> Tensor: + + log_losses = [] + + for i in range(self.n): + + if pmin is not None: + p_x_g_w = bounded_call(model, data, pmin) + else: + p_x_g_w = model(data) + + # log_loss_i = torch.sum(p_x_g_w, dim=1) + # log_loss_i = self.criterion(p_x_g_w, target) + log_loss_i = dists.Categorical(logits=p_x_g_w).log_prob(target) + + log_p_w_total = 0 + log_q_w_g_x_total = 0 + eps = 1e-6 + norm = False + + for l_name, l in get_torch_layers(model): + sampled_weight = l._sampled_weight + sampled_bias = l._sampled_bias + + log_p_w_weight = dists.Normal(l._prior_weight_dist.mu, l._prior_weight_dist.sigma + eps).log_prob(sampled_weight) + log_p_w_bias = dists.Normal(l._prior_bias_dist.mu, l._prior_bias_dist.sigma + eps).log_prob(sampled_bias) + + if norm: + log_p_w_total += (log_p_w_weight.sum() / torch.prod(torch.tensor(log_p_w_weight.shape)) + + log_p_w_bias.sum() / torch.prod(torch.tensor(log_p_w_bias.shape))) + else: + log_p_w_total += log_p_w_weight.sum() + log_p_w_bias.sum() + + log_q_w_g_x_weight = dists.Normal(l._weight_dist.mu, l._weight_dist.sigma + eps).log_prob(sampled_weight) + log_q_w_g_x_bias = dists.Normal(l._bias_dist.mu, l._bias_dist.sigma + eps).log_prob(sampled_bias) + + if norm: + log_q_w_g_x_total += (log_q_w_g_x_weight.sum() / torch.prod(torch.tensor(log_q_w_g_x_weight.shape)) + + log_q_w_g_x_bias.sum() / torch.prod(torch.tensor(log_q_w_g_x_bias.shape))) + else: + log_q_w_g_x_total += log_q_w_g_x_weight.sum() + log_q_w_g_x_bias.sum() + + log_loss_i = log_loss_i + log_p_w_total.repeat(len(log_loss_i)) - log_q_w_g_x_total.repeat(len(log_loss_i)) + log_losses.append(log_loss_i) + loss = - (torch.logsumexp(torch.stack(log_losses), dim=0) - np.log(self.n)).mean() + # loss = -log_losses[0].mean() + return loss diff --git a/core/objective/__init__.py b/core/objective/__init__.py index a6e2f45..29dad9c 100644 --- a/core/objective/__init__.py +++ b/core/objective/__init__.py @@ -6,3 +6,4 @@ from core.objective.McAllisterObjective import McAllisterObjective from core.objective.TolstikhinObjective import TolstikhinObjective from core.objective.NaiveIWAEObjective import NaiveIWAEObjective +from core.objective.IWAEObjective import IWAEObjective diff --git a/core/training.py b/core/training.py index 6733a3a..1311f0e 100644 --- a/core/training.py +++ b/core/training.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from core.distribution.utils import compute_kl, DistributionT -from core.objective import AbstractObjective, AbstractSamplingObjective +from core.objective import AbstractObjective, AbstractSamplingObjective, IWAEObjective from core.model import bounded_call @@ -53,6 +53,10 @@ def train(model: nn.Module, kl = compute_kl(posterior, prior) objective_value = objective.calculate(losses, kl, parameters['num_samples']) loss = sum(losses) / objective.n + elif isinstance(objective, IWAEObjective): + objective_value = objective.calculate(model, data, target, pmin=parameters.get('pmin', None)) + loss = criterion(model(data), target) + kl = compute_kl(posterior, prior) else: raise ValueError(f'Invalid objective type: {type(objective)}') objective_value.backward() diff --git a/scripts/generic_train.py b/scripts/generic_train.py index 2fb65cf..78bcc52 100644 --- a/scripts/generic_train.py +++ b/scripts/generic_train.py @@ -2,7 +2,7 @@ import torch import logging -from core.split_strategy import PBPSplitStrategy +from core.split_strategy import FaultySplitStrategy from core.distribution.utils import from_copy, from_zeros, from_random from core.distribution import GaussianVariable from core.training import train @@ -20,18 +20,18 @@ logging.basicConfig(level=logging.INFO) config = { - 'log_wandb': True, + 'log_wandb': False, 'mcsamples': 1000, 'pmin': 1e-5, - 'sigma': 0.03, + 'sigma': 0.01, 'factory': { 'losses': ['nll_loss', 'scaled_nll_loss', '01_loss'], 'metrics': ['accuracy_micro_metric', 'accuracy_macro_metric', 'f1_micro_metric', 'f1_macro_metric'], 'bounds': ['kl', 'mcallister'], - 'data_loader': {'name': 'cifar10', - 'params': {'dataset_path': './data/cifar10'} - }, + # 'data_loader': {'name': 'cifar10', + # 'params': {'dataset_path': './data/cifar10'} + # }, # 'model': {'name': 'resnet', # 'params': {'num_channels': 3} # }, @@ -40,20 +40,20 @@ # 'hidden_dim': 100, # 'output_dim': 10} # }, - 'model': {'name': 'conv', - 'params': {'in_channels': 3, 'dataset': 'cifar10'} - }, + # 'model': {'name': 'conv', + # 'params': {'in_channels': 3, 'dataset': 'cifar10'} + # }, # 'model': {'name': 'conv15', # 'params': {'in_channels': 3, 'dataset': 'cifar10'} # }, - # 'data_loader': {'name': 'mnist', - # 'params': {'dataset_path': './data/mnist'} - # }, - # 'model': {'name': 'nn', - # 'params': {'input_dim': 28*28, - # 'hidden_dim': 100, - # 'output_dim': 10} - # }, + 'data_loader': {'name': 'mnist', + 'params': {'dataset_path': './data/mnist'} + }, + 'model': {'name': 'nn', + 'params': {'input_dim': 28*28, + 'hidden_dim': 100, + 'output_dim': 10} + }, # 'model': {'name': 'conv', # 'params': {'in_channels': 1, 'dataset': 'mnist'} # }, @@ -62,7 +62,7 @@ # # 'delta': 0.025 # } # }, - 'prior_objective': {'name': 'naive_iwae', + 'prior_objective': {'name': 'iwae', 'params': {'kl_penalty': 0.001, 'n': 10, } @@ -140,11 +140,11 @@ def main(): loader = data_loader_factory.create(config["factory"]["data_loader"]["name"], **config["factory"]["data_loader"]["params"]) - strategy = PBPSplitStrategy(prior_type=config['split_strategy']['prior_type'], - train_percent=config['split_strategy']['train_percent'], - val_percent=config['split_strategy']['val_percent'], - prior_percent=config['split_strategy']['prior_percent'], - self_certified=config['split_strategy']['self_certified']) + strategy = FaultySplitStrategy(prior_type=config['split_strategy']['prior_type'], + train_percent=config['split_strategy']['train_percent'], + val_percent=config['split_strategy']['val_percent'], + prior_percent=config['split_strategy']['prior_percent'], + self_certified=config['split_strategy']['self_certified']) strategy.split(loader, split_config=config['split_config']) # Model @@ -157,10 +157,11 @@ def main(): rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), distribution=GaussianVariable, requires_grad=False) - prior = from_random(model=model, - rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), - distribution=GaussianVariable, - requires_grad=True) + # prior = from_random(model=model, + # rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + # distribution=GaussianVariable, + # requires_grad=True) + prior = from_copy(dist=prior_prior, distribution=GaussianVariable) dnn_to_probnn(model, prior, prior_prior) model.to(device) diff --git a/scripts/utils/factory/ObjectiveFactory.py b/scripts/utils/factory/ObjectiveFactory.py index f8ba750..bec654d 100644 --- a/scripts/utils/factory/ObjectiveFactory.py +++ b/scripts/utils/factory/ObjectiveFactory.py @@ -1,4 +1,11 @@ -from core.objective import AbstractObjective, FClassicObjective, McAllisterObjective, FQuadObjective, BBBObjective, TolstikhinObjective, NaiveIWAEObjective +from core.objective import (AbstractObjective, + FClassicObjective, + McAllisterObjective, + FQuadObjective, + BBBObjective, + TolstikhinObjective, + NaiveIWAEObjective, + IWAEObjective) from scripts.utils.factory import AbstractFactory @@ -12,3 +19,4 @@ def __init__(self) -> None: self.register_creator("mcallister", McAllisterObjective) self.register_creator("tolstikhin", TolstikhinObjective) self.register_creator("naive_iwae", NaiveIWAEObjective) + self.register_creator("iwae", IWAEObjective) From b044bfe8d6c7620ce81e0ebee0b0ccf0a93ee470 Mon Sep 17 00:00:00 2001 From: Misipuk Date: Sat, 24 Aug 2024 15:18:06 +0200 Subject: [PATCH 4/9] changing the prior sampling --- core/objective/IWAEObjective.py | 3 ++- scripts/generic_train.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py index eccc498..1037df1 100644 --- a/core/objective/IWAEObjective.py +++ b/core/objective/IWAEObjective.py @@ -27,6 +27,7 @@ def calculate(self, model: nn.Module, data: Tensor, target: Tensor, pmin: float # log_loss_i = torch.sum(p_x_g_w, dim=1) # log_loss_i = self.criterion(p_x_g_w, target) + temperature = 0.0001 log_loss_i = dists.Categorical(logits=p_x_g_w).log_prob(target) log_p_w_total = 0 @@ -56,7 +57,7 @@ def calculate(self, model: nn.Module, data: Tensor, target: Tensor, pmin: float else: log_q_w_g_x_total += log_q_w_g_x_weight.sum() + log_q_w_g_x_bias.sum() - log_loss_i = log_loss_i + log_p_w_total.repeat(len(log_loss_i)) - log_q_w_g_x_total.repeat(len(log_loss_i)) + log_loss_i = log_loss_i + (log_p_w_total.repeat(len(log_loss_i)) - log_q_w_g_x_total.repeat(len(log_loss_i)))*temperature log_losses.append(log_loss_i) loss = - (torch.logsumexp(torch.stack(log_losses), dim=0) - np.log(self.n)).mean() # loss = -log_losses[0].mean() diff --git a/scripts/generic_train.py b/scripts/generic_train.py index 78bcc52..359c252 100644 --- a/scripts/generic_train.py +++ b/scripts/generic_train.py @@ -153,15 +153,19 @@ def main(): model = model_factory.create(config["factory"]["model"]["name"], **config["factory"]["model"]["params"]) torch.manual_seed(config['dist_init']['seed']) - prior_prior = from_zeros(model=model, - rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), - distribution=GaussianVariable, - requires_grad=False) + # prior_prior = from_zeros(model=model, + # rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + # distribution=GaussianVariable, + # requires_grad=False) + prior_prior = from_random(model=model, + rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), + distribution=GaussianVariable, + requires_grad=False) # prior = from_random(model=model, # rho=torch.log(torch.exp(torch.Tensor([config['sigma']])) - 1), # distribution=GaussianVariable, # requires_grad=True) - prior = from_copy(dist=prior_prior, distribution=GaussianVariable) + prior = from_copy(dist=prior_prior, distribution=GaussianVariable, requires_grad=True) dnn_to_probnn(model, prior, prior_prior) model.to(device) From e077567265bbe3c3ca81934eb6f1b405a5f25efc Mon Sep 17 00:00:00 2001 From: Misipuk Date: Sat, 24 Aug 2024 15:28:57 +0200 Subject: [PATCH 5/9] working config --- core/objective/IWAEObjective.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py index 1037df1..5ec675a 100644 --- a/core/objective/IWAEObjective.py +++ b/core/objective/IWAEObjective.py @@ -27,7 +27,7 @@ def calculate(self, model: nn.Module, data: Tensor, target: Tensor, pmin: float # log_loss_i = torch.sum(p_x_g_w, dim=1) # log_loss_i = self.criterion(p_x_g_w, target) - temperature = 0.0001 + temperature = 0.00001 log_loss_i = dists.Categorical(logits=p_x_g_w).log_prob(target) log_p_w_total = 0 From a1c9e59f97d79d8ad096f0cc588f424594e4fa50 Mon Sep 17 00:00:00 2001 From: yauhenii Date: Sat, 24 Aug 2024 17:02:20 +0200 Subject: [PATCH 6/9] add basic logging for iwae --- core/objective/IWAEObjective.py | 28 ++++++++++++++++++++++------ core/training.py | 13 ++++++++++++- scripts/generic_train.py | 5 +++-- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py index 5ec675a..f9a6ecf 100644 --- a/core/objective/IWAEObjective.py +++ b/core/objective/IWAEObjective.py @@ -1,6 +1,8 @@ +import logging import torch +import wandb from torch import Tensor, nn -from typing import List +from typing import List, Dict import numpy as np import torch.distributions as dists @@ -9,12 +11,20 @@ class IWAEObjective: - def __init__(self, kl_penalty: float, n: int) -> None: + def __init__(self, kl_penalty: float, n: int, temperature: int) -> None: self._kl_penalty = kl_penalty self.n: int = n self.criterion = torch.nn.NLLLoss() + self._temperature = temperature - def calculate(self, model: nn.Module, data: Tensor, target: Tensor, pmin: float = None) -> Tensor: + def calculate(self, + model: nn.Module, + data: Tensor, + target: Tensor, + epoch: int, + batch: int, + pmin: float = None, + wandb_params: Dict = None) -> Tensor: log_losses = [] @@ -27,8 +37,7 @@ def calculate(self, model: nn.Module, data: Tensor, target: Tensor, pmin: float # log_loss_i = torch.sum(p_x_g_w, dim=1) # log_loss_i = self.criterion(p_x_g_w, target) - temperature = 0.00001 - log_loss_i = dists.Categorical(logits=p_x_g_w).log_prob(target) + log_p_x_g_w = dists.Categorical(logits=p_x_g_w).log_prob(target) log_p_w_total = 0 log_q_w_g_x_total = 0 @@ -57,7 +66,14 @@ def calculate(self, model: nn.Module, data: Tensor, target: Tensor, pmin: float else: log_q_w_g_x_total += log_q_w_g_x_weight.sum() + log_q_w_g_x_bias.sum() - log_loss_i = log_loss_i + (log_p_w_total.repeat(len(log_loss_i)) - log_q_w_g_x_total.repeat(len(log_loss_i)))*temperature + temperature_term = self._temperature * (log_p_w_total.repeat(len(log_p_x_g_w)) - log_q_w_g_x_total.repeat(len(log_p_x_g_w))) + log_loss_i = log_p_x_g_w + temperature_term + if i == self.n-1 and batch in [132,]: + logging.info( + f"Sample: {i}, Epoch: {epoch}, Batch: {batch}, Mean likelihood: {log_p_x_g_w.mean()}, Mean temperature term: {temperature_term.mean()}, Temperature: {temperature}") + if wandb_params is not None and wandb_params["log_wandb"]: + wandb.log({wandb_params["name_wandb"] + '/Mean likelihood': log_p_x_g_w.mean(), + wandb_params["name_wandb"] + '/Mean temperature term': temperature_term.mean()}) log_losses.append(log_loss_i) loss = - (torch.logsumexp(torch.stack(log_losses), dim=0) - np.log(self.n)).mean() # loss = -log_losses[0].mean() diff --git a/core/training.py b/core/training.py index 1311f0e..3c24d04 100644 --- a/core/training.py +++ b/core/training.py @@ -30,6 +30,11 @@ def train(model: nn.Module, if 'seed' in parameters: torch.manual_seed(parameters['seed']) + + loss = None + kl = None + objective_value = None + for epoch in range(parameters['epochs']): for i, (data, target) in tqdm(enumerate(train_loader)): data, target = data.to(device), target.to(device) @@ -54,7 +59,13 @@ def train(model: nn.Module, objective_value = objective.calculate(losses, kl, parameters['num_samples']) loss = sum(losses) / objective.n elif isinstance(objective, IWAEObjective): - objective_value = objective.calculate(model, data, target, pmin=parameters.get('pmin', None)) + objective_value = objective.calculate(model, + data, + target, + epoch=epoch, + batch=i, + pmin=parameters.get('pmin', None), + wandb_params=wandb_params) loss = criterion(model(data), target) kl = compute_kl(posterior, prior) else: diff --git a/scripts/generic_train.py b/scripts/generic_train.py index 359c252..ea4a666 100644 --- a/scripts/generic_train.py +++ b/scripts/generic_train.py @@ -20,7 +20,7 @@ logging.basicConfig(level=logging.INFO) config = { - 'log_wandb': False, + 'log_wandb': True, 'mcsamples': 1000, 'pmin': 1e-5, 'sigma': 0.01, @@ -65,6 +65,7 @@ 'prior_objective': {'name': 'iwae', 'params': {'kl_penalty': 0.001, 'n': 10, + 'temperature': 1e-4, } }, 'posterior_objective': {'name': 'bbb', @@ -94,7 +95,7 @@ }, 'prior': { 'training': { - 'lr': 0.001, + 'lr': 0.01, 'momentum': 0.95, 'epochs': 100, 'seed': 1135, From b96dc52ef59dbda704ee9c53fdcc9720775f5ba8 Mon Sep 17 00:00:00 2001 From: yauhenii Date: Sat, 24 Aug 2024 17:08:31 +0200 Subject: [PATCH 7/9] fix missing temperature param --- core/objective/IWAEObjective.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py index f9a6ecf..2a91839 100644 --- a/core/objective/IWAEObjective.py +++ b/core/objective/IWAEObjective.py @@ -70,7 +70,7 @@ def calculate(self, log_loss_i = log_p_x_g_w + temperature_term if i == self.n-1 and batch in [132,]: logging.info( - f"Sample: {i}, Epoch: {epoch}, Batch: {batch}, Mean likelihood: {log_p_x_g_w.mean()}, Mean temperature term: {temperature_term.mean()}, Temperature: {temperature}") + f"Sample: {i}, Epoch: {epoch}, Batch: {batch}, Mean likelihood: {log_p_x_g_w.mean()}, Mean temperature term: {temperature_term.mean()}, Temperature: {self._temperature}") if wandb_params is not None and wandb_params["log_wandb"]: wandb.log({wandb_params["name_wandb"] + '/Mean likelihood': log_p_x_g_w.mean(), wandb_params["name_wandb"] + '/Mean temperature term': temperature_term.mean()}) From 38faa1b0fb6d8d48a051cadfcdbc80261529caa3 Mon Sep 17 00:00:00 2001 From: Misipuk Date: Thu, 22 May 2025 00:05:07 +0200 Subject: [PATCH 8/9] adapted iwae --- core/objective/IWAEObjective.py | 169 ++++++++++++++++++-------------- scripts/generic_train.py | 4 +- 2 files changed, 98 insertions(+), 75 deletions(-) diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py index 2a91839..5bf63fd 100644 --- a/core/objective/IWAEObjective.py +++ b/core/objective/IWAEObjective.py @@ -1,80 +1,103 @@ -import logging -import torch -import wandb -from torch import Tensor, nn -from typing import List, Dict -import numpy as np -import torch.distributions as dists +import logging, math +from typing import Dict, Optional + +import torch, torch.distributions as dists, wandb +from torch import nn, Tensor from core.model import bounded_call from core.layer.utils import get_torch_layers class IWAEObjective: - def __init__(self, kl_penalty: float, n: int, temperature: int) -> None: - self._kl_penalty = kl_penalty - self.n: int = n - self.criterion = torch.nn.NLLLoss() - self._temperature = temperature - - def calculate(self, - model: nn.Module, - data: Tensor, - target: Tensor, - epoch: int, - batch: int, - pmin: float = None, - wandb_params: Dict = None) -> Tensor: - - log_losses = [] - - for i in range(self.n): - - if pmin is not None: - p_x_g_w = bounded_call(model, data, pmin) - else: - p_x_g_w = model(data) - - # log_loss_i = torch.sum(p_x_g_w, dim=1) - # log_loss_i = self.criterion(p_x_g_w, target) - log_p_x_g_w = dists.Categorical(logits=p_x_g_w).log_prob(target) - - log_p_w_total = 0 - log_q_w_g_x_total = 0 - eps = 1e-6 - norm = False - - for l_name, l in get_torch_layers(model): - sampled_weight = l._sampled_weight - sampled_bias = l._sampled_bias - - log_p_w_weight = dists.Normal(l._prior_weight_dist.mu, l._prior_weight_dist.sigma + eps).log_prob(sampled_weight) - log_p_w_bias = dists.Normal(l._prior_bias_dist.mu, l._prior_bias_dist.sigma + eps).log_prob(sampled_bias) - - if norm: - log_p_w_total += (log_p_w_weight.sum() / torch.prod(torch.tensor(log_p_w_weight.shape)) - + log_p_w_bias.sum() / torch.prod(torch.tensor(log_p_w_bias.shape))) - else: - log_p_w_total += log_p_w_weight.sum() + log_p_w_bias.sum() - - log_q_w_g_x_weight = dists.Normal(l._weight_dist.mu, l._weight_dist.sigma + eps).log_prob(sampled_weight) - log_q_w_g_x_bias = dists.Normal(l._bias_dist.mu, l._bias_dist.sigma + eps).log_prob(sampled_bias) - - if norm: - log_q_w_g_x_total += (log_q_w_g_x_weight.sum() / torch.prod(torch.tensor(log_q_w_g_x_weight.shape)) - + log_q_w_g_x_bias.sum() / torch.prod(torch.tensor(log_q_w_g_x_bias.shape))) - else: - log_q_w_g_x_total += log_q_w_g_x_weight.sum() + log_q_w_g_x_bias.sum() - - temperature_term = self._temperature * (log_p_w_total.repeat(len(log_p_x_g_w)) - log_q_w_g_x_total.repeat(len(log_p_x_g_w))) - log_loss_i = log_p_x_g_w + temperature_term - if i == self.n-1 and batch in [132,]: - logging.info( - f"Sample: {i}, Epoch: {epoch}, Batch: {batch}, Mean likelihood: {log_p_x_g_w.mean()}, Mean temperature term: {temperature_term.mean()}, Temperature: {self._temperature}") - if wandb_params is not None and wandb_params["log_wandb"]: - wandb.log({wandb_params["name_wandb"] + '/Mean likelihood': log_p_x_g_w.mean(), - wandb_params["name_wandb"] + '/Mean temperature term': temperature_term.mean()}) - log_losses.append(log_loss_i) - loss = - (torch.logsumexp(torch.stack(log_losses), dim=0) - np.log(self.n)).mean() - # loss = -log_losses[0].mean() + def __init__(self, kl_penalty: float, n: int, temperature: float = 1.0) -> None: + self.k = n + self.kl_penalty = kl_penalty # usually 1 / |D| + self.temperature = temperature + + # -------- helpers to compute log p(w) and log q(w) ------------------- + @staticmethod + def _log_prior(model: nn.Module, eps: float = 1e-6) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + s = torch.zeros(1, device=device, dtype=dtype) + + for _, l in get_torch_layers(model): + s += dists.Normal(l._prior_weight_dist.mu, + l._prior_weight_dist.sigma + eps + ).log_prob(l._sampled_weight).sum() + s += dists.Normal(l._prior_bias_dist.mu, + l._prior_bias_dist.sigma + eps + ).log_prob(l._sampled_bias).sum() + return s + + @staticmethod + def _log_post(model: nn.Module, eps: float = 1e-6) -> Tensor: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + s = torch.zeros(1, device=device, dtype=dtype) + + for _, l in get_torch_layers(model): + s += dists.Normal(l._weight_dist.mu, + l._weight_dist.sigma + eps + ).log_prob(l._sampled_weight).sum() + s += dists.Normal(l._bias_dist.mu, + l._bias_dist.sigma + eps + ).log_prob(l._sampled_bias).sum() + return s + + # -------------------------------------------------------------------- + def calculate( + self, + model: nn.Module, + data: Tensor, + target: Tensor, + epoch: int, + batch_idx: int, + dataset_size: int, + pmin: Optional[float] = None, + wandb_params: Optional[Dict] = None, + ) -> Tensor: + + batch_size = data.size(0) + scale = dataset_size / batch_size # N / |B| + log_ws = [] # list[k] of scalars + + for l in range(self.k): + # sample w and compute log p(x|w) + logits = bounded_call(model, data, pmin) if pmin is not None else model(data) + log_px = dists.Categorical(logits=logits).log_prob(target) # (batch,) + log_lik = scale * log_px.sum() # scalar + + # global KL part + kl = (self._log_prior(model) - self._log_post(model)) * self.kl_penalty + log_w = log_lik + self.temperature * kl # scalar + log_ws.append(log_w) + + # -------------------- per-sample logging -------------------- + if wandb_params and wandb_params.get("log_wandb", False): + tag = wandb_params["name_wandb"] + wandb.log({ + f"{tag}/epoch": epoch, + f"{tag}/batch": batch_idx, + f"{tag}/sample": l, + f"{tag}/log_likelihood": log_lik.detach(), + f"{tag}/kl": kl.detach(), + f"{tag}/log_weight": log_w.detach(), + }) + + # ----------- PB-IWAE loss (one scalar) --------------------------- + log_ws_tensor = torch.stack(log_ws) # (k,) + loss = -(torch.logsumexp(log_ws_tensor, dim=0) - math.log(self.k)) + + # ----------- final logging -------------------------------------- + if wandb_params and wandb_params.get("log_wandb", False): + wandb.log({f"{wandb_params['name_wandb']}/iwae_loss": loss}) + + if batch_idx == 0: + logging.info( + f"[Epoch {epoch:03d} | Batch {batch_idx:04d}] " + f"IWAE-loss {loss.item():.4f} " + f"| mean log_px {(log_px.mean()).item():.4f} " + f"| KL {kl.item():.2f}" + ) return loss diff --git a/scripts/generic_train.py b/scripts/generic_train.py index ea4a666..998fc53 100644 --- a/scripts/generic_train.py +++ b/scripts/generic_train.py @@ -20,7 +20,7 @@ logging.basicConfig(level=logging.INFO) config = { - 'log_wandb': True, + 'log_wandb': False, 'mcsamples': 1000, 'pmin': 1e-5, 'sigma': 0.01, @@ -62,7 +62,7 @@ # # 'delta': 0.025 # } # }, - 'prior_objective': {'name': 'iwae', + 'prior_objective': {'name': 'iwae',#iwae or naive iwae from objective factory 'params': {'kl_penalty': 0.001, 'n': 10, 'temperature': 1e-4, From a6d0726a1c713ead1f8fbf2fa4adf85d454e3e86 Mon Sep 17 00:00:00 2001 From: Misipuk Date: Thu, 22 May 2025 00:46:31 +0200 Subject: [PATCH 9/9] training works --- core/objective/IWAEObjective.py | 13 +++++++++++-- core/training.py | 22 +++++++++++++++------- scripts/generic_train.py | 6 +++--- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/core/objective/IWAEObjective.py b/core/objective/IWAEObjective.py index 5bf63fd..034faa1 100644 --- a/core/objective/IWAEObjective.py +++ b/core/objective/IWAEObjective.py @@ -58,19 +58,28 @@ def calculate( wandb_params: Optional[Dict] = None, ) -> Tensor: + batch_size = data.size(0) scale = dataset_size / batch_size # N / |B| log_ws = [] # list[k] of scalars + kl_pen = 1 / dataset_size + temp = 1.0 + for l in range(self.k): # sample w and compute log p(x|w) logits = bounded_call(model, data, pmin) if pmin is not None else model(data) + + if torch.isnan(logits).any() or torch.isinf(logits).any(): + logging.warning(f"NaN/Inf in logits at epoch {epoch}, batch {batch_idx}") + logits = torch.where(torch.isfinite(logits), logits, torch.zeros_like(logits)) + log_px = dists.Categorical(logits=logits).log_prob(target) # (batch,) log_lik = scale * log_px.sum() # scalar # global KL part - kl = (self._log_prior(model) - self._log_post(model)) * self.kl_penalty - log_w = log_lik + self.temperature * kl # scalar + kl = (self._log_prior(model) - self._log_post(model)) * kl_pen + log_w = log_lik + temp * kl # scalar log_ws.append(log_w) # -------------------- per-sample logging -------------------- diff --git a/core/training.py b/core/training.py index 3c24d04..8619650 100644 --- a/core/training.py +++ b/core/training.py @@ -24,10 +24,13 @@ def train(model: nn.Module, wandb_params: Dict = None, ): criterion = torch.nn.NLLLoss() - optimizer = torch.optim.SGD(model.parameters(), - lr=parameters['lr'], - momentum=parameters['momentum']) + #optimizer = torch.optim.SGD(model.parameters(), + # lr=parameters['lr'], + # momentum=parameters['momentum']) + #just to try + optimizer = torch.optim.Adam(model.parameters(), + lr=parameters['lr']) if 'seed' in parameters: torch.manual_seed(parameters['seed']) @@ -35,6 +38,8 @@ def train(model: nn.Module, kl = None objective_value = None + dataset_size = len(train_loader.dataset) # new added + for epoch in range(parameters['epochs']): for i, (data, target) in tqdm(enumerate(train_loader)): data, target = data.to(device), target.to(device) @@ -49,7 +54,7 @@ def train(model: nn.Module, objective_value = objective.calculate(loss, kl, parameters['num_samples']) elif isinstance(objective, AbstractSamplingObjective): losses = [] - for i in range(objective.n): + for j in range(objective.n): if 'pmin' in parameters: output = bounded_call(model, data, parameters['pmin']) else: @@ -63,14 +68,17 @@ def train(model: nn.Module, data, target, epoch=epoch, - batch=i, + batch_idx=i, + dataset_size= dataset_size, pmin=parameters.get('pmin', None), wandb_params=wandb_params) - loss = criterion(model(data), target) - kl = compute_kl(posterior, prior) + with torch.no_grad(): + loss = criterion(model(data), target) + kl = compute_kl(posterior, prior) else: raise ValueError(f'Invalid objective type: {type(objective)}') objective_value.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) optimizer.step() logging.info(f"Epoch: {epoch}, Objective: {objective_value}, Loss: {loss}, KL/n: {kl/parameters['num_samples']}") if wandb_params is not None and wandb_params["log_wandb"]: diff --git a/scripts/generic_train.py b/scripts/generic_train.py index 998fc53..6005f0d 100644 --- a/scripts/generic_train.py +++ b/scripts/generic_train.py @@ -95,15 +95,15 @@ }, 'prior': { 'training': { - 'lr': 0.01, - 'momentum': 0.95, + 'lr': 0.0005, + 'momentum': 0.9, 'epochs': 100, 'seed': 1135, } }, 'posterior': { 'training': { - 'lr': 0.001, + 'lr': 0.0001, 'momentum': 0.9, 'epochs': 1, 'seed': 1135,