Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,40 +101,49 @@ def __init__(
self.gramian_accumulator = gramian_accumulator
self.has_batch_dim = has_batch_dim

def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
def __call__(self, module: nn.Module, args: PyTree, outputs: PyTree) -> PyTree:
if self.gramian_accumulation_phase:
return output
return outputs

flat_outputs, output_spec = tree_flatten(output)
flat_outputs, output_spec = tree_flatten(outputs)

if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs):
rg_outputs = list[Tensor]()
rg_output_indices = list[int]()
for idx, output in enumerate(flat_outputs):
if isinstance(output, Tensor) and output.requires_grad:
rg_outputs.append(output)
rg_output_indices.append(idx)

if len(rg_outputs) == 0:
# This can happen only if a module has a trainable param but outputs no tensor that
# require grad
return output
return outputs

requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
self.gramian_accumulator.track_parameter_paths(requires_grad_params)

# We only care about running the JacobianAccumulator node, so we need one of its child
# edges (the edges of the original outputs of the model) as target. For memory
# efficiency, we select the smallest one (that requires grad).
inf = float("inf")
preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs])
preference = torch.tensor([t.numel() for t in rg_outputs])
index = cast(int, preference.argmin().item())
self.target_edges.register(get_gradient_edge(flat_outputs[index]))
self.target_edges.register(get_gradient_edge(rg_outputs[index]))

vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs)
vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, rg_outputs)

autograd_fn_outputs = JacobianAccumulator.apply(
autograd_fn_rg_outputs = JacobianAccumulator.apply(
self.gramian_accumulation_phase,
vjp,
args,
self.gramian_accumulator,
module,
*flat_outputs,
*rg_outputs,
)

return tree_unflatten(autograd_fn_outputs, output_spec)
for idx, output in zip(rg_output_indices, autograd_fn_rg_outputs):
flat_outputs[idx] = output

return tree_unflatten(flat_outputs, output_spec)


class JacobianAccumulator(torch.autograd.Function):
Expand All @@ -155,9 +164,9 @@ def forward(
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*xs: Tensor,
*rg_tensors: Tensor,
) -> tuple[Tensor, ...]:
return tuple(x.detach() for x in xs)
return tuple(t.detach() for t in rg_tensors)

# For Python version > 3.10, the type of `inputs` should become
# tuple[BoolRef, VJP, PyTree, GramianAccumulator, nn.Module, *tuple[Tensor, ...]]
Expand Down
27 changes: 9 additions & 18 deletions src/torchjd/autogram/_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,18 @@ def _call_on_one_instance(
args_j = tree_map_only(torch.Tensor, lambda x: x.unsqueeze(0), args_j)
grad_outputs_j_ = [x.unsqueeze(0) for x in grad_outputs_j]

def flat_functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]:
def functional_model_call(trainable_params: dict[str, Parameter]) -> list[Tensor]:
all_state = {
**trainable_params,
**dict(self.module.named_buffers()),
**self.frozen_params,
}
output = torch.func.functional_call(self.module, all_state, args_j)
return tree_flatten(output)[0]
flat_outputs = tree_flatten(output)[0]
rg_outputs = [t for t in flat_outputs if isinstance(t, Tensor) and t.requires_grad]
return rg_outputs

vjp_func = torch.func.vjp(flat_functional_model_call, self.trainable_params)[1]
vjp_func = torch.func.vjp(functional_model_call, self.trainable_params)[1]

# vjp_func is a function that computes the vjp w.r.t. to the primals (tuple). Here the
# functional has a single primal which is dict(module.named_parameters()). We therefore take
Expand All @@ -95,28 +97,17 @@ class AutogradVJP(ModuleVJP):
forward pass.
"""

def __init__(self, module: nn.Module, outputs: Sequence[Tensor]):
def __init__(self, module: nn.Module, rg_outputs: Sequence[Tensor]):
super().__init__(module)

self.outputs_that_require_grad = list[Tensor]()
self.mask = list[bool]()
for output in outputs:
requires_grad = output.requires_grad
if requires_grad:
self.outputs_that_require_grad.append(output)
self.mask.append(requires_grad)

self.rg_outputs = rg_outputs
self.flat_trainable_params, self.param_spec = tree_flatten(self.trainable_params)

def __call__(self, grad_outputs: tuple[Tensor, ...], _: PyTree) -> dict[str, Tensor]:

# Only keep the grad_outputs corresponding to outputs that require grad.
grad_outputs_ = [grad_output for grad_output, rg in zip(grad_outputs, self.mask) if rg]

grads = torch.autograd.grad(
self.outputs_that_require_grad,
self.rg_outputs,
self.flat_trainable_params,
grad_outputs_,
grad_outputs,
retain_graph=True,
allow_unused=True,
materialize_grads=True,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
(Ndim3Output, 32),
(Ndim4Output, 32),
(WithDropout, 32),
(WithModuleWithStringOutput, 32),
(FreeParam, 32),
(NoFreeParam, 32),
param(Cifar10Model, 16, marks=mark.slow),
Expand Down Expand Up @@ -167,7 +168,6 @@ def test_compute_gramian(architecture: type[ShapedModule], batch_size: int, batc
WithModuleTrackingRunningStats,
param(WithRNN, marks=mark.xfail_if_cuda),
WithModuleWithStringArg,
param(WithModuleWithStringOutput, marks=mark.xfail),
],
)
@mark.parametrize("batch_size", [1, 3, 32])
Expand Down
Loading