Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/examples/iwmtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The following example shows how to do that.
optimizer = SGD(params, lr=0.1)
mse = MSELoss(reduction="none")
weighting = Flattening(UPGradWeighting())
engine = Engine(shared_module.modules(), batch_dim=0)
engine = Engine(shared_module, batch_dim=0)

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/iwrm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ batch of data. When minimizing per-instance losses (IWRM), we use either autojac
params = model.parameters()
optimizer = SGD(params, lr=0.1)
weighting = UPGradWeighting()
engine = Engine(model.modules(), batch_dim=0)
engine = Engine(model, batch_dim=0)

for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/partial_jd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ first ``Linear`` layer, thereby reducing memory usage and computation time.

# Create the autogram engine that will compute the Gramian of the
# Jacobian with respect to the two last Linear layers' parameters.
engine = Engine(model[2:].modules(), batch_dim=0)
engine = Engine(model[2:], batch_dim=0)

params = model.parameters()
optimizer = SGD(params, lr=0.1)
Expand Down
28 changes: 14 additions & 14 deletions src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections.abc import Iterable
from typing import cast

import torch
Expand Down Expand Up @@ -58,8 +57,9 @@ class Engine:
backpropagate the losses. This is equivalent to doing a step of standard Jacobian descent using
:func:`torchjd.autojac.backward`.

:param modules: A collection of modules whose direct (non-recursive) parameters will contribute
to the Gramian of the Jacobian.
:param modules: The modules whose parameters will contribute to the Gramian of the Jacobian.
Several modules can be provided, but it's important that none of them is a child module of
another of them.
:param batch_dim: If the modules work with batches and process each batch element independently,
then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial
memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates
Expand Down Expand Up @@ -91,7 +91,7 @@ class Engine:
weighting = UPGradWeighting()

# Create the engine before the backward pass, and only once.
engine = Engine(model.modules(), batch_dim=0)
engine = Engine(model, batch_dim=0)

for input, target in zip(inputs, targets):
output = model(input).squeeze(dim=1) # shape: [16]
Expand Down Expand Up @@ -173,7 +173,7 @@ class Engine:

def __init__(
self,
modules: Iterable[nn.Module],
*modules: nn.Module,
batch_dim: int | None,
):
self._gramian_accumulator = GramianAccumulator()
Expand All @@ -183,16 +183,16 @@ def __init__(
self._target_edges, self._gramian_accumulator, batch_dim is not None
)

self._hook_modules(modules)
for module in modules:
self._hook_module_recursively(module)

def _hook_modules(self, modules: Iterable[nn.Module]) -> None:
_modules = set(modules)

# Add module forward hooks to compute jacobians
for module in _modules:
if any(p.requires_grad for p in module.parameters(recurse=False)):
self._check_module_is_compatible(module)
self._module_hook_manager.hook_module(module)
def _hook_module_recursively(self, module: nn.Module) -> None:
if any(p.requires_grad for p in module.parameters(recurse=False)):
self._check_module_is_compatible(module)
self._module_hook_manager.hook_module(module)
else:
for child in module.children():
self._hook_module_recursively(child)

def _check_module_is_compatible(self, module: nn.Module) -> None:
if self._batch_dim is not None:
Expand Down
5 changes: 2 additions & 3 deletions src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from ._edge_registry import EdgeRegistry
from ._gramian_accumulator import GramianAccumulator
from ._module_utils import get_used_params
from ._vjp import VJP, AutogradVJP, FunctionalVJP

