diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index be7104f0..90cd2582 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -120,9 +120,8 @@ class Engine: this, but for example `BatchNorm `_ does not (it computes some average and standard deviation over the elements of the batch). - * Their inputs and outputs can be any PyTree (tensor, tuple or list of tensors, dict of - tensors, or any nesting of those structures), but each of these tensors must be batched on - its first dimension. `Transformers + * Their inputs and outputs can be anything, but each input tensor and each output tensor + must be batched on its first dimension. `Transformers `_ and `RNNs `_ are thus not supported yet. This is only an implementation issue, so it should be fixed soon (please diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 3cc4cf6b..d0b78a79 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -101,7 +101,7 @@ def __init__( self.gramian_accumulator = gramian_accumulator self.has_batch_dim = has_batch_dim - def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree: + def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) -> PyTree: if self.gramian_accumulation_phase: return outputs @@ -129,7 +129,14 @@ def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree: index = cast(int, preference.argmin().item()) self.target_edges.register(get_gradient_edge(rg_outputs[index])) - vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, rg_outputs) + vjp: VJP + if self.has_batch_dim: + rg_outputs_in_dims = (0,) * len(rg_outputs) + args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) + in_dims = (rg_outputs_in_dims, args_in_dims) + vjp = FunctionalVJP(module, in_dims) + else: + vjp = AutogradVJP(module, rg_outputs) autograd_fn_rg_outputs = JacobianAccumulator.apply( self.gramian_accumulation_phase, @@ -161,7 +168,7 @@ class JacobianAccumulator(torch.autograd.Function): def forward( gramian_accumulation_phase: BoolRef, vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *rg_tensors: Tensor, @@ -169,7 +176,7 @@ def forward( return tuple(t.detach() for t in rg_tensors) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, VJP, tuple[PyTree, ...], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, @@ -183,7 +190,9 @@ def setup_context( ctx.module = inputs[4] @staticmethod - def backward(ctx, *grad_outputs: Tensor): + def backward(ctx, *grad_outputs: Tensor) -> tuple: + # Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]] + if not ctx.gramian_accumulation_phase: return None, None, None, None, None, *grad_outputs @@ -203,7 +212,7 @@ class AccumulateJacobian(torch.autograd.Function): @staticmethod def forward( vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *grad_outputs: Tensor, @@ -216,9 +225,9 @@ def forward( @staticmethod def vmap( _, - in_dims: PyTree, + in_dims: tuple, # tuple[None, tuple[PyTree, ...], None, None, *tuple[int | None, ...]] vjp: VJP, - args: PyTree, + args: tuple[PyTree, ...], gramian_accumulator: GramianAccumulator, module: nn.Module, *jac_outputs: Tensor, @@ -244,5 +253,5 @@ def _make_path_jacobians( return path_jacobians @staticmethod - def setup_context(*_): + def setup_context(*_) -> None: pass diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 4e6c4985..442b7fe9 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -19,7 +19,9 @@ class VJP(ABC): """Represents an abstract VJP function.""" @abstractmethod - def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: + def __call__( + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] + ) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given grad_outputs (cotangents) and at the given inputs. @@ -52,15 +54,17 @@ class FunctionalVJP(ModuleVJP): every module, and it requires to have an extra forward pass to create the vjp function. """ - def __init__(self, module: nn.Module): + def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]): super().__init__(module) - self.vmapped_vjp = torch.vmap(self._call_on_one_instance) + self.vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) - def __call__(self, grad_outputs: tuple[Tensor, ...], args: PyTree) -> dict[str, Tensor]: + def __call__( + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] + ) -> dict[str, Tensor]: return self.vmapped_vjp(grad_outputs, args) def _call_on_one_instance( - self, grad_outputs_j: tuple[Tensor, ...], args_j: PyTree + self, grad_outputs_j: tuple[Tensor, ...], args_j: tuple[PyTree, ...] ) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. @@ -103,7 +107,9 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]): self.rg_outputs = rg_outputs self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) - def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]: + def __call__( + self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...] + ) -> dict[str, Tensor]: grads = torch.autograd.grad( self.rg_outputs, self.flat_trainable_params, diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 74084f3a..fb37863c 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -51,6 +51,7 @@ WithBuffered, WithDropout, WithModuleTrackingRunningStats, + WithModuleWithHybridPyTreeArg, WithModuleWithStringArg, WithModuleWithStringOutput, WithNoTensorOutput, @@ -109,6 +110,8 @@ (Ndim3Output, 32), (Ndim4Output, 32), (WithDropout, 32), + (WithModuleWithStringArg, 32), + (WithModuleWithHybridPyTreeArg, 32), (WithModuleWithStringOutput, 32), (FreeParam, 32), (NoFreeParam, 32), @@ -167,7 +170,6 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc Randomness, WithModuleTrackingRunningStats, param(WithRNN, marks=mark.xfail_if_cuda), - WithModuleWithStringArg, ], ) @mark.parametrize("batch_size", [1, 3, 32]) diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index 073420d2..ef4933f5 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -797,6 +797,56 @@ def forward(self, input: Tensor) -> Tensor: return self.with_string_arg("two", input) +class WithModuleWithHybridPyTreeArg(ShapedModule): + """ + Model containing a module that has a PyTree argument containing a mix of tensor and non-tensor + leaves. + """ + + INPUT_SHAPES = (10,) + OUTPUT_SHAPES = (3,) + + class WithHybridPyTreeArg(nn.Module): + def __init__(self): + super().__init__() + self.m0 = nn.Parameter(torch.randn(3, 3)) + self.m1 = nn.Parameter(torch.randn(4, 3)) + self.m2 = nn.Parameter(torch.randn(5, 3)) + self.m3 = nn.Parameter(torch.randn(6, 3)) + + def forward(self, input: PyTree) -> Tensor: + t0 = input["one"][0][0] + t1 = input["one"][0][1] + t2 = input["one"][1] + t3 = input["two"] + + c0 = input["one"][0][3] + c1 = input["one"][0][4][0] + c2 = input["one"][2] + c3 = input["three"] + + return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3 + + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 18) + self.with_string_arg = self.WithHybridPyTreeArg() + + def forward(self, input: Tensor) -> Tensor: + input = self.linear(input) + + t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18] + + tree = { + "zero": "unused", + "one": [(t0, t1, "unused", 0.2, [0.3, "unused"]), t2, 0.4, "unused"], + "two": t3, + "three": 0.5, + } + + return self.with_string_arg(tree) + + class WithModuleWithStringOutput(ShapedModule): """Model containing a module that has a string output."""