From b5f879e81f5aeafc270749d226e289d3cf5c87f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 13 Oct 2025 03:21:03 +0200 Subject: [PATCH] test(autogram): Stop creating two engines for one model --- tests/unit/autogram/test_engine.py | 51 +++++++++++++++++------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 31f63559..c98831dd 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -479,18 +479,20 @@ def test_reshape_equivariance(shape: list[int], batch_dim: int | None): input_size = shape[0] output_size = prod(shape[1:]) - model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine(model, batch_dim=None) - engine2 = Engine(model, batch_dim=None) + torch.manual_seed(0) + model1 = Linear(input_size, output_size).to(device=DEVICE) + torch.manual_seed(0) + model2 = Linear(input_size, output_size).to(device=DEVICE) - input = randn_([input_size]) - output = model(input) + engine1 = Engine(model1, batch_dim=None) + engine2 = Engine(model2, batch_dim=None) - reshaped_output = output.reshape(shape[1:]) + input = randn_([input_size]) + output = model1(input) + reshaped_output = model2(input).reshape(shape[1:]) gramian = engine1.compute_gramian(output) reshaped_gramian = engine2.compute_gramian(reshaped_output) - expected_reshaped_gramian = reshape_gramian(gramian, shape[1:]) assert_close(reshaped_gramian, expected_reshaped_gramian) @@ -519,18 +521,20 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: input_size = shape[0] output_size = prod(shape[1:]) - model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine(model, batch_dim=None) - engine2 = Engine(model, batch_dim=None) + torch.manual_seed(0) + model1 = Linear(input_size, output_size).to(device=DEVICE) + torch.manual_seed(0) + model2 = Linear(input_size, output_size).to(device=DEVICE) - input = randn_([input_size]) - output = model(input).reshape(shape[1:]) + engine1 = Engine(model1, batch_dim=None) + engine2 = Engine(model2, batch_dim=None) - moved_output = output.movedim(source, destination) + input = randn_([input_size]) + output = model1(input).reshape(shape[1:]) + moved_output = model2(input).reshape(shape[1:]).movedim(source, destination) gramian = engine1.compute_gramian(output) moved_gramian = engine2.compute_gramian(moved_output) - expected_moved_gramian = movedim_gramian(gramian, source, destination) assert_close(moved_gramian, expected_moved_gramian) @@ -562,17 +566,20 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): batch_size = shape[batch_dim] output_size = input_size - model = Linear(input_size, output_size).to(device=DEVICE) - engine1 = Engine(model, batch_dim=batch_dim) - engine2 = Engine(model, batch_dim=None) + torch.manual_seed(0) + model1 = Linear(input_size, output_size).to(device=DEVICE) + torch.manual_seed(0) + model2 = Linear(input_size, output_size).to(device=DEVICE) + + engine1 = Engine(model1, batch_dim=batch_dim) + engine2 = Engine(model2, batch_dim=None) input = randn_([batch_size, input_size]) - output = model(input) - output = output.reshape([batch_size] + non_batched_shape) - output = output.movedim(0, batch_dim) + output1 = model1(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim) + output2 = model2(input).reshape([batch_size] + non_batched_shape).movedim(0, batch_dim) - gramian1 = engine1.compute_gramian(output) - gramian2 = engine2.compute_gramian(output) + gramian1 = engine1.compute_gramian(output1) + gramian2 = engine2.compute_gramian(output2) assert_close(gramian1, gramian2)