diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index d0b78a79..9f4285ee 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -65,7 +65,7 @@ def hook_module(self, module: nn.Module) -> None: self._gramian_accumulator, self._has_batch_dim, ) - self._handles.append(module.register_forward_hook(hook)) + self._handles.append(module.register_forward_hook(hook, with_kwargs=True)) @staticmethod def remove_hooks(handles: list[TorchRemovableHandle]) -> None: @@ -101,7 +101,13 @@ def __init__( self.gramian_accumulator = gramian_accumulator self.has_batch_dim = has_batch_dim - def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) -> PyTree: + def __call__( + self, + module: nn.Module, + args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], + outputs: PyTree, + ) -> PyTree: if self.gramian_accumulation_phase: return outputs @@ -131,9 +137,10 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) vjp: VJP if self.has_batch_dim: - rg_outputs_in_dims = (0,) * len(rg_outputs) - args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) - in_dims = (rg_outputs_in_dims, args_in_dims) + rg_output_in_dims = (0,) * len(rg_outputs) + arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args) + kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs) + in_dims = (rg_output_in_dims, arg_in_dims, kwargs_in_dims) vjp = FunctionalVJP(module, in_dims) else: vjp = AutogradVJP(module, rg_outputs) @@ -142,6 +149,7 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) self.gramian_accumulation_phase, vjp, args, + kwargs, self.gramian_accumulator, module, *rg_outputs, @@ -169,6 +177,7 @@ def forward( gramian_accumulation_phase: BoolRef, vjp: VJP, args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], gramian_accumulator: GramianAccumulator, module: nn.Module, *rg_tensors: Tensor, @@ -176,7 +185,7 @@ def forward( return tuple(t.detach() for t in rg_tensors) # For Python version > 3.10, the type of `inputs` should become - # tuple[BoolRef, VJP, tuple[PyTree, ...], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] + # tuple[BoolRef, VJP, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, nn.Module, *tuple[Tensor, ...]] @staticmethod def setup_context( ctx, @@ -186,25 +195,27 @@ def setup_context( ctx.gramian_accumulation_phase = inputs[0] ctx.vjp = inputs[1] ctx.args = inputs[2] - ctx.gramian_accumulator = inputs[3] - ctx.module = inputs[4] + ctx.kwargs = inputs[3] + ctx.gramian_accumulator = inputs[4] + ctx.module = inputs[5] @staticmethod def backward(ctx, *grad_outputs: Tensor) -> tuple: - # Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]] + # For python > 3.10: -> tuple[None, None, None, None, None, None, *tuple[Tensor, ...]] if not ctx.gramian_accumulation_phase: - return None, None, None, None, None, *grad_outputs + return None, None, None, None, None, None, *grad_outputs AccumulateJacobian.apply( ctx.vjp, ctx.args, + ctx.kwargs, ctx.gramian_accumulator, ctx.module, *grad_outputs, ) - return None, None, None, None, None, *grad_outputs + return None, None, None, None, None, None, *grad_outputs class AccumulateJacobian(torch.autograd.Function): @@ -213,29 +224,31 @@ class AccumulateJacobian(torch.autograd.Function): def forward( vjp: VJP, args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], gramian_accumulator: GramianAccumulator, module: nn.Module, *grad_outputs: Tensor, ) -> None: # There is no non-batched dimension - generalized_jacobians = vjp(grad_outputs, args) + generalized_jacobians = vjp(grad_outputs, args, kwargs) path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) gramian_accumulator.accumulate_path_jacobians(path_jacobians) @staticmethod def vmap( _, - in_dims: tuple, # tuple[None, tuple[PyTree, ...], None, None, *tuple[int | None, ...]] + in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, None, *tuple[int | None, ...]] vjp: VJP, args: tuple[PyTree, ...], + kwargs: dict[str, PyTree], gramian_accumulator: GramianAccumulator, module: nn.Module, *jac_outputs: Tensor, ) -> tuple[None, None]: # There is a non-batched dimension # We do not vmap over the args for the non-batched dimension - in_dims = (in_dims[4:], tree_map(lambda _: None, args)) - generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) + in_dims = (in_dims[5:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs)) + generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args, kwargs) path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) gramian_accumulator.accumulate_path_jacobians(path_jacobians) return None, None diff --git a/src/torchjd/autogram/_vjp.py b/src/torchjd/autogram/_vjp.py index 442b7fe9..3a94543c 100644 --- a/src/torchjd/autogram/_vjp.py +++ b/src/torchjd/autogram/_vjp.py @@ -20,7 +20,7 @@ class VJP(ABC): @abstractmethod def __call__( - self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree] ) -> dict[str, Tensor]: """ Computes and returns the dictionary of parameter names to their gradients for the given @@ -59,12 +59,15 @@ def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]): self.vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims) def __call__( - self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...] + self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree] ) -> dict[str, Tensor]: - return self.vmapped_vjp(grad_outputs, args) + return self.vmapped_vjp(grad_outputs, args, kwargs) def _call_on_one_instance( - self, grad_outputs_j: tuple[Tensor, ...], args_j: tuple[PyTree, ...] + self, + grad_outputs_j: tuple[Tensor, ...], + args_j: tuple[PyTree, ...], + kwargs_j: dict[str, PyTree], ) -> dict[str, Tensor]: # Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a # "batch" of 1 activation (or grad_output). This is because some layers (e.g. @@ -72,6 +75,7 @@ def _call_on_one_instance( # an element of a batch. We thus always provide them with batches, just of a # different size. args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j) + kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j) grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j] def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]: @@ -80,7 +84,7 @@ def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor **dict(self.module.named_buffers()), **self.frozen_params, } - output = torch.func.functional_call(self.module, all_state, args_j) + output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j) flat_outputs = tree_flatten(output)[0] rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad] return rg_outputs @@ -108,7 +112,7 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]): self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params) def __call__( - self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...] + self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree] ) -> dict[str, Tensor]: grads = torch.autograd.grad( self.rg_outputs, diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index fb37863c..7c88ae9d 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -52,7 +52,9 @@ WithDropout, WithModuleTrackingRunningStats, WithModuleWithHybridPyTreeArg, + WithModuleWithHybridPyTreeKwarg, WithModuleWithStringArg, + WithModuleWithStringKwarg, WithModuleWithStringOutput, WithNoTensorOutput, WithRNN, @@ -113,6 +115,8 @@ (WithModuleWithStringArg, 32), (WithModuleWithHybridPyTreeArg, 32), (WithModuleWithStringOutput, 32), + (WithModuleWithStringKwarg, 32), + (WithModuleWithHybridPyTreeKwarg, 32), (FreeParam, 32), (NoFreeParam, 32), param(Cifar10Model, 16, marks=mark.slow), diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index ef4933f5..3dc126ad 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -772,29 +772,66 @@ def forward(self, input: Tensor) -> Tensor: return input @ self.linear.weight.T + self.linear.bias +class _WithStringArg(nn.Module): + def __init__(self): + super().__init__() + self.matrix = nn.Parameter(torch.randn(2, 3)) + + def forward(self, s: str, input: Tensor) -> Tensor: + if s == "two": + return input @ self.matrix * 2.0 + else: + return input @ self.matrix + + class WithModuleWithStringArg(ShapedModule): """Model containing a module that has a string argument.""" INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - class WithStringArg(nn.Module): - def __init__(self): - super().__init__() - self.matrix = nn.Parameter(torch.randn(2, 3)) + def __init__(self): + super().__init__() + self.with_string_arg = _WithStringArg() - def forward(self, s: str, input: Tensor) -> Tensor: - if s == "two": - return input @ self.matrix * 2.0 - else: - return input @ self.matrix + def forward(self, input: Tensor) -> Tensor: + return self.with_string_arg("two", input) + + +class WithModuleWithStringKwarg(ShapedModule): + """Model calling its submodule's forward with a string and a tensor as keyword arguments.""" + + INPUT_SHAPES = (2,) + OUTPUT_SHAPES = (3,) def __init__(self): super().__init__() - self.with_string_arg = self.WithStringArg() + self.with_string_arg = _WithStringArg() def forward(self, input: Tensor) -> Tensor: - return self.with_string_arg("two", input) + return self.with_string_arg(s="two", input=input) + + +class _WithHybridPyTreeArg(nn.Module): + def __init__(self): + super().__init__() + self.m0 = nn.Parameter(torch.randn(3, 3)) + self.m1 = nn.Parameter(torch.randn(4, 3)) + self.m2 = nn.Parameter(torch.randn(5, 3)) + self.m3 = nn.Parameter(torch.randn(6, 3)) + + def forward(self, input: PyTree) -> Tensor: + t0 = input["one"][0][0] + t1 = input["one"][0][1] + t2 = input["one"][1] + t3 = input["two"] + + c0 = input["one"][0][3] + c1 = input["one"][0][4][0] + c2 = input["one"][2] + c3 = input["three"] + + return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3 class WithModuleWithHybridPyTreeArg(ShapedModule): @@ -806,31 +843,39 @@ class WithModuleWithHybridPyTreeArg(ShapedModule): INPUT_SHAPES = (10,) OUTPUT_SHAPES = (3,) - class WithHybridPyTreeArg(nn.Module): - def __init__(self): - super().__init__() - self.m0 = nn.Parameter(torch.randn(3, 3)) - self.m1 = nn.Parameter(torch.randn(4, 3)) - self.m2 = nn.Parameter(torch.randn(5, 3)) - self.m3 = nn.Parameter(torch.randn(6, 3)) + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 18) + self.with_string_arg = _WithHybridPyTreeArg() - def forward(self, input: PyTree) -> Tensor: - t0 = input["one"][0][0] - t1 = input["one"][0][1] - t2 = input["one"][1] - t3 = input["two"] + def forward(self, input: Tensor) -> Tensor: + input = self.linear(input) - c0 = input["one"][0][3] - c1 = input["one"][0][4][0] - c2 = input["one"][2] - c3 = input["three"] + t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18] - return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3 + tree = { + "zero": "unused", + "one": [(t0, t1, "unused", 0.2, [0.3, "unused"]), t2, 0.4, "unused"], + "two": t3, + "three": 0.5, + } + + return self.with_string_arg(tree) + + +class WithModuleWithHybridPyTreeKwarg(ShapedModule): + """ + Model calling its submodule's forward with a PyTree keyword argument containing a mix of tensors + and non-tensor values. + """ + + INPUT_SHAPES = (10,) + OUTPUT_SHAPES = (3,) def __init__(self): super().__init__() self.linear = nn.Linear(10, 18) - self.with_string_arg = self.WithHybridPyTreeArg() + self.with_string_arg = _WithHybridPyTreeArg() def forward(self, input: Tensor) -> Tensor: input = self.linear(input) @@ -844,7 +889,7 @@ def forward(self, input: Tensor) -> Tensor: "three": 0.5, } - return self.with_string_arg(tree) + return self.with_string_arg(input=tree) class WithModuleWithStringOutput(ShapedModule): @@ -1067,11 +1112,11 @@ def forward(self, input: Tensor) -> Tensor: # Other torchvision.models were not added for the following reasons: -# - VGG16: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok. +# - VGG16: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok. # - DenseNet: no way to easily replace the BatchNorms (no norm_layer param) # - InceptionV3: no way to easily replace the BatchNorms (no norm_layer param) # - GoogleNet: no way to easily replace the BatchNorms (no norm_layer param) # - ShuffleNetV2: no way to easily replace the BatchNorms (no norm_layer param) -# - ResNeXt: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok. -# - WideResNet50: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok. +# - ResNeXt: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok. +# - WideResNet50: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok. # - MNASNet: no way to easily replace the BatchNorms (no norm_layer param)