Skip to content

feat(autogram): Add Transformer support#447

Merged
ValerianRey merged 16 commits intomainfrom
add-transformer
Oct 10, 2025
Merged

feat(autogram): Add Transformer support#447
ValerianRey merged 16 commits intomainfrom
add-transformer

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey commented Oct 3, 2025

  • Refactor how used params are deduced from the module: we now combine the direct params and the indirectly used params, and we reuse the same code between both usages of this function
  • Add special case for indirectly used params of MultiheadAttention
  • Add WithMultiheadAttention, WithTransformer, and WithTransformerLarge tests
  • Add WithTransformerLarge speed test
  • Rename trainable to rg
  • Add test_batched_non_batched_equivalence_2
  • Update warnings about Transformers in the docstring of Engine

Otherwise we'd have to wait for pytorch/pytorch#126568 to be merged.

Tests seem to fail because of a user warning, need to filter it out.

TODO:

  • Isolate dirty code out of the rest
  • Filter out user warnings
  • Prevent hook creation for out_proj => More generally, we would like to prevent hooking module when they do not use any rg param. It's very hard to know in general which params are used, so we would need to do that on a case by case basis. However, in the case of mha.out_proj, I don't think it's worth the maintaining effort because out_proj isn't even called anyway, so its hook is never fired.
  • Make sure this will be future-compatible: if no out_proj module exists anymore (i.e. if the pytorch PR is merged), it should still work => I tried modifying pytorch code to simulate a potential future and it still works, so we're good. This is because we check hasattr("out_proj") now.
  • Add MultiheadAttention test
  • Adapt warning about Transformers: we need to tell people to disable dropout and maybe some other parameters, but that's basically it
  • Add usage example with Transformers => Not sure to do that yet actually, because it's not fast enough yet.
  • Test more Transformer parametrizations

@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: autogram labels Oct 3, 2025
@ValerianRey ValerianRey self-assigned this Oct 3, 2025
@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: autogram labels Oct 3, 2025
@PierreQuinton
Copy link
Copy Markdown
Contributor

PierreQuinton commented Oct 3, 2025

If PyTorch was to merge it quickly we would still only be compatible with very late versions which is a bit problematic. I think this suggests finding a solution independently.

PierreQuinton

This comment was marked as resolved.

@ValerianRey

This comment was marked as resolved.

@codecov
Copy link
Copy Markdown

codecov bot commented Oct 9, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/autogram/_engine.py 100.00% <ø> (ø)
src/torchjd/autogram/_module_hook_manager.py 100.00% <100.00%> (ø)
src/torchjd/autogram/_module_utils.py 100.00% <100.00%> (ø)
src/torchjd/autogram/_vjp.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

* It's actually not needed. We can either do mha.out_proj.get_parameter("weight") or mha.get_parameter("out_proj.weight"). The latter is much easier to handle since "out_proj.weight" is already the param name that we store.
@ValerianRey ValerianRey marked this pull request as ready for review October 9, 2025 23:32
@ValerianRey
Copy link
Copy Markdown
Contributor Author

@PierreQuinton need review and we can merge!

Co-authored-by: Pierre Quinton <pierre.quinton@epfl.ch>
@ValerianRey ValerianRey merged commit a211197 into main Oct 10, 2025
17 checks passed
@ValerianRey ValerianRey deleted the add-transformer branch October 10, 2025 14:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: autogram

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants