Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def hook_module(self, module: nn.Module) -> None:
self._gramian_accumulator,
self._has_batch_dim,
)
self._handles.append(module.register_forward_hook(hook))
self._handles.append(module.register_forward_hook(hook, with_kwargs=True))

@staticmethod
def remove_hooks(handles: list[TorchRemovableHandle]) -> None:
Expand Down Expand Up @@ -101,7 +101,13 @@ def __init__(
self.gramian_accumulator = gramian_accumulator
self.has_batch_dim = has_batch_dim

def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree) -> PyTree:
def __call__(
self,
module: nn.Module,
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
outputs: PyTree,
) -> PyTree:
if self.gramian_accumulation_phase:
return outputs

Expand Down Expand Up @@ -131,9 +137,10 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree)

vjp: VJP
if self.has_batch_dim:
rg_outputs_in_dims = (0,) * len(rg_outputs)
args_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
in_dims = (rg_outputs_in_dims, args_in_dims)
rg_output_in_dims = (0,) * len(rg_outputs)
arg_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, args)
kwargs_in_dims = tree_map(lambda t: 0 if isinstance(t, Tensor) else None, kwargs)
in_dims = (rg_output_in_dims, arg_in_dims, kwargs_in_dims)
vjp = FunctionalVJP(module, in_dims)
else:
vjp = AutogradVJP(module, rg_outputs)
Expand All @@ -142,6 +149,7 @@ def __call__(self, module: nn.Module, args: tuple[PyTree, ...], outputs: PyTree)
self.gramian_accumulation_phase,
vjp,
args,
kwargs,
self.gramian_accumulator,
module,
*rg_outputs,
Expand Down Expand Up @@ -169,14 +177,15 @@ def forward(
gramian_accumulation_phase: BoolRef,
vjp: VJP,
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*rg_tensors: Tensor,
) -> tuple[Tensor, ...]:
return tuple(t.detach() for t in rg_tensors)

# For Python version > 3.10, the type of `inputs` should become
# tuple[BoolRef, VJP, tuple[PyTree, ...], GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
# tuple[BoolRef, VJP, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
@staticmethod
def setup_context(
ctx,
Expand All @@ -186,25 +195,27 @@ def setup_context(
ctx.gramian_accumulation_phase = inputs[0]
ctx.vjp = inputs[1]
ctx.args = inputs[2]
ctx.gramian_accumulator = inputs[3]
ctx.module = inputs[4]
ctx.kwargs = inputs[3]
ctx.gramian_accumulator = inputs[4]
ctx.module = inputs[5]

@staticmethod
def backward(ctx, *grad_outputs: Tensor) -> tuple:
# Return type for python > 3.10: # tuple[None, None, None, None, None, *tuple[Tensor, ...]]
# For python > 3.10: -> tuple[None, None, None, None, None, None, *tuple[Tensor, ...]]

if not ctx.gramian_accumulation_phase:
return None, None, None, None, None, *grad_outputs
return None, None, None, None, None, None, *grad_outputs

AccumulateJacobian.apply(
ctx.vjp,
ctx.args,
ctx.kwargs,
ctx.gramian_accumulator,
ctx.module,
*grad_outputs,
)

return None, None, None, None, None, *grad_outputs
return None, None, None, None, None, None, *grad_outputs


class AccumulateJacobian(torch.autograd.Function):
Expand All @@ -213,29 +224,31 @@ class AccumulateJacobian(torch.autograd.Function):
def forward(
vjp: VJP,
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*grad_outputs: Tensor,
) -> None:
# There is no non-batched dimension
generalized_jacobians = vjp(grad_outputs, args)
generalized_jacobians = vjp(grad_outputs, args, kwargs)
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
gramian_accumulator.accumulate_path_jacobians(path_jacobians)

@staticmethod
def vmap(
_,
in_dims: tuple, # tuple[None, tuple[PyTree, ...], None, None, *tuple[int | None, ...]]
in_dims: tuple, # tuple[None, tuple[PyTree, ...], dict[str, PyTree], None, None, *tuple[int | None, ...]]
vjp: VJP,
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*jac_outputs: Tensor,
) -> tuple[None, None]:
# There is a non-batched dimension
# We do not vmap over the args for the non-batched dimension
in_dims = (in_dims[4:], tree_map(lambda _: None, args))
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args)
in_dims = (in_dims[5:], tree_map(lambda _: None, args), tree_map(lambda _: None, kwargs))
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args, kwargs)
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
return None, None
Expand Down
16 changes: 10 additions & 6 deletions src/torchjd/autogram/_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class VJP(ABC):

