Skip to content

test: Add ModuleFactory#459

Merged
ValerianRey merged 1 commit intomainfrom
add-module-factory
Oct 16, 2025
Merged

test: Add ModuleFactory#459
ValerianRey merged 1 commit intomainfrom
add-module-factory

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey commented Oct 16, 2025

  • Add ModuleFactory and use it to instantiate models in tests
  • Add get_in_out_shapes and use it to obtain input and output shapes in tests

The direct consequence is that we can now test model that have __init__ arguments. These args will be provided to the ModuleFactory, and then each instance created will receive these args.

Another nice consequence is that we can now test nn.Module that are not ShapedModule, simply by adding a line in get_in_out_shapes to indicate the formula to get the input and output shapes. So for instance, to test a Linear, one would have to add:

elif isinstance(module, nn.Linear):
    return (module.in_features,), (module.out_features,)

in get_in_out_shapes, and:

(ModuleFactory(nn.Linear, in_features=3, out_features=5, bias=False)

to the PARAMETRIZATIONS of test_engine.py.

Note that for example for nn.Conv2d, this would be much more complex, because it depends on all the parameters (maybe there are online code snippets to compute this).

We also don't have to worry about device anymore! The factory imports the DEVICE from conftest.py, and directly moves the created models to it (similarly as the tensor creation functions from test.utils.tensors.py).

Lastly, we don't have to worry about resetting the seed to 0 before creating a model. The factory is in charge of forking the rng and setting the seed to 0 before each model creation. Arguably, we could have a seed: int | None parameter to the __call__ method of ModuleFactory, so that we can optionally not do that, but for now this is good enough I think.

The downside is that for now, the pytest summary has the name "factory0", "factory1", etc., for the different parametrizations, but this is easy to fix. I'll make another PR to fix this + other parameter name issues with pytest.

@ValerianRey ValerianRey added the cc: test Conventional commit type for changes to tests. label Oct 16, 2025
@ValerianRey ValerianRey self-assigned this Oct 16, 2025
@codecov
Copy link
Copy Markdown

codecov bot commented Oct 16, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

* Add ModuleFactory
* Add get_in_out_shapes
* Add condition on model being a ShapedModule before doing assertion about output shape in _forward_pass
* Update module instantiation in tests to use ModuleFactories instead of type[ShapedModule]
@ValerianRey ValerianRey merged commit f2535dc into main Oct 16, 2025
17 checks passed
@ValerianRey ValerianRey deleted the add-module-factory branch October 16, 2025 21:48
ValerianRey added a commit that referenced this pull request Oct 20, 2025
* Add ModuleFactory and use it to instantiate models in tests
* Add get_in_out_shapes and use it to obtain input and output shapes in tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: test Conventional commit type for changes to tests.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant