diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index d4cac1ec..302e0137 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -565,3 +565,39 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): gramian2 = engine2.compute_gramian(output) assert_close(gramian1, gramian2) + + +@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) +def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_size: int): + """ + Same as test_batched_non_batched_equivalence but on real architectures, and thus only between + batch_size=0 and batch_size=None. + """ + + input_shapes = architecture.INPUT_SHAPES + output_shapes = architecture.OUTPUT_SHAPES + + torch.manual_seed(0) + model_0 = architecture().to(device=DEVICE) + torch.manual_seed(0) + model_none = architecture().to(device=DEVICE) + + engine_0 = Engine(model_0.modules(), batch_dim=0) + engine_none = Engine(model_none.modules(), batch_dim=None) + + inputs = make_tensors(batch_size, input_shapes) + targets = make_tensors(batch_size, output_shapes) + loss_fn = make_mse_loss_fn(targets) + + torch.random.manual_seed(0) # Fix randomness for random models + output = model_0(inputs) + losses_0 = reduce_to_vector(loss_fn(output)) + + torch.random.manual_seed(0) # Fix randomness for random models + output = model_none(inputs) + losses_none = reduce_to_vector(loss_fn(output)) + + gramian_0 = engine_0.compute_gramian(losses_0) + gramian_none = engine_none.compute_gramian(losses_none) + + assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5)