@abstractmethod
def __call__(
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...]
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree]
) -> dict[str, Tensor]:
"""
Computes and returns the dictionary of parameter names to their gradients for the given
Expand Down Expand Up @@ -59,19 +59,23 @@ def __init__(self, module: nn.Module, in_dims: tuple[PyTree, ...]):
self.vmapped_vjp = torch.vmap(self._call_on_one_instance, in_dims=in_dims)

def __call__(
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...]
self, grad_outputs: tuple[Tensor, ...], args: tuple[PyTree, ...], kwargs: dict[str, PyTree]
) -> dict[str, Tensor]:
return self.vmapped_vjp(grad_outputs, args)
return self.vmapped_vjp(grad_outputs, args, kwargs)

def _call_on_one_instance(
self, grad_outputs_j: tuple[Tensor, ...], args_j: tuple[PyTree, ...]
self,
grad_outputs_j: tuple[Tensor, ...],
args_j: tuple[PyTree, ...],
kwargs_j: dict[str, PyTree],
) -> dict[str, Tensor]:
# Note: we use unsqueeze(0) to turn a single activation (or grad_output) into a
# "batch" of 1 activation (or grad_output). This is because some layers (e.g.
# nn.Flatten) do not work equivalently if they're provided with a batch or with
# an element of a batch. We thus always provide them with batches, just of a
# different size.
args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j)
kwargs_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), kwargs_j)
grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j]

def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]:
Expand All @@ -80,7 +84,7 @@ def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor
**dict(self.module.named_buffers()),
**self.frozen_params,
}
output = torch.func.functional_call(self.module, all_state, args_j)
output = torch.func.functional_call(self.module, all_state, args_j, kwargs_j)
flat_outputs = tree_flatten(output)[0]
rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad]
return rg_outputs
Expand Down Expand Up @@ -108,7 +112,7 @@ def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]):
self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params)

def __call__(
self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...]
self, grad_outputs: tuple[Tensor, ...], _: tuple[PyTree, ...], __: dict[str, PyTree]
) -> dict[str, Tensor]:
grads = torch.autograd.grad(
self.rg_outputs,
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
WithDropout,
WithModuleTrackingRunningStats,
WithModuleWithHybridPyTreeArg,
WithModuleWithHybridPyTreeKwarg,
WithModuleWithStringArg,
WithModuleWithStringKwarg,
WithModuleWithStringOutput,
WithNoTensorOutput,
WithRNN,
Expand Down Expand Up @@ -113,6 +115,8 @@
(WithModuleWithStringArg, 32),
(WithModuleWithHybridPyTreeArg, 32),
(WithModuleWithStringOutput, 32),
(WithModuleWithStringKwarg, 32),
(WithModuleWithHybridPyTreeKwarg, 32),
(FreeParam, 32),
(NoFreeParam, 32),
param(Cifar10Model, 16, marks=mark.slow),
Expand Down
111 changes: 78 additions & 33 deletions tests/utils/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,29 +772,66 @@ def forward(self, input: Tensor) -> Tensor:
return input @ self.linear.weight.T + self.linear.bias


class _WithStringArg(nn.Module):
def __init__(self):
super().__init__()
self.matrix = nn.Parameter(torch.randn(2, 3))

def forward(self, s: str, input: Tensor) -> Tensor:
if s == "two":
return input @ self.matrix * 2.0
else:
return input @ self.matrix


class WithModuleWithStringArg(ShapedModule):
"""Model containing a module that has a string argument."""

INPUT_SHAPES = (2,)
OUTPUT_SHAPES = (3,)

class WithStringArg(nn.Module):
def __init__(self):
super().__init__()
self.matrix = nn.Parameter(torch.randn(2, 3))
def __init__(self):
super().__init__()
self.with_string_arg = _WithStringArg()

def forward(self, s: str, input: Tensor) -> Tensor:
if s == "two":
return input @ self.matrix * 2.0
else:
return input @ self.matrix
def forward(self, input: Tensor) -> Tensor:
return self.with_string_arg("two", input)


class WithModuleWithStringKwarg(ShapedModule):
"""Model calling its submodule's forward with a string and a tensor as keyword arguments."""