# Note about import from protected _pytree module:
Expand Down Expand Up @@ -126,8 +125,8 @@ def __call__(
# require grad
return outputs

rg_params, _ = get_used_params(module)
self.gramian_accumulator.track_parameter_paths(rg_params.values())
rg_params = [p for p in module.parameters(recurse=True) if p.requires_grad]
self.gramian_accumulator.track_parameter_paths(rg_params)

# We only care about running the JacobianAccumulator node, so we need one of its child
# edges (the edges of the original outputs of the model) as target. For memory
Expand Down
47 changes: 0 additions & 47 deletions src/torchjd/autogram/_module_utils.py

This file was deleted.

12 changes: 9 additions & 3 deletions src/torchjd/autogram/_vjp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from torch.nn import Parameter
from torch.utils._pytree import PyTree, tree_flatten, tree_map_only, tree_unflatten

from torchjd.autogram._module_utils import get_used_params

# Note about import from protected _pytree module:
# PyTorch maintainers plan to make pytree public (see
# https://github.com/pytorch/pytorch/issues/65761, https://github.com/pytorch/pytorch/pull/137400).
Expand Down Expand Up @@ -39,7 +37,15 @@ class ModuleVJP(VJP, ABC):

def __init__(self, module: nn.Module):
self.module = module
self.rg_params, self.frozen_params = get_used_params(module)

self.rg_params = dict[str, Parameter]()
self.frozen_params = dict[str, Parameter]()

for name, param in module.named_parameters(recurse=True):
if param.requires_grad:
self.rg_params[name] = param
else:
self.frozen_params[name] = param


class FunctionalVJP(ModuleVJP):
Expand Down
2 changes: 1 addition & 1 deletion tests/doc/test_autogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_engine():
weighting = UPGradWeighting()

# Create the engine before the backward pass, and only once.
engine = Engine(model.modules(), batch_dim=0)
engine = Engine(model, batch_dim=0)

for input, target in zip(inputs, targets):
output = model(input).squeeze(dim=1) # shape: [16]
Expand Down
6 changes: 3 additions & 3 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_iwmtl():
optimizer = SGD(params, lr=0.1)
mse = MSELoss(reduction="none")
weighting = Flattening(UPGradWeighting())
engine = Engine(shared_module.modules(), batch_dim=0)
engine = Engine(shared_module, batch_dim=0)

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_autogram():
params = model.parameters()
optimizer = SGD(params, lr=0.1)
weighting = UPGradWeighting()
engine = Engine(model.modules(), batch_dim=0)
engine = Engine(model, batch_dim=0)

for x, y in zip(X, Y):
y_hat = model(x).squeeze(dim=1) # shape: [16]
Expand Down Expand Up @@ -374,7 +374,7 @@ def test_partial_jd():

# Create the autogram engine that will compute the Gramian of the
# Jacobian with respect to the two last Linear layers' parameters.
engine = Engine(model[2:].modules(), batch_dim=0)
engine = Engine(model[2:], batch_dim=0)

params = model.parameters()
optimizer = SGD(params, lr=0.1)
Expand Down
2 changes: 1 addition & 1 deletion tests/speed/autogram/grad_vs_jac_vs_gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def post_fn():
print(autojac_times)
print()

engine = Engine(model.modules(), batch_dim=0)
engine = Engine(model, batch_dim=0)
autogram_times = torch.tensor(time_call(fn_autogram, init_fn_autogram, pre_fn, post_fn, n_runs))
print(f"autogram times (avg = {autogram_times.mean():.5f}, std = {autogram_times.std():.5f}")
print(autogram_times)
Expand Down
30 changes: 15 additions & 15 deletions tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _assert_gramian_is_equivalent_to_autograd(
torch.manual_seed(0)
model_autogram = architecture().to(device=DEVICE)

engine = Engine(model_autogram.modules(), batch_dim=batch_dim)
engine = Engine(model_autogram, batch_dim=batch_dim)

inputs = make_tensors(batch_size, input_shapes)
targets = make_tensors(batch_size, output_shapes)
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_compute_gramian_various_output_shapes(
torch.manual_seed(0)
model_autogram = architecture().to(device=DEVICE)

engine = Engine(model_autogram.modules(), batch_dim=batch_dim)
engine = Engine(model_autogram, batch_dim=batch_dim)

inputs = make_tensors(batch_size, input_shapes)
targets = make_tensors(batch_size, output_shapes)
Expand Down Expand Up @@ -324,7 +324,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int
autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True)
torch.manual_seed(0)

engine = Engine(gramian_modules, batch_dim=batch_dim)
engine = Engine(*gramian_modules, batch_dim=batch_dim)

output = model(input)
losses = reduce_to_vector(loss_fn(output))
Expand All @@ -349,7 +349,7 @@ def test_iwrm_steps_with_autogram(

model = architecture().to(device=DEVICE)

engine = Engine(model.modules(), batch_dim=batch_dim)
engine = Engine(model, batch_dim=batch_dim)
optimizer = SGD(model.parameters(), lr=1e-7)

for i in range(n_iter):
Expand Down Expand Up @@ -388,7 +388,7 @@ def test_autograd_while_modules_are_hooked(
autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None}

# Hook modules and optionally compute the Gramian
engine = Engine(model_autogram.modules(), batch_dim=batch_dim)
engine = Engine(model_autogram, batch_dim=batch_dim)
if use_engine:
torch.manual_seed(0) # Fix randomness for random models
output = model_autogram(input)
Expand Down Expand Up @@ -420,7 +420,7 @@ def test_incompatible_modules(architecture: type[nn.Module], batch_dim: int | No
model = architecture().to(device=DEVICE)

with pytest.raises(ValueError):
_ = Engine(model.modules(), batch_dim=batch_dim)
_ = Engine(model, batch_dim=batch_dim)


def test_compute_gramian_manual():
Expand All @@ -434,7 +434,7 @@ def test_compute_gramian_manual():

torch.manual_seed(0)
model = Linear(in_dims, out_dims).to(device=DEVICE)
engine = Engine(model.modules(), batch_dim=None)
engine = Engine(model, batch_dim=None)

input = randn_(in_dims)
output = model(input)
Expand Down Expand Up @@ -480,8 +480,8 @@ def test_reshape_equivariance(shape: list[int], batch_dim: int | None):
output_size = prod(shape[1:])

model = Linear(input_size, output_size).to(device=DEVICE)
engine1 = Engine([model], batch_dim=None)
engine2 = Engine([model], batch_dim=None)
engine1 = Engine(model, batch_dim=None)
engine2 = Engine(model, batch_dim=None)

input = randn_([input_size])
output = model(input)
Expand Down Expand Up @@ -520,8 +520,8 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination:
output_size = prod(shape[1:])

model = Linear(input_size, output_size).to(device=DEVICE)
engine1 = Engine([model], batch_dim=None)
engine2 = Engine([model], batch_dim=None)
engine1 = Engine(model, batch_dim=None)
engine2 = Engine(model, batch_dim=None)

input = randn_([input_size])
output = model(input).reshape(shape[1:])
Expand Down Expand Up @@ -563,8 +563,8 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int):
output_size = input_size

model = Linear(input_size, output_size).to(device=DEVICE)
engine1 = Engine([model], batch_dim=batch_dim)
engine2 = Engine([model], batch_dim=None)
engine1 = Engine(model, batch_dim=batch_dim)
engine2 = Engine(model, batch_dim=None)

input = randn_([batch_size, input_size])
output = model(input)
Expand Down Expand Up @@ -595,8 +595,8 @@ def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_siz
torch.manual_seed(0)
model_none = architecture().to(device=DEVICE)

engine_0 = Engine(model_0.modules(), batch_dim=0)
engine_none = Engine(model_none.modules(), batch_dim=None)
engine_0 = Engine(model_0, batch_dim=0)
engine_none = Engine(model_none, batch_dim=None)

inputs = make_tensors(batch_size, input_shapes)
targets = make_tensors(batch_size, output_shapes)
Expand Down
Loading