Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions deepobs/pytorch/testproblems/fmnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,19 @@ def get_batch_loss_and_accuracy_func(
return_forward_func (bool): If ``True``, the call also returns a function that calculates the loss on the current batch. Can be used if you need to access the forward path twice.
Returns:
float, float, (callable): loss and accuracy of the model on the current batch. If ``return_forward_func`` is ``True`` it also returns the function that calculates the loss on the current batch.
"""
"""
inputs, _ = self._get_next_batch()
inputs = inputs.to(self._device)

def forward_func():
# in evaluation phase is no gradient needed
# TODO move phase distinction to evaluate in runner?
if self.phase in ["train_eval", "test", "valid"]:
with torch.no_grad():
outputs, means, std_devs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(
outputs, inputs, means, std_devs
)
else:
with self._get_forward_context(self.phase)():
outputs, means, std_devs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(
outputs, inputs, means, std_devs
)

accuracy = 0

if add_regularization_if_available:
regularizer_loss = self.get_regularization_loss()
else:
Expand Down
9 changes: 1 addition & 8 deletions deepobs/pytorch/testproblems/mnist_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,7 @@ def get_batch_loss_and_accuracy_func(
inputs = inputs.to(self._device)

def forward_func():
# in evaluation phase is no gradient needed
if self.phase in ["train_eval", "test", "valid"]:
with torch.no_grad():
outputs, means, std_devs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(
outputs, inputs, means, std_devs
)
else:
with self._get_forward_context(self.phase)():
outputs, means, std_devs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(
outputs, inputs, means, std_devs
Expand Down
41 changes: 8 additions & 33 deletions deepobs/pytorch/testproblems/quadratic_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from torch import Tensor

from ..datasets.quadratic import quadratic
from .testproblem import UnregularizedTestproblem
Expand Down Expand Up @@ -115,41 +116,15 @@ def _make_hessian(eigvals_small=90, eigvals_large=10):
Hessian = np.matmul(np.transpose(R), np.matmul(D, R))
return torch.from_numpy(Hessian).to(torch.float32)

def get_batch_loss_and_accuracy_func(
self, reduction="mean", add_regularization_if_available=True
):
"""Get new batch and create forward function that calculates loss on that batch.
@staticmethod
def _compute_accuracy(outputs: Tensor, labels: Tensor) -> float:
"""Return zero as model accuracy (non-existent for this regression task).

Args:
reduction (str): The reduction that is used for returning the loss.
Can be 'mean', 'sum' or 'none' in which case each indivual loss
in the mini-batch is returned as a tensor.
add_regularization_if_available (bool): If true, regularization is added to the loss.
outputs: Model predictions.
labels: Ground truth.

Returns:
callable: The function that calculates the loss/accuracy on the current batch.
0
"""
inputs, labels = self._get_next_batch()
inputs = inputs.to(self._device)
labels = labels.to(self._device)

def forward_func():
# in evaluation phase is no gradient needed
if self.phase in ["train_eval", "test", "valid"]:
with torch.no_grad():
outputs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(outputs, labels)
else:
outputs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(outputs, labels)

accuracy = 0.0

if add_regularization_if_available:
regularizer_loss = self.get_regularization_loss()
else:
regularizer_loss = torch.tensor(0.0, device=torch.device(self._device))

return loss + regularizer_loss, accuracy

return forward_func
return 0.0
77 changes: 56 additions & 21 deletions deepobs/pytorch/testproblems/testproblem.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# -*- coding: utf-8 -*-
"""Base class for DeepOBS test problems."""
import abc
from contextlib import nullcontext
from typing import Callable, ContextManager, Tuple, Type

import torch
from torch import Tensor, no_grad

from .. import config

Expand Down Expand Up @@ -105,8 +108,8 @@ def _get_next_batch(self):
return batch

def get_batch_loss_and_accuracy_func(
self, reduction="mean", add_regularization_if_available=True
):
self, reduction: str = "mean", add_regularization_if_available: bool = True
) -> Callable[[], Tuple[Tensor, float]]:
"""Get new batch and create forward function.

Creates the forward function that calculates loss and accuracy (if available)
Expand All @@ -115,38 +118,30 @@ def get_batch_loss_and_accuracy_func(
this method accordingly.

Args:
reduction (str): The reduction that is used for returning the loss.
reduction: The reduction that is used for returning the loss.
Can be 'mean', 'sum' or 'none' in which case each indivual loss
in the mini-batch is returned as a tensor.
add_regularization_if_available (bool): If true, regularization is
add_regularization_if_available: If true, regularization is
added to the loss.

Returns:
callable: The function that calculates the loss/accuracy on the
current batch.
Function that calculates the loss/accuracy on the current batch.

"""
inputs, labels = self._get_next_batch()
inputs = inputs.to(self._device)
labels = labels.to(self._device)

def forward_func():
correct = 0.0
total = 0.0
def forward_func() -> Tuple[Tensor, float]:
"""Evaluate the forward pass on a fixed mini-batch.

# in evaluation phase is no gradient needed
if self.phase in ["train_eval", "test", "valid"]:
with torch.no_grad():
outputs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(outputs, labels)
else:
Returns:
Mini-batch loss and model accuracy.
"""
with self._get_forward_context(self.phase)():
outputs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(outputs, labels)

_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

accuracy = correct / total
accuracy = self._compute_accuracy(outputs, labels)

if add_regularization_if_available:
regularizer_loss = self.get_regularization_loss()
Expand All @@ -157,6 +152,46 @@ def forward_func():

return forward_func

@staticmethod
def _get_forward_context(phase: str) -> Type[ContextManager]:
"""Get autodiff context for the forward pass (no gradients in evaluation phase).

Args:
phase: Phase of the forward pass.

Returns:
Context manager class for the forward pass.

Raises:
ValueError: If ``phase`` is invalid.
"""
if phase in ["train_eval", "test", "valid"]:
context = no_grad
elif phase == "train":
context = nullcontext
else:
raise ValueError(f"Unknown phase: {phase}")

return context

@staticmethod
def _compute_accuracy(outputs: Tensor, labels: Tensor) -> float:
"""Compute the model accuracy.

Args:
outputs: Model predictions.
labels: Ground truth.

Returns:
Model accuracy
"""
_, predictions = outputs.max(dim=1)
correct = (predictions == labels).sum()
total = labels.numel()
accuracy = correct / total

return float(accuracy.item())

def get_batch_loss_and_accuracy(
self, reduction="mean", add_regularization_if_available=True
):
Expand Down