Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
deb38e8
Implement `_compute_gramian_with_autograd` and add `test_gramian_equi…
PierreQuinton Sep 12, 2025
d1b28f1
Add `requires_grad = True` in `test_partial_autogram`
PierreQuinton Sep 12, 2025
db15921
Move `compute_gramian_with_autograd` to utils
PierreQuinton Sep 12, 2025
390e747
Remove usage of `autojac` in tests of `autogram`
PierreQuinton Sep 12, 2025
3fff2a5
Change `test_equivalence_autograd_autogram` to `test_IWRM_steps_with_…
PierreQuinton Sep 12, 2025
1c5f32e
Remove computation of gradients in `test_partial_autogram`
PierreQuinton Sep 12, 2025
3165ea0
Fix function name
ValerianRey Sep 13, 2025
6dd2ec0
Fix variable names
ValerianRey Sep 13, 2025
5c132be
Add docstring to compute_gramian_with_autograd
ValerianRey Sep 13, 2025
7d22521
Avoid using sum()
ValerianRey Sep 13, 2025
7c0d4b3
Revert removal of autojac_forward_backward and change in compare_auto…
ValerianRey Sep 13, 2025
e81732a
Move compute_gramian_with_autograd to forward_backwards
ValerianRey Sep 13, 2025
eefaff3
Revert breaking everything with sum change
ValerianRey Sep 13, 2025
b22dcb0
Add autograd_gramian speed test in grad_vs_jac_vs_gram
ValerianRey Sep 13, 2025
45b313c
Merge branch 'main' into remove_autojac_dependence_in_autogram_tests
ValerianRey Sep 14, 2025
eace7ad
Fix outdated comment
ValerianRey Sep 15, 2025
451bd3e
Fix tolerance of test_gramian_equivalence_autograd_autogram
ValerianRey Sep 15, 2025
fc9a1c9
Remove outdated docstring
ValerianRey Sep 15, 2025
33cf529
Fix docstring
ValerianRey Sep 15, 2025
2057e05
Simplify test_autograd_while_modules_are_hooked
ValerianRey Sep 15, 2025
e5698f9
Restructure file
ValerianRey Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions tests/speed/autogram/grad_vs_jac_vs_gram.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import time

import torch
Expand All @@ -15,6 +16,7 @@
)
from utils.forward_backwards import (
autograd_forward_backward,
autograd_gramian_forward_backward,
autogram_forward_backward,
autojac_forward_backward,
make_mse_loss_fn,
Expand All @@ -31,8 +33,8 @@
(AlexNet, 8),
(InstanceNormResNet18, 16),
(GroupNormMobileNetV3Small, 16),
(SqueezeNet, 16),
(InstanceNormMobileNetV2, 8),
(SqueezeNet, 4),
(InstanceNormMobileNetV2, 2),
]


Expand All @@ -58,20 +60,31 @@ def fn_autograd():

def init_fn_autograd():
torch.cuda.empty_cache()
gc.collect()
fn_autograd()

def fn_autograd_gramian():
autograd_gramian_forward_backward(model, inputs, list(model.parameters()), loss_fn, W)

def init_fn_autograd_gramian():
torch.cuda.empty_cache()
gc.collect()
fn_autograd_gramian()

def fn_autojac():
autojac_forward_backward(model, inputs, loss_fn, A)

def init_fn_autojac():
torch.cuda.empty_cache()
gc.collect()
fn_autojac()

def fn_autogram():
autogram_forward_backward(model, engine, W, inputs, loss_fn)

def init_fn_autogram():
torch.cuda.empty_cache()
gc.collect()
fn_autogram()

def optionally_cuda_sync():
Expand All @@ -91,6 +104,16 @@ def post_fn():
print(autograd_times)
print()

autograd_gramian_times = torch.tensor(
time_call(fn_autograd_gramian, init_fn_autograd_gramian, pre_fn, post_fn, n_runs)
)
print(
f"autograd gramian times (avg = {autograd_gramian_times.mean():.5f}, std = "
f"{autograd_gramian_times.std():.5f}"
)
print(autograd_gramian_times)
print()