INPUT_SHAPES = (2,)
OUTPUT_SHAPES = (3,)

def __init__(self):
super().__init__()
self.with_string_arg = self.WithStringArg()
self.with_string_arg = _WithStringArg()

def forward(self, input: Tensor) -> Tensor:
return self.with_string_arg("two", input)
return self.with_string_arg(s="two", input=input)


class _WithHybridPyTreeArg(nn.Module):
def __init__(self):
super().__init__()
self.m0 = nn.Parameter(torch.randn(3, 3))
self.m1 = nn.Parameter(torch.randn(4, 3))
self.m2 = nn.Parameter(torch.randn(5, 3))
self.m3 = nn.Parameter(torch.randn(6, 3))

def forward(self, input: PyTree) -> Tensor:
t0 = input["one"][0][0]
t1 = input["one"][0][1]
t2 = input["one"][1]
t3 = input["two"]

c0 = input["one"][0][3]
c1 = input["one"][0][4][0]
c2 = input["one"][2]
c3 = input["three"]

return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3


class WithModuleWithHybridPyTreeArg(ShapedModule):
Expand All @@ -806,31 +843,39 @@ class WithModuleWithHybridPyTreeArg(ShapedModule):
INPUT_SHAPES = (10,)
OUTPUT_SHAPES = (3,)

class WithHybridPyTreeArg(nn.Module):
def __init__(self):
super().__init__()
self.m0 = nn.Parameter(torch.randn(3, 3))
self.m1 = nn.Parameter(torch.randn(4, 3))
self.m2 = nn.Parameter(torch.randn(5, 3))
self.m3 = nn.Parameter(torch.randn(6, 3))
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 18)
self.with_string_arg = _WithHybridPyTreeArg()

def forward(self, input: PyTree) -> Tensor:
t0 = input["one"][0][0]
t1 = input["one"][0][1]
t2 = input["one"][1]
t3 = input["two"]
def forward(self, input: Tensor) -> Tensor:
input = self.linear(input)

c0 = input["one"][0][3]
c1 = input["one"][0][4][0]
c2 = input["one"][2]
c3 = input["three"]
t0, t1, t2, t3 = input[:, 0:3], input[:, 3:7], input[:, 7:12], input[:, 12:18]

return c0 * t0 @ self.m0 + c1 * t1 @ self.m1 + c2 * t2 @ self.m2 + c3 * t3 @ self.m3
tree = {
"zero": "unused",
"one": [(t0, t1, "unused", 0.2, [0.3, "unused"]), t2, 0.4, "unused"],
"two": t3,
"three": 0.5,
}

return self.with_string_arg(tree)


class WithModuleWithHybridPyTreeKwarg(ShapedModule):
"""
Model calling its submodule's forward with a PyTree keyword argument containing a mix of tensors
and non-tensor values.
"""

INPUT_SHAPES = (10,)
OUTPUT_SHAPES = (3,)

def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 18)
self.with_string_arg = self.WithHybridPyTreeArg()
self.with_string_arg = _WithHybridPyTreeArg()

def forward(self, input: Tensor) -> Tensor:
input = self.linear(input)
Expand All @@ -844,7 +889,7 @@ def forward(self, input: Tensor) -> Tensor:
"three": 0.5,
}

return self.with_string_arg(tree)
return self.with_string_arg(input=tree)


class WithModuleWithStringOutput(ShapedModule):
Expand Down Expand Up @@ -1067,11 +1112,11 @@ def forward(self, input: Tensor) -> Tensor:


# Other torchvision.models were not added for the following reasons:
# - VGG16: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok.
# - VGG16: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok.
# - DenseNet: no way to easily replace the BatchNorms (no norm_layer param)
# - InceptionV3: no way to easily replace the BatchNorms (no norm_layer param)
# - GoogleNet: no way to easily replace the BatchNorms (no norm_layer param)
# - ShuffleNetV2: no way to easily replace the BatchNorms (no norm_layer param)
# - ResNeXt: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok.
# - WideResNet50: Sometimes takes to much memory on autojac even with bs=2, nut autogram seems ok.
# - ResNeXt: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok.
# - WideResNet50: Sometimes takes to much memory on autojac even with bs=2, but autogram seems ok.
# - MNASNet: no way to easily replace the BatchNorms (no norm_layer param)
Loading