diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 56b3c5d0..50135796 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -309,7 +309,7 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int the model parameters is specified. """ - model = SimpleBranched() + model = ModuleFactory(SimpleBranched)() batch_size = 64 inputs, targets = make_inputs_and_targets(model, batch_size) loss_fn = make_mse_loss_fn(targets)