diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 803027bd..f5e1028d 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -20,16 +20,6 @@ # still support older versions of PyTorch where pytree is protected). -class BoolRef: - """Class wrapping a boolean value, acting as a reference to this boolean value.""" - - def __init__(self, value: bool): - self.value = value - - def __bool__(self) -> bool: - return self.value - - class ModuleHookManager: """ Class responsible for handling hooks and Nodes that computes the Gramian reverse accumulation. @@ -88,58 +78,64 @@ def remove_hooks(handles: list[TorchRemovableHandle]) -> None: handle.remove() -class AccumulateJacobian(torch.autograd.Function): +class BoolRef: + """Class wrapping a boolean value, acting as a reference to this boolean value.""" - @staticmethod - def forward( - output_spec: TreeSpec, - vjp: VJP, - args: PyTree, - gramian_accumulator: GramianAccumulator, - module: nn.Module, - *flat_grad_outputs: Tensor, - ) -> None: - # There is no non-batched dimension - grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) - generalized_jacobians = vjp(grad_outputs, args) - path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) - gramian_accumulator.accumulate_path_jacobians(path_jacobians) + def __init__(self, value: bool): + self.value = value - @staticmethod - def vmap( - _, - in_dims: PyTree, - output_spec: TreeSpec, - vjp: VJP, - args: PyTree, + def __bool__(self) -> bool: + return self.value + + +class Hook: + def __init__( + self, + gramian_accumulation_phase: BoolRef, + target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, - module: nn.Module, - *flat_jac_outputs: Tensor, - ) -> tuple[None, None]: - # There is a non-batched dimension - jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) - # We do not vmap over the args for the non-batched dimension - in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args)) - generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) - path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) - gramian_accumulator.accumulate_path_jacobians(path_jacobians) - return None, None + has_batch_dim: bool, + ): + self.gramian_accumulation_phase = gramian_accumulation_phase + self.target_edges = target_edges + self.gramian_accumulator = gramian_accumulator + self.has_batch_dim = has_batch_dim - @staticmethod - def _make_path_jacobians( - module: nn.Module, - generalized_jacobians: dict[str, Tensor], - ) -> dict[Tensor, Tensor]: - path_jacobians: dict[Tensor, Tensor] = {} - for param_name, generalized_jacobian in generalized_jacobians.items(): - key = module.get_parameter(param_name) - jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) - path_jacobians[key] = jacobian - return path_jacobians + def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: + if self.gramian_accumulation_phase: + return output - @staticmethod - def setup_context(*_): - pass + flat_outputs, output_spec = tree_flatten(output) + + if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs): + # This can happen only if a module has a trainable param but outputs no tensor that + # require grad + return output + + requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] + self.gramian_accumulator.track_parameter_paths(requires_grad_params) + + # We only care about running the JacobianAccumulator node, so we need one of its child + # edges (the edges of the original outputs of the model) as target. For memory + # efficiency, we select the smallest one (that requires grad). + inf = float("inf") + preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs]) + index = cast(int, preference.argmin().item()) + self.target_edges.register(get_gradient_edge(flat_outputs[index])) + + vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs) + + autograd_fn_outputs = JacobianAccumulator.apply( + self.gramian_accumulation_phase, + output_spec, + vjp, + args, + self.gramian_accumulator, + module, + *flat_outputs, + ) + + return tree_unflatten(autograd_fn_outputs, output_spec) class JacobianAccumulator(torch.autograd.Function): @@ -197,51 +193,55 @@ def backward(ctx, *flat_grad_outputs: Tensor): return None, None, None, None, None, None, *flat_grad_outputs -class Hook: - def __init__( - self, - gramian_accumulation_phase: BoolRef, - target_edges: EdgeRegistry, - gramian_accumulator: GramianAccumulator, - has_batch_dim: bool, - ): - self.gramian_accumulation_phase = gramian_accumulation_phase - self.target_edges = target_edges - self.gramian_accumulator = gramian_accumulator - self.has_batch_dim = has_batch_dim - - def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree: - if self.gramian_accumulation_phase: - return output - - flat_outputs, output_spec = tree_flatten(output) - - if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs): - # This can happen only if a module has a trainable param but outputs no tensor that - # require grad - return output - - requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad] - self.gramian_accumulator.track_parameter_paths(requires_grad_params) +class AccumulateJacobian(torch.autograd.Function): - # We only care about running the JacobianAccumulator node, so we need one of its child - # edges (the edges of the original outputs of the model) as target. For memory - # efficiency, we select the smallest one (that requires grad). - inf = float("inf") - preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs]) - index = cast(int, preference.argmin().item()) - self.target_edges.register(get_gradient_edge(flat_outputs[index])) + @staticmethod + def forward( + output_spec: TreeSpec, + vjp: VJP, + args: PyTree, + gramian_accumulator: GramianAccumulator, + module: nn.Module, + *flat_grad_outputs: Tensor, + ) -> None: + # There is no non-batched dimension + grad_outputs = tree_unflatten(flat_grad_outputs, output_spec) + generalized_jacobians = vjp(grad_outputs, args) + path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) + gramian_accumulator.accumulate_path_jacobians(path_jacobians) - vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs) + @staticmethod + def vmap( + _, + in_dims: PyTree, + output_spec: TreeSpec, + vjp: VJP, + args: PyTree, + gramian_accumulator: GramianAccumulator, + module: nn.Module, + *flat_jac_outputs: Tensor, + ) -> tuple[None, None]: + # There is a non-batched dimension + jac_outputs = tree_unflatten(flat_jac_outputs, output_spec) + # We do not vmap over the args for the non-batched dimension + in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args)) + generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args) + path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians) + gramian_accumulator.accumulate_path_jacobians(path_jacobians) + return None, None - autograd_fn_outputs = JacobianAccumulator.apply( - self.gramian_accumulation_phase, - output_spec, - vjp, - args, - self.gramian_accumulator, - module, - *flat_outputs, - ) + @staticmethod + def _make_path_jacobians( + module: nn.Module, + generalized_jacobians: dict[str, Tensor], + ) -> dict[Tensor, Tensor]: + path_jacobians: dict[Tensor, Tensor] = {} + for param_name, generalized_jacobian in generalized_jacobians.items(): + key = module.get_parameter(param_name) + jacobian = generalized_jacobian.reshape([-1] + list(key.shape)) + path_jacobians[key] = jacobian + return path_jacobians - return tree_unflatten(autograd_fn_outputs, output_spec) + @staticmethod + def setup_context(*_): + pass