autojac_times = torch.tensor(time_call(fn_autojac, init_fn_autojac, pre_fn, post_fn, n_runs))
print(f"autojac times (avg = {autojac_times.mean():.5f}, std = {autojac_times.std():.5f}")
print(autojac_times)
Expand Down
212 changes: 91 additions & 121 deletions tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytest import mark, param
from torch import nn
from torch.optim import SGD
from torch.testing import assert_close
from unit.conftest import DEVICE
from utils.architectures import (
AlexNet,
Expand Down Expand Up @@ -52,15 +53,13 @@
from utils.forward_backwards import (
autograd_forward_backward,
autogram_forward_backward,
autojac_forward_backward,
compute_gramian_with_autograd,
make_mse_loss_fn,
)
from utils.tensors import make_tensors

from torchjd.aggregation import UPGrad, UPGradWeighting
from torchjd.aggregation import UPGradWeighting
from torchjd.autogram._engine import Engine
from torchjd.autojac._transform import Diagonalize, Init, Jac, OrderedSet
from torchjd.autojac._transform._aggregate import _Matrixify

PARAMETRIZATIONS = [
(OverlyNested, 32),
Expand Down Expand Up @@ -107,110 +106,34 @@


@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
def test_equivalence_autojac_autogram(
architecture: type[ShapedModule],
batch_size: int,
):
"""
Tests that the autogram engine gives the same results as the autojac engine on IWRM for several
JD steps.
"""

n_iter = 3
def test_compute_gramian(architecture: type[ShapedModule], batch_size: int):
"""Tests that the autograd and the autogram engines compute the same gramian."""

input_shapes = architecture.INPUT_SHAPES
output_shapes = architecture.OUTPUT_SHAPES

weighting = UPGradWeighting()
aggregator = UPGrad()

torch.manual_seed(0)
model_autojac = architecture().to(device=DEVICE)
model_autograd = architecture().to(device=DEVICE)
torch.manual_seed(0)
model_autogram = architecture().to(device=DEVICE)

engine = Engine(model_autogram.modules())
optimizer_autojac = SGD(model_autojac.parameters(), lr=1e-7)
optimizer_autogram = SGD(model_autogram.parameters(), lr=1e-7)

for i in range(n_iter):
inputs = make_tensors(batch_size, input_shapes)
targets = make_tensors(batch_size, output_shapes)
loss_fn = make_mse_loss_fn(targets)

torch.random.manual_seed(0) # Fix randomness for random aggregators and random models
autojac_forward_backward(model_autojac, inputs, loss_fn, aggregator)
expected_grads = {
name: p.grad for name, p in model_autojac.named_parameters() if p.grad is not None
}

torch.random.manual_seed(0) # Fix randomness for random weightings and random models
autogram_forward_backward(model_autogram, engine, weighting, inputs, loss_fn)
grads = {
name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None
}

assert_tensor_dicts_are_close(grads, expected_grads)

optimizer_autojac.step()
model_autojac.zero_grad()

optimizer_autogram.step()
model_autogram.zero_grad()


@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
def test_autograd_while_modules_are_hooked(architecture: type[ShapedModule], batch_size: int):
"""
Tests that the hooks added when constructing the engine do not interfere with a simple autograd
call.
"""

input_shapes = architecture.INPUT_SHAPES
output_shapes = architecture.OUTPUT_SHAPES

W = UPGradWeighting()
A = UPGrad()
input = make_tensors(batch_size, input_shapes)
inputs = make_tensors(batch_size, input_shapes)
targets = make_tensors(batch_size, output_shapes)
loss_fn = make_mse_loss_fn(targets)

torch.manual_seed(0)
model = architecture().to(device=DEVICE)

torch.manual_seed(0) # Fix randomness for random models
autojac_forward_backward(model, input, loss_fn, A)
autojac_grads = {
name: p.grad.clone() for name, p in model.named_parameters() if p.grad is not None
}
model.zero_grad()

torch.manual_seed(0) # Fix randomness for random models
autograd_forward_backward(model, input, loss_fn)
autograd_grads = {
name: p.grad.clone() for name, p in model.named_parameters() if p.grad is not None
}

torch.manual_seed(0)
model_autogram = architecture().to(device=DEVICE)
torch.random.manual_seed(0) # Fix randomness for random models
output = model_autograd(inputs)
losses = loss_fn(output)
autograd_gramian = compute_gramian_with_autograd(losses, list(model_autograd.parameters()))

# Hook modules and verify that we're equivalent to autojac when using the engine
engine = Engine(model_autogram.modules())
torch.manual_seed(0) # Fix randomness for random models
autogram_forward_backward(model_autogram, engine, W, input, loss_fn)
grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None}
assert_tensor_dicts_are_close(grads, autojac_grads)
model_autogram.zero_grad()
torch.random.manual_seed(0) # Fix randomness for random models
output = model_autogram(inputs)
losses = loss_fn(output)
autogram_gramian = engine.compute_gramian(losses)

# Verify that even with the hooked modules, autograd works normally when not using the engine.
# Results should be the same as a normal call to autograd, and no time should be spent computing
# the gramian at all.
torch.manual_seed(0) # Fix randomness for random models
autograd_forward_backward(model_autogram, input, loss_fn)
assert engine._gramian_accumulator.gramian is None
grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None}
assert_tensor_dicts_are_close(grads, autograd_grads)
model_autogram.zero_grad()
assert_close(autogram_gramian, autograd_gramian, rtol=1e-4, atol=1e-5)


