diff --git a/deepobs/pytorch/testproblems/fmnist_vae.py b/deepobs/pytorch/testproblems/fmnist_vae.py index ee656ba2..dd9bb722 100644 --- a/deepobs/pytorch/testproblems/fmnist_vae.py +++ b/deepobs/pytorch/testproblems/fmnist_vae.py @@ -79,26 +79,19 @@ def get_batch_loss_and_accuracy_func( return_forward_func (bool): If ``True``, the call also returns a function that calculates the loss on the current batch. Can be used if you need to access the forward path twice. Returns: float, float, (callable): loss and accuracy of the model on the current batch. If ``return_forward_func`` is ``True`` it also returns the function that calculates the loss on the current batch. - """ + """ inputs, _ = self._get_next_batch() inputs = inputs.to(self._device) def forward_func(): - # in evaluation phase is no gradient needed - # TODO move phase distinction to evaluate in runner? - if self.phase in ["train_eval", "test", "valid"]: - with torch.no_grad(): - outputs, means, std_devs = self.net(inputs) - loss = self.loss_function(reduction=reduction)( - outputs, inputs, means, std_devs - ) - else: + with self._get_forward_context(self.phase)(): outputs, means, std_devs = self.net(inputs) loss = self.loss_function(reduction=reduction)( outputs, inputs, means, std_devs ) accuracy = 0 + if add_regularization_if_available: regularizer_loss = self.get_regularization_loss() else: diff --git a/deepobs/pytorch/testproblems/mnist_vae.py b/deepobs/pytorch/testproblems/mnist_vae.py index 2b75d6a5..902c3f76 100644 --- a/deepobs/pytorch/testproblems/mnist_vae.py +++ b/deepobs/pytorch/testproblems/mnist_vae.py @@ -84,14 +84,7 @@ def get_batch_loss_and_accuracy_func( inputs = inputs.to(self._device) def forward_func(): - # in evaluation phase is no gradient needed - if self.phase in ["train_eval", "test", "valid"]: - with torch.no_grad(): - outputs, means, std_devs = self.net(inputs) - loss = self.loss_function(reduction=reduction)( - outputs, inputs, means, std_devs - ) - else: + with self._get_forward_context(self.phase)(): outputs, means, std_devs = self.net(inputs) loss = self.loss_function(reduction=reduction)( outputs, inputs, means, std_devs diff --git a/deepobs/pytorch/testproblems/quadratic_deep.py b/deepobs/pytorch/testproblems/quadratic_deep.py index 5a9a6dd5..c5bd7f6f 100644 --- a/deepobs/pytorch/testproblems/quadratic_deep.py +++ b/deepobs/pytorch/testproblems/quadratic_deep.py @@ -3,6 +3,7 @@ import numpy as np import torch +from torch import Tensor from ..datasets.quadratic import quadratic from .testproblem import UnregularizedTestproblem @@ -115,41 +116,15 @@ def _make_hessian(eigvals_small=90, eigvals_large=10): Hessian = np.matmul(np.transpose(R), np.matmul(D, R)) return torch.from_numpy(Hessian).to(torch.float32) - def get_batch_loss_and_accuracy_func( - self, reduction="mean", add_regularization_if_available=True - ): - """Get new batch and create forward function that calculates loss on that batch. + @staticmethod + def _compute_accuracy(outputs: Tensor, labels: Tensor) -> float: + """Return zero as model accuracy (non-existent for this regression task). Args: - reduction (str): The reduction that is used for returning the loss. - Can be 'mean', 'sum' or 'none' in which case each indivual loss - in the mini-batch is returned as a tensor. - add_regularization_if_available (bool): If true, regularization is added to the loss. + outputs: Model predictions. + labels: Ground truth. Returns: - callable: The function that calculates the loss/accuracy on the current batch. + 0 """ - inputs, labels = self._get_next_batch() - inputs = inputs.to(self._device) - labels = labels.to(self._device) - - def forward_func(): - # in evaluation phase is no gradient needed - if self.phase in ["train_eval", "test", "valid"]: - with torch.no_grad(): - outputs = self.net(inputs) - loss = self.loss_function(reduction=reduction)(outputs, labels) - else: - outputs = self.net(inputs) - loss = self.loss_function(reduction=reduction)(outputs, labels) - - accuracy = 0.0 - - if add_regularization_if_available: - regularizer_loss = self.get_regularization_loss() - else: - regularizer_loss = torch.tensor(0.0, device=torch.device(self._device)) - - return loss + regularizer_loss, accuracy - - return forward_func + return 0.0 diff --git a/deepobs/pytorch/testproblems/testproblem.py b/deepobs/pytorch/testproblems/testproblem.py index b661a50e..5ca790d2 100644 --- a/deepobs/pytorch/testproblems/testproblem.py +++ b/deepobs/pytorch/testproblems/testproblem.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- """Base class for DeepOBS test problems.""" import abc +from contextlib import nullcontext +from typing import Callable, ContextManager, Tuple, Type import torch +from torch import Tensor, no_grad from .. import config @@ -105,8 +108,8 @@ def _get_next_batch(self): return batch def get_batch_loss_and_accuracy_func( - self, reduction="mean", add_regularization_if_available=True - ): + self, reduction: str = "mean", add_regularization_if_available: bool = True + ) -> Callable[[], Tuple[Tensor, float]]: """Get new batch and create forward function. Creates the forward function that calculates loss and accuracy (if available) @@ -115,38 +118,30 @@ def get_batch_loss_and_accuracy_func( this method accordingly. Args: - reduction (str): The reduction that is used for returning the loss. + reduction: The reduction that is used for returning the loss. Can be 'mean', 'sum' or 'none' in which case each indivual loss in the mini-batch is returned as a tensor. - add_regularization_if_available (bool): If true, regularization is + add_regularization_if_available: If true, regularization is added to the loss. Returns: - callable: The function that calculates the loss/accuracy on the - current batch. + Function that calculates the loss/accuracy on the current batch. + """ inputs, labels = self._get_next_batch() inputs = inputs.to(self._device) labels = labels.to(self._device) - def forward_func(): - correct = 0.0 - total = 0.0 + def forward_func() -> Tuple[Tensor, float]: + """Evaluate the forward pass on a fixed mini-batch. - # in evaluation phase is no gradient needed - if self.phase in ["train_eval", "test", "valid"]: - with torch.no_grad(): - outputs = self.net(inputs) - loss = self.loss_function(reduction=reduction)(outputs, labels) - else: + Returns: + Mini-batch loss and model accuracy. + """ + with self._get_forward_context(self.phase)(): outputs = self.net(inputs) loss = self.loss_function(reduction=reduction)(outputs, labels) - - _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) - correct += (predicted == labels).sum().item() - - accuracy = correct / total + accuracy = self._compute_accuracy(outputs, labels) if add_regularization_if_available: regularizer_loss = self.get_regularization_loss() @@ -157,6 +152,46 @@ def forward_func(): return forward_func + @staticmethod + def _get_forward_context(phase: str) -> Type[ContextManager]: + """Get autodiff context for the forward pass (no gradients in evaluation phase). + + Args: + phase: Phase of the forward pass. + + Returns: + Context manager class for the forward pass. + + Raises: + ValueError: If ``phase`` is invalid. + """ + if phase in ["train_eval", "test", "valid"]: + context = no_grad + elif phase == "train": + context = nullcontext + else: + raise ValueError(f"Unknown phase: {phase}") + + return context + + @staticmethod + def _compute_accuracy(outputs: Tensor, labels: Tensor) -> float: + """Compute the model accuracy. + + Args: + outputs: Model predictions. + labels: Ground truth. + + Returns: + Model accuracy + """ + _, predictions = outputs.max(dim=1) + correct = (predictions == labels).sum() + total = labels.numel() + accuracy = correct / total + + return float(accuracy.item()) + def get_batch_loss_and_accuracy( self, reduction="mean", add_regularization_if_available=True ):