From 5e5e9208a9c2fcb982584e2d3f93ae5801b4cd65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 20:59:30 +0200 Subject: [PATCH 1/2] Add test_batched_non_batched_equivalence_2 --- tests/unit/autogram/test_engine.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 7c88ae9d..6d530a80 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -562,3 +562,33 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): gramian2 = engine2.compute_gramian(output) assert_close(gramian1, gramian2) + + +@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_size: int): + """ + Same as test_batched_non_batched_equivalence but on real architectures, and thus only between + batch_size=0 and batch_size=None. + """ + + input_shapes = architecture.INPUT_SHAPES + output_shapes = architecture.OUTPUT_SHAPES + + torch.manual_seed(0) + model = architecture().to(device=DEVICE) + + engine_0 = Engine(model.modules(), batch_dim=0) + engine_none = Engine(model.modules(), batch_dim=None) + + inputs = make_tensors(batch_size, input_shapes) + targets = make_tensors(batch_size, output_shapes) + loss_fn = make_mse_loss_fn(targets) + + torch.random.manual_seed(0) # Fix randomness for random models + output = model(inputs) + losses = reduce_to_vector(loss_fn(output)) + + gramian_0 = engine_0.compute_gramian(losses) + gramian_none = engine_none.compute_gramian(losses) + + assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5) From e2e11a20928f11ca677a9ec72f99977a19ed82cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 2 Oct 2025 21:56:53 +0200 Subject: [PATCH 2/2] Fix test_batched_non_batched_equivalence_2 to not have two engines on the same model --- tests/unit/autogram/test_engine.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 6d530a80..a233ef0d 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -575,20 +575,26 @@ def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_siz output_shapes = architecture.OUTPUT_SHAPES torch.manual_seed(0) - model = architecture().to(device=DEVICE) + model_0 = architecture().to(device=DEVICE) + torch.manual_seed(0) + model_none = architecture().to(device=DEVICE) - engine_0 = Engine(model.modules(), batch_dim=0) - engine_none = Engine(model.modules(), batch_dim=None) + engine_0 = Engine(model_0.modules(), batch_dim=0) + engine_none = Engine(model_none.modules(), batch_dim=None) inputs = make_tensors(batch_size, input_shapes) targets = make_tensors(batch_size, output_shapes) loss_fn = make_mse_loss_fn(targets) torch.random.manual_seed(0) # Fix randomness for random models - output = model(inputs) - losses = reduce_to_vector(loss_fn(output)) + output = model_0(inputs) + losses_0 = reduce_to_vector(loss_fn(output)) + + torch.random.manual_seed(0) # Fix randomness for random models + output = model_none(inputs) + losses_none = reduce_to_vector(loss_fn(output)) - gramian_0 = engine_0.compute_gramian(losses) - gramian_none = engine_none.compute_gramian(losses) + gramian_0 = engine_0.compute_gramian(losses_0) + gramian_none = engine_none.compute_gramian(losses_none) assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5)