Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/unit/autogram/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,39 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int):
gramian2 = engine2.compute_gramian(output)

assert_close(gramian1, gramian2)


@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS)
def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_size: int):
"""
Same as test_batched_non_batched_equivalence but on real architectures, and thus only between
batch_size=0 and batch_size=None.
"""

input_shapes = architecture.INPUT_SHAPES
output_shapes = architecture.OUTPUT_SHAPES

torch.manual_seed(0)
model_0 = architecture().to(device=DEVICE)
torch.manual_seed(0)
model_none = architecture().to(device=DEVICE)

engine_0 = Engine(model_0.modules(), batch_dim=0)
engine_none = Engine(model_none.modules(), batch_dim=None)

inputs = make_tensors(batch_size, input_shapes)
targets = make_tensors(batch_size, output_shapes)
loss_fn = make_mse_loss_fn(targets)

torch.random.manual_seed(0) # Fix randomness for random models
output = model_0(inputs)
losses_0 = reduce_to_vector(loss_fn(output))

torch.random.manual_seed(0) # Fix randomness for random models
output = model_none(inputs)
losses_none = reduce_to_vector(loss_fn(output))

gramian_0 = engine_0.compute_gramian(losses_0)
gramian_none = engine_none.compute_gramian(losses_none)

assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5)
Loading