diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index e5cfef44..87c7d75c 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -17,6 +17,11 @@ Abstract base classes :undoc-members: :exclude-members: forward +.. autoclass:: torchjd.aggregation.GeneralizedWeighting + :members: + :undoc-members: + :exclude-members: forward + .. toctree:: :hidden: diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 8d1a9432..fdd4c2a1 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -55,7 +55,7 @@ from ._utils.check_dependencies import ( OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, ) -from ._weighting_bases import Weighting +from ._weighting_bases import GeneralizedWeighting, Weighting try: from ._cagrad import CAGrad, CAGradWeighting diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index daa1bdca..154c7a30 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -52,3 +52,28 @@ def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutpu def forward(self, stat: _T) -> Tensor: return self.weighting(self.fn(stat)) + + +class GeneralizedWeighting(nn.Module, ABC): + r""" + Abstract base class for all weightings that operate on generalized Gramians. It has the role of + extracting a tensor of weights of dimension :math:`m_1 \times \dots \times m_k` from a + generalized Gramian of dimension + :math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`. + """ + + def __init__(self): + super().__init__() + + @abstractmethod + def forward(self, generalized_gramian: Tensor) -> Tensor: + """Computes the vector of weights from the input generalized Gramian.""" + + # Override to make type hints and documentation more specific + def __call__(self, generalized_gramian: Tensor) -> Tensor: + """ + Computes the tensor of weights from the input generalized Gramian and applies all registered + hooks. + """ + + return super().__call__(generalized_gramian) diff --git a/src/torchjd/aggregation/hierachical_weighting.py b/src/torchjd/aggregation/hierachical_weighting.py new file mode 100644 index 00000000..088d5dde --- /dev/null +++ b/src/torchjd/aggregation/hierachical_weighting.py @@ -0,0 +1,105 @@ +import torch +from torch import Tensor + +from ..autogram._gramian_utils import reshape_gramian +from ._weighting_bases import GeneralizedWeighting, Weighting + + +class HierarchicalWeighting(GeneralizedWeighting): + """ + Hierarchically reduces a generalized Gramian using a sequence of weighting functions. + + Applies multiple weightings in sequence to a generalized Gramian ``G`` of shape + ``[n₁, ..., nₖ, nₖ, ..., n₁]``. It first applies the initial weighting to the innermost diagonal + Gramians, contracts those dimensions to form a smaller generalized Gramian, and repeats the + process with subsequent weightings. The final returned weights are chosen so that contracting + the original Gramian directly with these weights produces the same quadratic form as applying + the reductions step by step. + + :param weightings: A list of weighting callables, one for each hierarchical reduction step. + """ + + def __init__(self, weightings: list[Weighting]): + super().__init__() + self.weightings = weightings + self.n_dims = len(weightings) + + def forward(self, generalized_gramian: Tensor) -> Tensor: + + assert len(self.weightings) * 2 == len(generalized_gramian.shape) # temporary + + weighting = self.weightings[0] + dim_size = generalized_gramian.shape[0] + reshaped_gramian = reshape_gramian(generalized_gramian, [-1, dim_size]) + weights = _compute_weights(weighting, reshaped_gramian) + generalized_gramian = _contract_gramian(reshaped_gramian, weights) + + for i in range(self.n_dim): + weighting = self.weightings[i] + dim_size = generalized_gramian.shape[i] + reshaped_gramian = reshape_gramian(generalized_gramian, [-1, dim_size]) + temp_weights = _compute_weights(weighting, reshaped_gramian) + generalized_gramian = _contract_gramian(reshaped_gramian, temp_weights) + weights = _scale_weights(weights, temp_weights) + + return weights + + +def _compute_weights(weighting: Weighting, generalized_gramian: Tensor) -> Tensor: + """ + Apply a weighting to each diagonal Gramian in a generalized Gramian. + + For a generalized Gramian ``G`` of shape ``[m, n, n, m]``, this extracts each diagonal Gramian + ``G[j, :, :, j]`` of shape ``[n, n]`` for ``j`` in ``[m]`` and applies the provided weighting. + The resulting weights are stacked into a tensor of shape ``[m, n]``. + + :param weighting: Callable that maps a Gramian of shape ``[n, n]`` to weights of shape ``[n]``. + :param generalized_gramian: Tensor of shape ``[m, n, n, m]`` containing the generalized Gramian. + :returns: Tensor of shape ``[m, n]`` containing the computed weights for each diagonal Gramian. + """ + + weights = torch.zeros( + generalized_gramian[:2], device=generalized_gramian.device, dtype=generalized_gramian.dtype + ) + for i in range(generalized_gramian.shape[0]): + weights[i] = weighting(generalized_gramian[i, :, :, i]) + return weights + + +def _contract_gramian(generalized_gramian: Tensor, weights: Tensor) -> Tensor: + r""" + Compute a partial quadratic form by contracting a generalized Gramian with weight vectors on + both sides. + + Given a generalized Gramian ``G`` of shape ``[m, n, n, m]`` and weights ``w`` of shape + ``[m, n]``, this function computes a Gramian ``G'`` of shape ``[m, m]`` where + + .. math:: + + G'[i, j] = \sum_{k, l=1}^n w[i, k] G[i, k, l, j] w[j, l]. + + This can be viewed as forming a quadratic form with respect to the two innermost dimensions of + ``G``. + + :param generalized_gramian: Tensor of shape ``[m, n, n, m]`` representing the generalized + Gramian. + :param weights: Tensor of shape ``[m, n]`` containing weight vectors to contract with the + Gramian. + :returns: Tensor of shape ``[m, m]`` containing the contracted Gramian, i.e. the partial + quadratic form. + """ + left_product = torch.einsum("ij,ijkl->ikl", weights, generalized_gramian) + return torch.einsum("ij,ijl->il", weights, left_product) + + +def _scale_weights(weights: Tensor, scalings: Tensor) -> Tensor: + """ + Scale a tensor along its leading dimensions by broadcasting scaling factors. + + :param weights: Tensor of shape [n₁, ..., nₖ, nₖ₊₁, ..., nₚ]. + :param scalings: Tensor of shape [n₁, ..., nₖ] providing scaling factors for the leading + dimensions of ``weights``. + :returns: Tensor of the same shape as ``weights``, where each slice + ``weights[i₁, ..., iₖ, :, ..., :]`` is multiplied by ``scalings[i₁, ..., iₖ]``. + """ + return weights * scalings[(...,) + (None,) * (weights.ndim - scalings.ndim)] diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 5aabcdc9..9254aa86 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -2,11 +2,12 @@ from typing import cast import torch -from torch import Tensor, nn +from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator +from ._gramian_utils import movedim_gramian, reshape_gramian from ._module_hook_manager import ModuleHookManager _INCOMPATIBLE_MODULE_TYPES = ( @@ -57,6 +58,10 @@ class Engine: :param modules: A collection of modules whose direct (non-recursive) parameters will contribute to the Gramian of the Jacobian. + :param batched_dim: If the modules work with batches and process each batch element + independently, then many intermediary jacobians are sparse (block-diagonal), which allows + for a substancial memory optimization by backpropagating a squashed Jacobian instead. This + parameter indicates the batch dimension, if any. Defaults to None. .. admonition:: Example @@ -79,7 +84,7 @@ class Engine: >>> >>> criterion = MSELoss(reduction="none") >>> weighting = UPGradWeighting() - >>> engine = Engine(model.modules()) + >>> engine = Engine(model.modules(), batched_dim=0) >>> >>> for input, target in zip(inputs, targets): >>> output = model(input).squeeze(dim=1) # shape: [16] @@ -127,10 +132,17 @@ class Engine: `_ layers. """ - def __init__(self, modules: Iterable[nn.Module]): + def __init__( + self, + modules: Iterable[nn.Module], + batched_dim: int | None = None, + ): self._gramian_accumulator = GramianAccumulator() self._target_edges = EdgeRegistry() - self._module_hook_manager = ModuleHookManager(self._target_edges, self._gramian_accumulator) + self._batched_dim = batched_dim + self._module_hook_manager = ModuleHookManager( + self._target_edges, self._gramian_accumulator, batched_dim is not None + ) self._hook_modules(modules) @@ -143,13 +155,15 @@ def _hook_modules(self, modules: Iterable[nn.Module]) -> None: self._check_module_is_compatible(module) self._module_hook_manager.hook_module(module) - @staticmethod - def _check_module_is_compatible(module: nn.Module) -> None: - if isinstance(module, _INCOMPATIBLE_MODULE_TYPES): + def _check_module_is_compatible(self, module: nn.Module) -> None: + if self._batched_dim is not None and isinstance(module, _INCOMPATIBLE_MODULE_TYPES): raise ValueError( f"Found a module of type {type(module)}, which is incompatible with the autogram " - f"engine. The incompatible module types are {_INCOMPATIBLE_MODULE_TYPES} (and their" - " subclasses)." + f"engine when `batched_dim` is not `None`. The incompatible module types are " + f"{_INCOMPATIBLE_MODULE_TYPES} (and their subclasses). The recommended fix is to " + f"replace incompatible layers by something else (e.g. BatchNorm by InstanceNorm), " + f"but if you really can't and performance not a priority, you may also just set" + f"`batch_dim=None` when creating the engine." ) if isinstance(module, _TRACK_RUNNING_STATS_MODULE_TYPES) and module.track_running_stats: @@ -161,17 +175,68 @@ def _check_module_is_compatible(module: nn.Module) -> None: ) def compute_gramian(self, output: Tensor) -> Tensor: - """ - Compute the Gramian of the Jacobian of ``output`` with respect the direct parameters of all - ``modules``. + r""" + Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of + all ``modules``. + + .. note:: + This function doesn't require ``output`` to be a vector. For example, if ``output`` is + a matrix of shape :math:`[m_1, m_2]`, its Jacobian :math:`J` with respect to the + parameters will be of shape :math:`[m_1, m_2, n]`, where :math:`n` is the number of + parameters in the model. This is what we call a generalized Jacobian. The + corresponding Gramian :math:`G = J J^\top` will be of shape + :math:`[m_1, m_2, m_2, m_1]`. This is what we call a `generalized Gramian`. The number + of dimensions of the returned generalized Gramian will always be twice that of the + ``output``. + + A few examples: + - 0D (scalar) ``output``: 0D Gramian (this can be used to efficiently compute the + squared norm of the gradient of ``output``). + - 1D (vector) ``output``: 2D Gramian (this is the standard setting of Jacobian + descent). + - 2D (matrix) ``output``: 4D Gramian (this can happen when combining IWRM and + multi-task learning, as each sample in the batch has one loss per task). + - etc. - :param output: The vector to differentiate. Must be a 1-D tensor. + :param output: The tensor of arbitrary shape to differentiate. The shape of the returned + Gramian depends on the shape of this output, as explained in the note above. """ - reshaped_output = output.reshape([-1]) - return self._compute_square_gramian(reshaped_output) + if self._batched_dim is not None: + # move batched dim to the end + ordered_output = output.movedim(self._batched_dim, -1) + ordered_shape = list(ordered_output.shape) + batch_size = ordered_shape[-1] + has_non_batched_dim = len(ordered_shape) > 1 + target_shape = [batch_size] + else: + ordered_output = output + ordered_shape = list(ordered_output.shape) + has_non_batched_dim = len(ordered_shape) > 0 + target_shape = [] + + if has_non_batched_dim: + target_shape = [-1] + target_shape + + reshaped_output = ordered_output.reshape(target_shape) + # There are four different cases for the shape of reshaped_output: + # - Not batched and not non-batched: scalar of shape [] + # - Batched only: vector of shape [batch_size] + # - Non-batched only: vector of shape [dim] + # - Batched and non-batched: matrix of shape [dim, batch_size] + + square_gramian = self._compute_square_gramian(reshaped_output, has_non_batched_dim) - def _compute_square_gramian(self, output: Tensor) -> Tensor: + unordered_gramian = reshape_gramian(square_gramian, ordered_shape) + + if self._batched_dim is not None: + gramian = movedim_gramian(unordered_gramian, [-1], [self._batched_dim]) + else: + gramian = unordered_gramian + + return gramian + + def _compute_square_gramian(self, output: Tensor, has_non_batched_dim: bool) -> Tensor: self._module_hook_manager.gramian_accumulation_phase = True leaf_targets = list(self._target_edges.get_leaf_edges({get_gradient_edge(output)})) @@ -184,7 +249,20 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: retain_graph=True, ) - _ = differentiation(torch.ones_like(output)) + if has_non_batched_dim: + # There is one non-batched dimension, it is the first one + non_batched_dim_len = output.shape[0] + jac_output_shape = [output.shape[0]] + list(output.shape) + + # Need to batch `grad_output` over the first dimension + jac_output = torch.zeros(jac_output_shape, device=output.device, dtype=output.dtype) + for i in range(non_batched_dim_len): + jac_output[i, i, ...] = 1 + + _ = vmap(differentiation)(jac_output) + else: + grad_output = torch.ones_like(output) + _ = differentiation(grad_output) # If the gramian were None, then leaf_targets would be empty, so autograd.grad would # have failed. So gramian is necessarily a valid Tensor here. diff --git a/src/torchjd/autogram/_gramian_utils.py b/src/torchjd/autogram/_gramian_utils.py new file mode 100644 index 00000000..4379b7e0 --- /dev/null +++ b/src/torchjd/autogram/_gramian_utils.py @@ -0,0 +1,69 @@ +from math import prod + +from torch import Tensor + + +def reshape_gramian(gramian: Tensor, shape: list[int]) -> Tensor: + """ + Reshapes a Gramian to a provided shape. As a Gramian is quadratic form, the reshape of the first + half of the target dimensions must be done from the left, while the reshape of the second half + must be done from the right. + :param gramian: Gramian to reshape + :param shape: First half of the target shape, the shape of the output is therefore + `shape + shape[::-1]`. + """ + + # Example: `gramian` of shape [24, 24] and `shape` of [4, 3, 2]: + # - The `unordered_gramian` will be of shape [4, 3, 2, 4, 3, 2] + # - The `last_dims` will be [3, 4, 5] and `last_dims[::-1]` will be [5, 4, 3] + # - The `reordered_gramian` will be of shape [4, 3, 2, 2, 3, 4] + + automatic_dimensions = [i for i in range(len(shape)) if shape[i] == -1] + if len(automatic_dimensions) == 1: + index = automatic_dimensions[0] + current_shape = gramian.shape[: len(gramian.shape) // 2] + numel = prod(current_shape) + specified_numel = -prod(shape) # shape[index] == -1, this is the product of all other dims + shape[index] = numel // specified_numel + + unordered_intput_gramian = _revert_last_dims(gramian) + unordered_output_gramian = unordered_intput_gramian.reshape(shape + shape) + reordered_output_gramian = _revert_last_dims(unordered_output_gramian) + return reordered_output_gramian + + +def movedim_gramian(gramian: Tensor, source: list[int], destination: list[int]) -> Tensor: + """ + Moves the dimensions of a Gramian from some source dimensions to destination dimensions. As a + Gramian is quadratic form, moving dimension must be done simultaneously on the first half of the + dimensions and on the second half of the dimensions reversed. + :param gramian: Gramian to reshape. + :param source: Source dimensions, that should be in the range + [-gramian.ndim//2, gramian.ndim//2[. Its elements should be unique. + :param destination: Destination dimensions, that should be in the range + [-gramian.ndim//2, gramian.ndim//2[. It should have the same size as `source`, and its + elements should be unique. + """ + + # Example: `gramian` of shape [4, 3, 2, 2, 3, 4], `source` of [-2, 2] and destination of [0, 1]: + # - `source_` will be [1, 2] and `destination_` will be [0, 1] + # - `mirrored_source` will be [1, 2, 4, 3] and `mirrored_destination` will be [0, 1, 5, 4] + # - The `moved_gramian` will be of shape [3, 2, 4, 4, 2, 3] + + # Map everything to the range [0, gramian.ndim//2[ + length = gramian.ndim // 2 + source_ = [i if 0 <= i else i + length for i in source] + destination_ = [i if 0 <= i else i + length for i in destination] + + # Mirror the source and destination and use the result to move the dimensions of the gramian + last_dim = gramian.ndim - 1 + mirrored_source = source_ + [last_dim - i for i in source_] + mirrored_destination = destination_ + [last_dim - i for i in destination_] + moved_gramian = gramian.movedim(mirrored_source, mirrored_destination) + return moved_gramian + + +def _revert_last_dims(generalized_gramian: Tensor) -> Tensor: + input_ndim = len(generalized_gramian.shape) // 2 + last_dims = [input_ndim + i for i in range(input_ndim)] + return generalized_gramian.movedim(last_dims, last_dims[::-1]) diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 07381184..7660bc04 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -3,12 +3,12 @@ import torch from torch import Tensor, nn from torch.autograd.graph import get_gradient_edge -from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_unflatten +from torch.utils._pytree import PyTree, TreeSpec, tree_flatten, tree_map, tree_unflatten from torch.utils.hooks import RemovableHandle as TorchRemovableHandle from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator -from ._vjp import get_functional_vjp +from ._vjp import AutogradVJP, FunctionalVJP # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -32,9 +32,11 @@ def __init__( self, target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, + has_batch_dim: bool, ): self._target_edges = target_edges self._gramian_accumulator = gramian_accumulator + self._has_batch_dim = has_batch_dim self.gramian_accumulation_phase = False self._handles: list[TorchRemovableHandle] = [] @@ -50,7 +52,7 @@ def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: if self.gramian_accumulation_phase: return output - flat_outputs, tree_spec = tree_flatten(output) + flat_outputs, output_spec = tree_flatten(output) if not any(isinstance(t, Tensor) for t in flat_outputs): # This can happen only if a module returns no Tensor, for instance some niche usage @@ -68,7 +70,7 @@ def module_hook(_: nn.Module, args: PyTree, output: PyTree) -> PyTree: index = cast(int, preference.argmin().item()) self._target_edges.register(get_gradient_edge(flat_outputs[index])) - return self._apply_jacobian_accumulator(module, args, tree_spec, flat_outputs) + return self._apply_jacobian_accumulator(module, args, output_spec, flat_outputs) handle = module.register_forward_hook(module_hook) self._handles.append(handle) @@ -77,23 +79,46 @@ def _apply_jacobian_accumulator( self, module: nn.Module, args: PyTree, - tree_spec: TreeSpec, + output_spec: TreeSpec, flat_outputs: list[Tensor], ) -> PyTree: - vjp = torch.vmap(get_functional_vjp(module)) + + if self._has_batch_dim: + vjp = torch.vmap(FunctionalVJP(module)) + else: + vjp = AutogradVJP(module, flat_outputs) class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward(*flat_grad_outputs: Tensor) -> None: - grad_outputs = tree_unflatten(flat_grad_outputs, tree_spec) - jacobians = vjp(grad_outputs, args) - self._gramian_accumulator.accumulate_path_jacobians( - { - module.get_parameter(param_name): jacobian - for param_name, jacobian in jacobians.items() - } - ) + # There is no non-batched dimension + grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) + generalized_jacobians = vjp(grad_outputs, args) + path_jacobians = AccumulateJacobian._make_path_jacobians(generalized_jacobians) + self._gramian_accumulator.accumulate_path_jacobians(path_jacobians) + + @staticmethod + def vmap(_, in_dims: PyTree, *flat_jac_outputs: Tensor) -> tuple[None, None]: + # There is a non-batched dimension + jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) + # We do not vmap over the args for the non-batched dimension + in_dims = (tree_unflatten(in_dims, output_spec), tree_map(lambda _: None, args)) + generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) + path_jacobians = AccumulateJacobian._make_path_jacobians(generalized_jacobians) + self._gramian_accumulator.accumulate_path_jacobians(path_jacobians) + return None, None + + @staticmethod + def _make_path_jacobians( + generalized_jacobians: dict[str, Tensor], + ) -> dict[Tensor, Tensor]: + path_jacobians: dict[Tensor, Tensor] = {} + for param_name, generalized_jacobian in generalized_jacobians.items(): + key = module.get_parameter(param_name) + jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) + path_jacobians[key] = jacobian + return path_jacobians @staticmethod def setup_context(*_): @@ -127,4 +152,4 @@ def backward(ctx, *flat_grad_outputs: Tensor): return flat_grad_outputs - return tree_unflatten(JacobianAccumulator.apply(*flat_outputs), tree_spec) + return tree_unflatten(JacobianAccumulator.apply(*flat_outputs), output_spec) diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index b4bea046..4a93f69a 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -1,9 +1,10 @@ -from collections.abc import Callable +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence import torch from torch import Tensor, nn from torch.nn import Parameter -from torch.utils._pytree import PyTree, tree_map_only +from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten # Note about import from protected _pytree module: # PyTorch maintainers plan to make pytree public (see @@ -14,19 +15,42 @@ # still support older versions of PyTorch where pytree is protected). -def get_functional_vjp(module: nn.Module) -> Callable[[PyTree, PyTree], dict[str, Tensor]]: +class VJP(ABC): """ - Create a VJP function for a module's forward pass with respect to its parameters. The returned - function takes both the input and the cotangents that can be vmaped jointly in both terms to - avoid providing to block diagonal jacobians. + Represents an abstract VJP function for a module's forward pass with respect to its parameters. :params module: The module to differentiate. - :returns: VJP function that takes cotangents and inputs and returns dictionary of names of + """ + + def __init__(self, module: nn.Module): + self.module = module + named_parameters = dict(module.named_parameters(recurse=False)) + self.trainable_params = {k: v for k, v in named_parameters.items() if v.requires_grad} + self.frozen_params = {k: v for k, v in named_parameters.items() if not v.requires_grad} + + @abstractmethod + def __call__(self, grad_outputs: PyTree, inputs: PyTree) -> dict[str, Tensor]: + """ + VJP function that takes cotangents and inputs and returns dictionary of names of parameters (as given by `module.named_parameters.keys()`) to gradients of the parameters for the given cotangents at the given inputs. + """ + + +class FunctionalVJP(VJP): """ + Represents a VJP function for a module's forward pass with respect to its parameters using the + func api. The __call__ function takes both the inputs and the cotangents that can be vmaped + jointly in both terms to avoid providing to block diagonal jacobians. The disadvantage of using + this method is that it computes the forward phase. - def get_vjp(grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: + :params module: The module to differentiate. + """ + + def __init__(self, module: nn.Module): + super().__init__(module) + + def __call__(self, grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. # nn.Flatten) do not work equivalently if they're provided with a batch or with @@ -39,30 +63,52 @@ def get_vjp(grad_outputs_j: PyTree, inputs_j: PyTree) -> dict[str, Tensor]: # primals (tuple), here the functional has a single primal which is # dict(module.named_parameters()). We therefore take the 0'th element to obtain # the dict of gradients w.r.t. the module's named_parameters. - return _vjp_from_module(module, inputs_j)(grad_outputs_j)[0] + return self._vjp_from_module(inputs_j)(grad_outputs_j)[0] - return get_vjp + def _vjp_from_module(self, inputs: PyTree) -> Callable[[PyTree], tuple[dict[str, Tensor]]]: + """ + Create a VJP function for a module's forward pass with respect to its parameters. + Returns a function that computes vector-Jacobian products for the module's parameters given + fixed inputs. Only parameters with requires_grad=True are included in the differentiation. -def _vjp_from_module( - module: nn.Module, inputs: PyTree -) -> Callable[[PyTree], tuple[dict[str, Tensor]]]: - """ - Create a VJP function for a module's forward pass with respect to its parameters. + :param inputs: Fixed inputs to the module for the VJP computation. + :returns: VJP function that takes cotangents and returns parameter gradients. + """ - Returns a function that computes vector-Jacobian products for the module's parameters given - fixed inputs. Only parameters with requires_grad=True are included in the differentiation. + def functional_model_call(primals: dict[str, Parameter]) -> Tensor: + all_state = { + **primals, + **dict(self.module.named_buffers()), + **self.frozen_params, + } + return torch.func.functional_call(self.module, all_state, inputs) - :param module: The module to differentiate. - :param inputs: Fixed inputs to the module for the VJP computation. - :returns: VJP function that takes cotangents and returns parameter gradients. + return torch.func.vjp(functional_model_call, self.trainable_params)[1] + + +class AutogradVJP(VJP): + """ + Represents a VJP function for a module's forward pass with respect to its parameters using the + autograd engine. The __call__ function takes both the inputs and the cotangents but ignores the + inputs. The main advantage of using this method is that it doesn't require computing the forward + phase. """ - named_params = dict(module.named_parameters(recurse=False)) - requires_grad_named_params = {k: v for k, v in named_params.items() if v.requires_grad} - no_requires_grad_named_params = {k: v for k, v in named_params.items() if not v.requires_grad} - def functional_model_call(primals: dict[str, Parameter]) -> Tensor: - all_state = {**primals, **dict(module.named_buffers()), **no_requires_grad_named_params} - return torch.func.functional_call(module, all_state, inputs) + def __init__(self, module: nn.Module, outputs: Sequence[Tensor]): + super().__init__(module) + self.outputs = outputs + self.mask = [output.requires_grad for output in self.outputs] + self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) - return torch.func.vjp(functional_model_call, requires_grad_named_params)[1] + def __call__(self, grad_outputs: PyTree, _: PyTree) -> dict[str, Tensor]: + flat_grad_outputs = tree_flatten(grad_outputs)[0] + grads = torch.autograd.grad( + [t for t, requires_grad in zip(self.outputs, self.mask) if requires_grad], + self.flat_trainable_params, + [t for t, requires_grad in zip(flat_grad_outputs, self.mask) if requires_grad], + retain_graph=True, + allow_unused=True, + materialize_grads=True, + ) + return tree_unflatten(grads, self.param_spec) diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index be55a1ae..e0e3117f 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -18,7 +18,7 @@ def test_engine(): criterion = MSELoss(reduction="none") weighting = UPGradWeighting() - engine = Engine(model.modules()) + engine = Engine(model.modules(), batched_dim=0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 4f2b9351..d7be637e 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -136,7 +136,7 @@ def test_autogram(): params = model.parameters() optimizer = SGD(params, lr=0.1) weighting = UPGradWeighting() - engine = Engine(model.modules()) + engine = Engine(model.modules(), batched_dim=0) for x, y in zip(X, Y): y_hat = model(x).squeeze(dim=1) # shape: [16] @@ -325,7 +325,7 @@ def test_partial_jd(): # Create the autogram engine that will compute the Gramian of the # Jacobian with respect to the two last Linear layers' parameters. - engine = Engine(model[2:].modules()) + engine = Engine(model[2:].modules(), batched_dim=0) params = model.parameters() optimizer = SGD(params, lr=0.1) diff --git a/tests/speed/autogram/grad_vs_jac_vs_gram.py b/tests/speed/autogram/grad_vs_jac_vs_gram.py index 881d0321..7080373c 100644 --- a/tests/speed/autogram/grad_vs_jac_vs_gram.py +++ b/tests/speed/autogram/grad_vs_jac_vs_gram.py @@ -96,7 +96,7 @@ def post_fn(): print(autojac_times) print() - engine = Engine(model.modules()) + engine = Engine(model.modules(), batched_dim=0) autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs)) print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}") print(autogram_times) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index ba4ec06a..83093a22 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -1,10 +1,13 @@ from itertools import combinations +from math import prod import pytest import torch from pytest import mark, param -from torch import nn +from torch import Tensor, nn +from torch.nn import Linear from torch.optim import SGD +from torch.testing import assert_close from unit.conftest import DEVICE from utils.architectures import ( AlexNet, @@ -55,10 +58,11 @@ autojac_forward_backward, make_mse_loss_fn, ) -from utils.tensors import make_tensors +from utils.tensors import make_tensors, ones_, randn_, zeros_ -from torchjd.aggregation import UPGrad, UPGradWeighting +from torchjd.aggregation import GeneralizedWeighting, UPGrad, UPGradWeighting from torchjd.autogram._engine import Engine +from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet from torchjd.autojac._transform._aggregate import _Matrixify @@ -107,9 +111,11 @@ @mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +@mark.parametrize("batched_engine", [False, True]) def test_equivalence_autojac_autogram( architecture: type[ShapedModule], batch_size: int, + batched_engine: bool, ): """ Tests that the autogram engine gives the same results as the autojac engine on IWRM for several @@ -129,7 +135,7 @@ def test_equivalence_autojac_autogram( torch.manual_seed(0) model_autogram = architecture().to(device=DEVICE) - engine = Engine(model_autogram.modules()) + engine = Engine(model_autogram.modules(), batched_dim=0 if batched_engine else None) optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7) optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7) @@ -195,7 +201,7 @@ def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], bat model_autogram = architecture().to(device=DEVICE) # Hook modules and verify that we're equivalent to autojac when using the engine - engine = Engine(model_autogram.modules()) + engine = Engine(model_autogram.modules(), batched_dim=0) torch.manual_seed(0) # Fix randomness for random models autogram_forward_backward(model_autogram, engine, W, input, loss_fn) grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None} @@ -270,7 +276,7 @@ def test_partial_autogram(gramian_module_names: set[str]): expected_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None} model.zero_grad() - engine = Engine(gramian_modules) + engine = Engine(gramian_modules, batched_dim=0) output = model(input) losses = loss_fn(output) @@ -289,4 +295,234 @@ def test_incompatible_modules(architecture: type[nn.Module]): model = architecture().to(device=DEVICE) with pytest.raises(ValueError): - _ = Engine(model.modules()) + _ = Engine(model.modules(), batched_dim=0) + + +@mark.parametrize("shape", [(1, 3), (7, 15), (27, 15)]) +@mark.parametrize("batch_size", [None, 3, 16, 32]) +@mark.parametrize("reduce_output", [True, False]) +def test_gramian_is_correct(shape: tuple[int, int], batch_size: int, reduce_output: bool): + """ + Tests that the Gramian computed by the `Engine` equals to a manual computation of the expected + Gramian. + """ + + is_batched = batch_size is not None + + if is_batched: + batched_dim = 0 + input_dim = [batch_size, shape[0]] + else: + batched_dim = None + input_dim = [shape[0]] + + model = Linear(shape[0], shape[1]).to(device=DEVICE) + engine = Engine([model], batched_dim=batched_dim) + + input = randn_(input_dim) + output = model(input) + if reduce_output: + output = torch.sum(output, dim=-1) + + assert output.ndim == int(not reduce_output) + int(is_batched) + + gramian = engine.compute_gramian(output) + + # compute the expected gramian + output_shape = list(output.shape) + initial_jacobian = torch.diag(ones_(output.numel())).reshape(output_shape + output_shape) + + if reduce_output: + initial_jacobian = initial_jacobian.unsqueeze(-1).repeat( + ([1] * initial_jacobian.ndim) + [shape[1]] + ) + if not is_batched: + initial_jacobian = initial_jacobian.unsqueeze(-2) + input = input.unsqueeze(0) + + assert initial_jacobian.shape[-2] == (1 if batch_size is None else batch_size) + assert initial_jacobian.shape[-1] == shape[1] + assert initial_jacobian.shape[:-2] == output.shape + + assert input.shape[0] == (1 if batch_size is None else batch_size) + assert input.shape[1] == shape[0] + + # If k is the batch_size (1 if None) and n the input size and m the output size, then + # - input has shape `[k, n]` + # - initial_jacobian has shape `output.shape + `[k, m]` + + # The partial (batched) jacobian of outputs w.r.t. weights is of shape `[k, m, m, n]`, whe + # multiplied (along 2 dims) by initial_jacobian this yields the jacobian of the weights of shape + # `output.shape + [m, n]`. The partial jacobian itself is block diagonal with diagonal defined + # by `partial_weight_jacobian[i, j, j] = input[i]` (other elements are 0). + + partial_weight_jacobian = zeros_([input.shape[0], shape[1], shape[1], shape[0]]) + for j in range(shape[1]): + partial_weight_jacobian[:, j, j, :] = input + weight_jacobian = torch.tensordot( + initial_jacobian, partial_weight_jacobian, dims=([-2, -1], [0, 1]) + ) + weight_gramian = torch.tensordot(weight_jacobian, weight_jacobian, dims=([-2, -1], [-2, -1])) + if weight_gramian.ndim == 4: + weight_gramian = weight_gramian.movedim((-2), (-1)) + + # The partial (batched) jacobian of outputs w.r.t. bias is of shape `[k, m, m]`, when multiplied + # (along 2 dims) by initial_jacobian this yields the jacobian of the bias of shape + # `output.shape + [m]`. The partial jacobian itself is block diagonal with diagonal defined by + # `partial_bias_jacobian[i, j, j] = 1` (other elements are 0). + partial_bias_jacobian = zeros_([input.shape[0], shape[1], shape[1]]) + for j in range(shape[1]): + partial_bias_jacobian[:, j, j] = 1.0 + bias_jacobian = torch.tensordot( + initial_jacobian, partial_bias_jacobian, dims=([-2, -1], [0, 1]) + ) + bias_gramian = torch.tensordot(bias_jacobian, bias_jacobian, dims=([-1], [-1])) + if bias_gramian.ndim == 4: + bias_gramian = bias_gramian.movedim(-2, -1) + + expected_gramian = weight_gramian + bias_gramian + + assert_close(gramian, expected_gramian) + + +@mark.parametrize( + "shape", + [ + [1, 2, 2, 3], + [7, 3, 2, 5], + [27, 6, 7], + [3, 2, 1, 1], + [3, 2, 1], + [3, 2], + [3], + [1, 1, 1, 1], + [1, 1, 1], + [1, 1], + [1], + ], +) +def test_reshape_equivariance(shape: list[int]): + """ + Test equivariance of `compute_gramian` under reshape operation. More precisely, if we reshape + the `output` to some `shape`, then the result is the same as reshaping the Gramian to the + corresponding shape. + """ + + input_size = shape[0] + output_size = prod(shape[1:]) + + model = Linear(input_size, output_size).to(device=DEVICE) + engine1 = Engine([model]) + engine2 = Engine([model]) + + input = randn_([input_size]) + output = model(input) + + reshaped_output = output.reshape(shape[1:]) + + gramian = engine1.compute_gramian(output) + reshaped_gramian = engine2.compute_gramian(reshaped_output) + + expected_reshaped_gramian = reshape_gramian(gramian, shape[1:]) + + assert_close(reshaped_gramian, expected_reshaped_gramian) + + +@mark.parametrize( + ["shape", "source", "destination"], + [ + ([50, 2, 2, 3], [0, 2], [1, 0]), + ([60, 3, 2, 5], [1], [2]), + ([30, 6, 7], [0, 1], [1, 0]), + ([3, 2], [0], [0]), + ([3], [], []), + ([3, 2, 1], [1, 0], [0, 1]), + ([4, 3, 2], [], []), + ([1, 1, 1], [1, 0], [0, 1]), + ], +) +def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): + """ + Test equivariance of `compute_gramian` under movedim operation. More precisely, if we movedim + the `output` on some dimensions, then the result is the same as movedim on the Gramian with the + corresponding dimensions. + """ + + input_size = shape[0] + output_size = prod(shape[1:]) + + model = Linear(input_size, output_size).to(device=DEVICE) + engine1 = Engine([model]) + engine2 = Engine([model]) + + input = randn_([input_size]) + output = model(input).reshape(shape[1:]) + + moved_output = output.movedim(source, destination) + + gramian = engine1.compute_gramian(output) + moved_gramian = engine2.compute_gramian(moved_output) + + expected_moved_gramian = movedim_gramian(gramian, source, destination) + + assert_close(moved_gramian, expected_moved_gramian) + + +@mark.parametrize( + ["shape", "batched_dim"], + [ + ([2, 5, 3, 2], 2), + ([3, 2, 5], 1), + ([6, 3], 0), + ([4, 3, 2], 1), + ([1, 1, 1], 0), + ([1, 1, 1], 1), + ([1, 1, 1], 2), + ([1, 1], 0), + ([1], 0), + ([4, 3, 1], 2), + ], +) +def test_batched_non_batched_equivalence(shape: list[int], batched_dim: int): + """ + Tests that for a vector with some batched dimensions, the gramian is the same if we use the + appropriate `batched_dims` or if we don't use any. + """ + + non_batched_shape = [shape[i] for i in range(len(shape)) if i != batched_dim] + input_size = prod(non_batched_shape) + batch_size = shape[batched_dim] + output_size = input_size + + model = Linear(input_size, output_size).to(device=DEVICE) + engine1 = Engine([model], batched_dim=batched_dim) + engine2 = Engine([model]) + + input = randn_([batch_size, input_size]) + output = model(input) + output = output.reshape([batch_size] + non_batched_shape) + output = output.movedim(0, batched_dim) + + gramian1 = engine1.compute_gramian(output) + gramian2 = engine2.compute_gramian(output) + + assert_close(gramian1, gramian2) + + +class FakeGeneralizedWeighting(GeneralizedWeighting): + """ + Fake GeneralizedWeighting flattening the Gramian and using UPGradWeighting on it. Could be + removed when we implement a proper FlatteningGeneralizedWeighting.""" + + def __init__(self): + super().__init__() + self.weighting = UPGradWeighting() + + def forward(self, generalized_gramian: Tensor) -> Tensor: + k = generalized_gramian.ndim // 2 + shape = generalized_gramian.shape[:k] + m = prod(shape) + square_gramian = reshape_gramian(generalized_gramian, [m]) + weights_vector = self.weighting(square_gramian) + weights = weights_vector.reshape(shape) + return weights diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py new file mode 100644 index 00000000..3ea8aa3d --- /dev/null +++ b/tests/unit/autogram/test_gramian_utils.py @@ -0,0 +1,99 @@ +from math import prod + +import torch +from pytest import mark +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import rand_ + +from torchjd.autogram._gramian_utils import movedim_gramian, reshape_gramian + + +def compute_quadratic_form(generalized_gramian: Tensor, x: Tensor) -> Tensor: + """ + Compute the quadratic form x^T G x when the provided generalized Gramian and x may have multiple + dimensions. + """ + indices = list(range(x.ndim)) + linear_form = torch.tensordot(x, generalized_gramian, dims=(indices, indices)) + return torch.tensordot(linear_form, x, dims=(indices[::-1], indices)) + + +@mark.parametrize( + "shape", + [ + [50, 2, 2, 3], + [60, 3, 2, 5], + [30, 6, 7], + [4, 3, 1], + [4, 1, 1], + [1, 1, 1], + [4, 1], + [4], + [1, 1], + [1], + ], +) +def test_quadratic_form_invariance_to_reshape(shape: list[int]): + """ + When reshaping a Gramian, we expect it to represent the same quadratic form that now applies to + reshaped inputs. So the mapping x -> x^T G x commutes with reshaping x, G and then computing the + corresponding quadratic form. + """ + + flat_dim = prod(shape[1:]) + iterations = 20 + + matrix = rand_([flat_dim, shape[0]]) + gramian = matrix @ matrix.T + reshaped_gramian = reshape_gramian(gramian, shape[1:]) + + for _ in range(iterations): + vector = rand_([flat_dim]) + reshaped_vector = vector.reshape(shape[1:]) + + quadratic_form = vector @ gramian @ vector + reshaped_quadratic_form = compute_quadratic_form(reshaped_gramian, reshaped_vector) + + assert_close(reshaped_quadratic_form, quadratic_form) + + +@mark.parametrize( + ["shape", "source", "destination"], + [ + ([50, 2, 2, 3], [0, 2], [1, 0]), + ([60, 3, 2, 5], [1], [2]), + ([30, 6, 7], [0, 1], [1, 0]), + ([4, 3, 1], [0, 1], [1, 0]), + ([4, 1, 1], [1], [0]), + ([1, 1, 1], [], []), + ([4, 1], [0], [0]), + ([4], [], []), + ([1, 1], [], []), + ([1], [], []), + ], +) +def test_quadratic_form_invariance_to_movedim( + shape: list[int], source: list[int], destination: list[int] +): + """ + When moving dims on a Gramian, we expect it to represent the same quadratic form that now + applies to inputs with moved dims. So the mapping x -> x^T G x commutes with moving dims x, G + and then computing the quadratic form with those. + """ + + flat_dim = prod(shape[1:]) + iterations = 20 + + matrix = rand_([flat_dim, shape[0]]) + gramian = reshape_gramian(matrix @ matrix.T, shape[1:]) + moved_gramian = movedim_gramian(gramian, source, destination) + + for _ in range(iterations): + vector = rand_(shape[1:]) + moved_vector = vector.movedim(source, destination) + + quadratic_form = compute_quadratic_form(gramian, vector) + moved_quadratic_form = compute_quadratic_form(moved_gramian, moved_vector) + + assert_close(moved_quadratic_form, quadratic_form)