Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 101 additions & 101 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,6 @@
# still support older versions of PyTorch where pytree is protected).


class BoolRef:
"""Class wrapping a boolean value, acting as a reference to this boolean value."""

def __init__(self, value: bool):
self.value = value

def __bool__(self) -> bool:
return self.value


class ModuleHookManager:
"""
Class responsible for handling hooks and Nodes that computes the Gramian reverse accumulation.
Expand Down Expand Up @@ -88,58 +78,64 @@ def remove_hooks(handles: list[TorchRemovableHandle]) -> None:
handle.remove()


class AccumulateJacobian(torch.autograd.Function):
class BoolRef:
"""Class wrapping a boolean value, acting as a reference to this boolean value."""

@staticmethod
def forward(
output_spec: TreeSpec,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*flat_grad_outputs: Tensor,
) -> None:
# There is no non-batched dimension
grad_outputs = tree_unflatten(flat_grad_outputs, output_spec)
generalized_jacobians = vjp(grad_outputs, args)
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
def __init__(self, value: bool):
self.value = value

@staticmethod
def vmap(
_,
in_dims: PyTree,
output_spec: TreeSpec,
vjp: VJP,
args: PyTree,
def __bool__(self) -> bool:
return self.value


class Hook:
def __init__(
self,
gramian_accumulation_phase: BoolRef,
target_edges: EdgeRegistry,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*flat_jac_outputs: Tensor,
) -> tuple[None, None]:
# There is a non-batched dimension
jac_outputs = tree_unflatten(flat_jac_outputs, output_spec)
# We do not vmap over the args for the non-batched dimension
in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args))
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args)
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
return None, None
has_batch_dim: bool,
):
self.gramian_accumulation_phase = gramian_accumulation_phase
self.target_edges = target_edges
self.gramian_accumulator = gramian_accumulator
self.has_batch_dim = has_batch_dim

@staticmethod
def _make_path_jacobians(
module: nn.Module,
generalized_jacobians: dict[str, Tensor],
) -> dict[Tensor, Tensor]:
path_jacobians: dict[Tensor, Tensor] = {}
for param_name, generalized_jacobian in generalized_jacobians.items():
key = module.get_parameter(param_name)
jacobian = generalized_jacobian.reshape([-1] + list(key.shape))
path_jacobians[key] = jacobian
return path_jacobians
def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
if self.gramian_accumulation_phase:
return output

@staticmethod
def setup_context(*_):
pass
flat_outputs, output_spec = tree_flatten(output)

if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs):
# This can happen only if a module has a trainable param but outputs no tensor that
# require grad
return output

requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
self.gramian_accumulator.track_parameter_paths(requires_grad_params)

# We only care about running the JacobianAccumulator node, so we need one of its child
# edges (the edges of the original outputs of the model) as target. For memory
# efficiency, we select the smallest one (that requires grad).
inf = float("inf")
preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs])
index = cast(int, preference.argmin().item())
self.target_edges.register(get_gradient_edge(flat_outputs[index]))

vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs)

autograd_fn_outputs = JacobianAccumulator.apply(
self.gramian_accumulation_phase,
output_spec,
vjp,
args,
self.gramian_accumulator,
module,
*flat_outputs,
)

return tree_unflatten(autograd_fn_outputs, output_spec)


class JacobianAccumulator(torch.autograd.Function):
Expand Down Expand Up @@ -197,51 +193,55 @@ def backward(ctx, *flat_grad_outputs: Tensor):
return None, None, None, None, None, None, *flat_grad_outputs


class Hook:
def __init__(
self,
gramian_accumulation_phase: BoolRef,
target_edges: EdgeRegistry,
gramian_accumulator: GramianAccumulator,
has_batch_dim: bool,
):
self.gramian_accumulation_phase = gramian_accumulation_phase
self.target_edges = target_edges
self.gramian_accumulator = gramian_accumulator
self.has_batch_dim = has_batch_dim

def __call__(self, module: nn.Module, args: PyTree, output: PyTree) -> PyTree:
if self.gramian_accumulation_phase:
return output

flat_outputs, output_spec = tree_flatten(output)

if not any(isinstance(t, Tensor) and t.requires_grad for t in flat_outputs):
# This can happen only if a module has a trainable param but outputs no tensor that
# require grad
return output

requires_grad_params = [p for p in module.parameters(recurse=False) if p.requires_grad]
self.gramian_accumulator.track_parameter_paths(requires_grad_params)
class AccumulateJacobian(torch.autograd.Function):

# We only care about running the JacobianAccumulator node, so we need one of its child
# edges (the edges of the original outputs of the model) as target. For memory
# efficiency, we select the smallest one (that requires grad).
inf = float("inf")
preference = torch.tensor([t.numel() if t.requires_grad else inf for t in flat_outputs])
index = cast(int, preference.argmin().item())
self.target_edges.register(get_gradient_edge(flat_outputs[index]))
@staticmethod
def forward(
output_spec: TreeSpec,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*flat_grad_outputs: Tensor,
) -> None:
# There is no non-batched dimension
grad_outputs = tree_unflatten(flat_grad_outputs, output_spec)
generalized_jacobians = vjp(grad_outputs, args)
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
gramian_accumulator.accumulate_path_jacobians(path_jacobians)

vjp = FunctionalVJP(module) if self.has_batch_dim else AutogradVJP(module, flat_outputs)
@staticmethod
def vmap(
_,
in_dims: PyTree,
output_spec: TreeSpec,
vjp: VJP,
args: PyTree,
gramian_accumulator: GramianAccumulator,
module: nn.Module,
*flat_jac_outputs: Tensor,
) -> tuple[None, None]:
# There is a non-batched dimension
jac_outputs = tree_unflatten(flat_jac_outputs, output_spec)
# We do not vmap over the args for the non-batched dimension
in_dims = (tree_unflatten(in_dims[5:], output_spec), tree_map(lambda _: None, args))
generalized_jacobians = torch.vmap(vjp, in_dims=in_dims)(jac_outputs, args)
path_jacobians = AccumulateJacobian._make_path_jacobians(module, generalized_jacobians)
gramian_accumulator.accumulate_path_jacobians(path_jacobians)
return None, None

autograd_fn_outputs = JacobianAccumulator.apply(
self.gramian_accumulation_phase,
output_spec,
vjp,
args,
self.gramian_accumulator,
module,
*flat_outputs,
)
@staticmethod
def _make_path_jacobians(
module: nn.Module,
generalized_jacobians: dict[str, Tensor],
) -> dict[Tensor, Tensor]:
path_jacobians: dict[Tensor, Tensor] = {}
for param_name, generalized_jacobian in generalized_jacobians.items():
key = module.get_parameter(param_name)
jacobian = generalized_jacobian.reshape([-1] + list(key.shape))
path_jacobians[key] = jacobian
return path_jacobians

return tree_unflatten(autograd_fn_outputs, output_spec)
@staticmethod
def setup_context(*_):
pass
Loading