Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,36 @@ class Engine:
memory-efficient, and thus typically faster, to use the Gramian-based approach.

.. warning::
When providing a non-None ``batch_dim``, all provided modules must respect a few
conditions:
When providing a non-None ``batch_dim``, all provided modules must respect a few conditions:

* They should treat the elements of the batch independently. Most common layers respect
this, but for example `BatchNorm
<https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ does not (it
computes some average and standard deviation over the elements of the batch).
* Their inputs and outputs can be anything, but each input tensor and each output tensor
must be batched on its first dimension. `Transformers
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_ and `RNNs
<https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ are thus not
supported yet. This is only an implementation issue, so it should be fixed soon (please
open an issue if you need extra focus on this).
must be batched on its first dimension. When available (e.g. in `Transformers
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_,
`MultiheadAttention
<https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html>`_,
etc.), the ``batch_first`` parameter has to be set to ``True``. Also, this makes `RNNs
<https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ not supported yet
because their hidden state is batched on dimension 1 even if ``batch_first`` is ``True``.
* They should not perform in-place operations on tensors (for instance you should not use
``track_running_stats=True`` in normalization layers).
* They should not have side effects during the forward pass (since their forward pass will
be called twice, the side effects could be different from what's expected).
* If they have some randomness during the forward pass, they should not have direct
trainable parameters. It is, however, perfectly fine for random modules to have child
modules that have trainable parameters, so if you have a random module with some direct
parameters, a simple fix is to wrap these parameters into a child module.
trainable parameters. For this reason,
`Transformers
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_, which use a
dropout function (rather than a `Dropout
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layer) in a
module with some trainable parameters, has to be used with
``dropout=0.0``. Note that a `Dropout
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ layers are
entirely supported and should be preferred. It is also perfectly fine for random modules
to have child modules that have trainable parameters, so if you have a random module with
some direct parameters, a simple fix is to wrap these parameters into a child module.

If you're building your own architecture, respecting those criteria should be quite easy.
However, if you're using an existing architecture, you may have to modify it to make it
Expand Down
5 changes: 3 additions & 2 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._module_utils import get_used_params
from ._vjp import VJP, AutogradVJP, FunctionalVJP

# Note about import from protected _pytree module:
Expand Down Expand Up @@ -125,8 +126,8 @@ def __call__(
# require grad
return outputs

requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
self.gramian_accumulator.track_parameter_paths(requires_grad_params)
rg_params, _ = get_used_params(module)
self.gramian_accumulator.track_parameter_paths(rg_params.values())

# We only care about running the JacobianAccumulator node, so we need one of its child
# edges (the edges of the original outputs of the model) as target. For memory
Expand Down
47 changes: 47 additions & 0 deletions src/torchjd/autogram/_module_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from torch import nn


def get_used_params(module: nn.Module) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]:
"""
Gets all parameters that a module uses. In reality, we return all direct params (which may
include some unused params) and all the indirectly used params that we know about (we may be
missing some in weird modules).

Returns the tuple containing the params that require grad and the params that don't require
grad.
"""

direct_rg_params, direct_frozen_params = _get_direct_params(module)
indirect_rg_params, indirect_frozen_params = _get_indirectly_used_params(module)
rg_params = direct_rg_params | indirect_rg_params
frozen_params = direct_frozen_params | indirect_frozen_params

return rg_params, frozen_params


def _get_direct_params(
module: nn.Module, prefix: str = ""
) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]:
rg_params = dict[str, nn.Parameter]()
frozen_params = dict[str, nn.Parameter]()

for name, param in module.named_parameters(recurse=False):
if param.requires_grad:
rg_params[prefix + name] = param
else:
frozen_params[prefix + name] = param

return rg_params, frozen_params


def _get_indirectly_used_params(
module: nn.Module,
) -> tuple[dict[str, nn.Parameter], dict[str, nn.Parameter]]:
# MHA uses its out_proj child params itself. Note that we also check that the MHA still has
# an out_proj attribute because it might change in the future (which will remove the
# necessity of custom code for MHA entirely). See the status of
# https://github.com/pytorch/pytorch/pull/126568
if isinstance(module, nn.MultiheadAttention) and hasattr(module, "out_proj"):
return _get_direct_params(module.out_proj, prefix="out_proj.")

return {}, {}
21 changes: 8 additions & 13 deletions src/torchjd/autogram/_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from torch.nn import Parameter
from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten

from torchjd.autogram._module_utils import get_used_params

# Note about import from protected _pytree module:
# PyTorch maintainers plan to make pytree public (see
# https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400).
Expand Down Expand Up @@ -37,14 +39,7 @@ class ModuleVJP(VJP, ABC):

def __init__(self, module: nn.Module):
self.module = module
self.trainable_params = dict[str, Parameter]()
self.frozen_params = dict[str, Parameter]()

for name, param in module.named_parameters(recurse=False):
if param.requires_grad:
self.trainable_params[name] = param
else:
self.frozen_params[name] = param
self.rg_params, self.frozen_params = get_used_params(module)


class FunctionalVJP(ModuleVJP):
Expand Down Expand Up @@ -78,9 +73,9 @@ def _call_on_one_instance(
kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j)
grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j]

def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]:
def functional_model_call(rg_params: dict[str, Parameter]) -> list[Tensor]:
all_state = {
**trainable_params,
**rg_params,
**dict(self.module.named_buffers()),
**self.frozen_params,
}
Expand All @@ -89,7 +84,7 @@ def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor
rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad]
return rg_outputs

vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]
vjp_func = torch.func.vjp(functional_model_call, self.rg_params)[1]

# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
# functional has a single primal which is dict(module.named_parameters()). We therefore take
Expand All @@ -109,14 +104,14 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]):
super().__init__(module)

self.rg_outputs = rg_outputs
self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params)
self.flat_rg_params, self.param_spec = tree_flatten(self.rg_params)

def __call__(
self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree]
) -> dict[str, Tensor]:
grads = torch.autograd.grad(
self.rg_outputs,
self.flat_trainable_params,
self.flat_rg_params,
grad_outputs,
retain_graph=True,
allow_unused=True,
Expand Down
2 changes: 2 additions & 0 deletions tests/speed/autogram/grad_vs_jac_vs_gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
NoFreeParam,
ShapedModule,
SqueezeNet,
WithTransformerLarge,
)
from utils.forward_backwards import (
autograd_forward_backward,
Expand All @@ -27,6 +28,7 @@
from torchjd.autogram import Engine

PARAMETRIZATIONS = [
(WithTransformerLarge, 8),
(FreeParam, 64),
(NoFreeParam, 64),
(Cifar10Model, 64),
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@
WithModuleWithStringArg,
WithModuleWithStringKwarg,
WithModuleWithStringOutput,
WithMultiHeadAttention,
WithNoTensorOutput,
WithRNN,
WithSideEffect,
WithSomeFrozenModule,
WithTransformer,
WithTransformerLarge,
)
from utils.dict_assertions import assert_tensor_dicts_are_close
from utils.forward_backwards import (
Expand Down Expand Up @@ -118,6 +121,8 @@
(WithModuleWithStringOutput, 32),
(WithModuleWithStringKwarg, 32),
(WithModuleWithHybridPyTreeKwarg, 32),
(WithMultiHeadAttention, 32),
param(WithTransformer, 32, marks=mark.filterwarnings("ignore:There is a performance drop")),
(FreeParam, 32),
(NoFreeParam, 32),
param(Cifar10Model, 16, marks=mark.slow),
Expand All @@ -126,6 +131,11 @@
param(GroupNormMobileNetV3Small, 3, marks=mark.slow),
param(SqueezeNet, 8, marks=mark.slow),
param(InstanceNormMobileNetV2, 2, marks=mark.slow),
param(
WithTransformerLarge,
8,
marks=[mark.slow, mark.filterwarnings("ignore:There is a performance drop")],
),
]


Expand Down Expand Up @@ -565,3 +575,42 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int):
gramian2 = engine2.compute_gramian(output)

assert_close(gramian1, gramian2)


@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_size: int):
"""
Same as test_batched_non_batched_equivalence but on real architectures, and thus only between
batch_size=0 and batch_size=None.

If for some architecture this test passes but the test_compute_gramian doesn't pass, it could be
that the get_used_params does not work for some module of the architecture.
"""

input_shapes = architecture.INPUT_SHAPES
output_shapes = architecture.OUTPUT_SHAPES

torch.manual_seed(0)
model_0 = architecture().to(device=DEVICE)
torch.manual_seed(0)
model_none = architecture().to(device=DEVICE)

engine_0 = Engine(model_0.modules(), batch_dim=0)
engine_none = Engine(model_none.modules(), batch_dim=None)

inputs = make_tensors(batch_size, input_shapes)
targets = make_tensors(batch_size, output_shapes)
loss_fn = make_mse_loss_fn(targets)

torch.random.manual_seed(0) # Fix randomness for random models
output = model_0(inputs)
losses_0 = reduce_to_vector(loss_fn(output))

torch.random.manual_seed(0) # Fix randomness for random models
output = model_none(inputs)
losses_none = reduce_to_vector(loss_fn(output))

gramian_0 = engine_0.compute_gramian(losses_0)
gramian_none = engine_none.compute_gramian(losses_none)

assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5)
64 changes: 64 additions & 0 deletions tests/utils/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,70 @@ def forward(self, input: Tensor) -> Tensor:
return output


class WithMultiHeadAttention(ShapedModule):
"""Module containing a MultiheadAttention layer."""

INPUT_SHAPES = ((20, 8), (10, 9), (10, 11))
OUTPUT_SHAPES = (20, 8)

def __init__(self):
super().__init__()
self.mha = nn.MultiheadAttention(
embed_dim=8,
num_heads=2,
dropout=0.0,
batch_first=True,
kdim=9,
vdim=11,
)

def forward(self, input: tuple[Tensor, Tensor, Tensor]) -> Tensor:
query, key, value = input
attn_output, _ = self.mha(query, key, value)
return attn_output


class WithTransformer(ShapedModule):
"""Module containing a Transformer."""

INPUT_SHAPES = ((10, 8), (20, 8))
OUTPUT_SHAPES = (20, 8)

def __init__(self):
super().__init__()
self.transformer = nn.Transformer(
d_model=8,
nhead=2,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=32,
batch_first=True,
dropout=0.0,
)

def forward(self, input: tuple[Tensor, Tensor]) -> Tensor:
src, tgt = input
return self.transformer(src, tgt)


class WithTransformerLarge(ShapedModule):
"""Module containing a large Transformer."""

INPUT_SHAPES = ((10, 512), (20, 512))
OUTPUT_SHAPES = (20, 512)

def __init__(self):
super().__init__()
self.transformer = nn.Transformer(
batch_first=True,
dropout=0.0,
)

def forward(self, input: tuple[Tensor, Tensor]) -> Tensor:
src, tgt = input
return self.transformer(src, tgt)


class FreeParam(ShapedModule):
"""
Model that contains a free (i.e. not contained in a submodule) parameter, that is used at the
Expand Down
Loading