Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ class Engine:
this, but for example `BatchNorm
<https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html>`_ does not (it
computes some average and standard deviation over the elements of the batch).
* Their inputs and outputs can be any PyTree (tensor, tuple or list of tensors, dict of
tensors, or any nesting of those structures), but each of these tensors must be batched on
its first dimension. `Transformers
* Their inputs and outputs can be anything, but each input tensor and each output tensor
must be batched on its first dimension. `Transformers
<https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`_ and `RNNs
<https://docs.pytorch.org/docs/stable/generated/torch.nn.RNN.html>`_ are thus not
supported yet. This is only an implementation issue, so it should be fixed soon (please
Expand Down
27 changes: 18 additions & 9 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
self.gramian_accumulator = gramian_accumulator
self.has_batch_dim = has_batch_dim

def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree:
def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) -> PyTree:
if self.gramian_accumulation_phase:
return outputs

Expand Down Expand Up @@ -129,7 +129,14 @@ def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree:
index = cast(int, preference.argmin().item())
self.target_edges.register(get_gradient_edge(rg_outputs[index]))

vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, rg_outputs)
vjp: VJP
if self.has_batch_dim:
rg_outputs_in_dims = (0,) * len(rg_outputs)
args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
in_dims = (rg_outputs_in_dims, args_in_dims)
vjp = FunctionalVJP(module, in_dims)
else:
vjp = AutogradVJP(module, rg_outputs)

autograd_fn_rg_outputs = JacobianAccumulator.apply(
self.gramian_accumulation_phase,
Expand Down Expand Up @@ -161,15 +168,15 @@ class JacobianAccumulator(torch.autograd.Function):
def forward(
gramian_accumulation_phase: BoolRef,
vjp: VJP,
args: PyTree,
args: tuple[PyTree, ...],
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*rg_tensors: Tensor,
) -> tuple[Tensor, ...]:
return tuple(t.detach() for t in rg_tensors)

# For Python version > 3.10, the type of `inputs` should become
# tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
# tuple[BoolRef, VJP, tuple[PyTree, ...], GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
@staticmethod
def setup_context(
ctx,
Expand All @@ -183,7 +190,9 @@ def setup_context(
ctx.module = inputs[4]

@staticmethod
def backward(ctx, *grad_outputs: Tensor):
def backward(ctx, *grad_outputs: Tensor) -> tuple:
# Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]]

if not ctx.gramian_accumulation_phase:
return None, None, None, None, None, *grad_outputs

Expand All @@ -203,7 +212,7 @@ class AccumulateJacobian(torch.autograd.Function):
@staticmethod
def forward(
vjp: VJP,
args: PyTree,
args: tuple[PyTree, ...],
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*grad_outputs: Tensor,
Expand All @@ -216,9 +225,9 @@ def forward(
@staticmethod
def vmap(
_,
in_dims: PyTree,
in_dims: tuple, # tuple[None, tuple[PyTree, ...], None, None, *tuple[int | None, ...]]
vjp: VJP,
args: PyTree,
args: tuple[PyTree, ...],
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*jac_outputs: Tensor,
Expand All @@ -244,5 +253,5 @@ def _make_path_jacobians(
return path_jacobians

@staticmethod
def setup_context(*_):
def setup_context(*_) -> None:
pass
18 changes: 12 additions & 6 deletions src/torchjd/autogram/_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ class VJP(ABC):
"""Represents an abstract VJP function."""

@abstractmethod
def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]:
def __call__(
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...]
) -> dict[str, Tensor]:
"""
Computes and returns the dictionary of parameter names to their gradients for the given
grad_outputs (cotangents) and at the given inputs.
Expand Down Expand Up @@ -52,15 +54,17 @@ class FunctionalVJP(ModuleVJP):
every module, and it requires to have an extra forward pass to create the vjp function.
"""

def __init__(self, module: nn.Module):
def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]):
super().__init__(module)
self.vmapped_vjp = torch.vmap(self._call_on_one_instance)
self.vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims)

def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]:
def __call__(
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...]
) -> dict[str, Tensor]:
return self.vmapped_vjp(grad_outputs, args)

def _call_on_one_instance(
self, grad_outputs_j: tuple[Tensor, ...], args_j: PyTree
self, grad_outputs_j: tuple[Tensor, ...], args_j: tuple[PyTree, ...]
) -> dict[str, Tensor]:
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
Expand Down Expand Up @@ -103,7 +107,9 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]):
self.rg_outputs = rg_outputs
self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params)

def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]:
def __call__(
self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...]
) -> dict[str, Tensor]:
grads = torch.autograd.grad(
self.rg_outputs,
self.flat_trainable_params,
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
WithBuffered,
WithDropout,
WithModuleTrackingRunningStats,
WithModuleWithHybridPyTreeArg,
WithModuleWithStringArg,
WithModuleWithStringOutput,
WithNoTensorOutput,
Expand Down Expand Up @@ -109,6 +110,8 @@
(Ndim3Output, 32),
(Ndim4Output, 32),
(WithDropout, 32),
(WithModuleWithStringArg, 32),
(WithModuleWithHybridPyTreeArg, 32),
(WithModuleWithStringOutput, 32),
(FreeParam, 32),
(NoFreeParam, 32),
Expand Down Expand Up @@ -167,7 +170,6 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc
Randomness,
WithModuleTrackingRunningStats,
param(WithRNN, marks=mark.xfail_if_cuda),
WithModuleWithStringArg,
],
)
@mark.parametrize("batch_size", [1, 3, 32])
Expand Down
50 changes: 50 additions & 0 deletions tests/utils/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,56 @@ def forward(self, input: Tensor) -> Tensor:
return self.with_string_arg("two", input)


class WithModuleWithHybridPyTreeArg(ShapedModule):
"""
Model containing a module that has a PyTree argument containing a mix of tensor and non-tensor
leaves.
"""

INPUT_SHAPES = (10,)
OUTPUT_SHAPES = (3,)

class WithHybridPyTreeArg(nn.Module):
def __init__(self):
super().__init__()
self.m0 = nn.Parameter(torch.randn(3, 3))
self.m1 = nn.Parameter(torch.randn(4, 3))
self.m2 = nn.Parameter(torch.randn(5, 3))
self.m3 = nn.Parameter(torch.randn(6, 3))

def forward(self, input: PyTree) -> Tensor:
t0 = input["one"][0][0]
t1 = input["one"][0][1]
t2 = input["one"][1]
t3 = input["two"]

c0 = input["one"][0][3]
c1 = input["one"][0][4][0]
c2 = input["one"][2]
c3 = input["three"]

return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3

def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 18)
self.with_string_arg = self.WithHybridPyTreeArg()

def forward(self, input: Tensor) -> Tensor:
input = self.linear(input)

t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18]

tree = {
"zero": "unused",
"one": [(t0, t1, "unused", 0.2, [0.3, "unused"]), t2, 0.4, "unused"],
"two": t3,
"three": 0.5,
}

return self.with_string_arg(tree)


class WithModuleWithStringOutput(ShapedModule):
"""Model containing a module that has a string output."""

Expand Down
Loading