def _non_empty_subsets(elements: set) -> list[set]:
Expand All @@ -221,20 +144,15 @@ def _non_empty_subsets(elements: set) -> list[set]:


@mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"}))
def test_partial_autogram(gramian_module_names: set[str]):
def test_compute_partial_gramian(gramian_module_names: set[str]):
"""
Tests that partial JD via the autogram engine works similarly as if the gramian was computed via
the autojac engine.

Note that this test is a bit redundant now that we have the Engine interface, because it now
just compares two ways of computing the Gramian, which is independant of the idea of partial JD.
Tests that the autograd and the autogram engines compute the same gramian when only a subset of
the model parameters is specified.
"""

architecture = SimpleBranched
batch_size = 64

weighting = UPGradWeighting()

input_shapes = architecture.INPUT_SHAPES
output_shapes = architecture.OUTPUT_SHAPES

Expand All @@ -247,39 +165,91 @@ def test_partial_autogram(gramian_module_names: set[str]):

output = model(input)
losses = loss_fn(output)
losses_ = OrderedSet(losses)

init = Init(losses_)
diag = Diagonalize(losses_)

gramian_modules = [model.get_submodule(name) for name in gramian_module_names]
gramian_params = OrderedSet({})
gramian_params = []
for m in gramian_modules:
gramian_params += OrderedSet(m.parameters())
gramian_params += list(m.parameters())

jac = Jac(losses_, OrderedSet(gramian_params), None, True)
mat = _Matrixify()
transform = mat << jac << diag << init

jacobian_matrices = transform({})
jacobian_matrix = torch.cat(list(jacobian_matrices.values()), dim=1)
gramian = jacobian_matrix @ jacobian_matrix.T
autograd_gramian = compute_gramian_with_autograd(losses, gramian_params, retain_graph=True)
torch.manual_seed(0)
losses.backward(weighting(gramian))

expected_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None}
model.zero_grad()

engine = Engine(gramian_modules)

output = model(input)
losses = loss_fn(output)
gramian = engine.compute_gramian(losses)

assert_close(gramian, autograd_gramian)


@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
def test_iwrm_steps_with_autogram(architecture: type[ShapedModule], batch_size: int):
"""Tests that the autogram engine doesn't raise any error during several IWRM iterations."""

n_iter = 3

input_shapes = architecture.INPUT_SHAPES
output_shapes = architecture.OUTPUT_SHAPES

weighting = UPGradWeighting()

model = architecture().to(device=DEVICE)

engine = Engine(model.modules())
optimizer = SGD(model.parameters(), lr=1e-7)

for i in range(n_iter):
inputs = make_tensors(batch_size, input_shapes)
targets = make_tensors(batch_size, output_shapes)
loss_fn = make_mse_loss_fn(targets)

autogram_forward_backward(model, engine, weighting, inputs, loss_fn)

optimizer.step()
model.zero_grad()


@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
@mark.parametrize("compute_gramian", [False, True])
def test_autograd_while_modules_are_hooked(
architecture: type[ShapedModule], batch_size: int, compute_gramian: bool
):
"""
Tests that the hooks added when constructing the engine do not interfere with a simple autograd
call.
"""

input = make_tensors(batch_size, architecture.INPUT_SHAPES)
targets = make_tensors(batch_size, architecture.OUTPUT_SHAPES)
loss_fn = make_mse_loss_fn(targets)

torch.manual_seed(0)
model = architecture().to(device=DEVICE)
torch.manual_seed(0)
losses.backward(weighting(gramian))
model_autogram = architecture().to(device=DEVICE)

torch.manual_seed(0) # Fix randomness for random models
autograd_forward_backward(model, input, loss_fn)
autograd_grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None}

# Hook modules and optionally compute the Gramian
engine = Engine(model_autogram.modules())
if compute_gramian:
torch.manual_seed(0) # Fix randomness for random models
output = model_autogram(input)
losses = loss_fn(output)
_ = engine.compute_gramian(losses)

grads = {name: p.grad for name, p in model.named_parameters() if p.grad is not None}
assert_tensor_dicts_are_close(grads, expected_grads)
# Verify that even with the hooked modules, autograd works normally when not using the engine.
# Results should be the same as a normal call to autograd, and no time should be spent computing
# the gramian at all.
torch.manual_seed(0) # Fix randomness for random models
autograd_forward_backward(model_autogram, input, loss_fn)
grads = {name: p.grad for name, p in model_autogram.named_parameters() if p.grad is not None}

assert_tensor_dicts_are_close(grads, autograd_grads)
assert engine._gramian_accumulator.gramian is None


@mark.parametrize("architecture", [WithRNN, WithModuleTrackingRunningStats])
Expand Down
Loading
Loading