Skip to content

feat(autogram): Add support for non-Tensor args#442

Merged
ValerianRey merged 9 commits intomainfrom
fix-args
Oct 2, 2025
Merged

feat(autogram): Add support for non-Tensor args#442
ValerianRey merged 9 commits intomainfrom
fix-args

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

No description provided.

@codecov
Copy link
Copy Markdown

codecov bot commented Oct 1, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/autogram/_engine.py 100.00% <ø> (ø)
src/torchjd/autogram/_module_hook_manager.py 100.00% <100.00%> (ø)
src/torchjd/autogram/_vjp.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@PierreQuinton
Copy link
Copy Markdown
Contributor

PierreQuinton commented Oct 1, 2025

The current fail highlights something slightly non-trivial for the batched case (and in particular in FunctionalVJP: torch.vmap(self._call_on_one_instance)). It is due to the fact that we cannot vmap over the 0'th dimension of a non-Tensor input to the module, quite expectedly. However in the case where some input Tensor is not batched, then we would have the same problem. I think this would be solved if we were able to handle better the in_dims of vmap for the modules (we could also specify None for such Tensors). The current version works only when "all inputs are Tensors and are batched on the first input". We could generalize to "all Tensor inputs are batched on the first input" and later solve the more general case.

Side note: If we implement a mapping from Module to Gramian computer, then we could also specify the in_dims there thus making these two discussions dependent.

@ValerianRey ValerianRey changed the title Fix type hints feat(autogram): Add support for any kind of args Oct 1, 2025
@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: autogram labels Oct 1, 2025
@ValerianRey
Copy link
Copy Markdown
Contributor Author

I completely agree. I fixed the issue with 49e5d5f

The way the in_dims are computed for now is just to give 0 to every tensor (note: grad_outputs are all tensors so no need to provide any None for them). In another PR we can customize the value that we give (and give something that depends on the module, args, and outputs, rather than always 0 for tensors).

@ValerianRey ValerianRey changed the title feat(autogram): Add support for any kind of args feat(autogram): Add support for non-Tensor args Oct 1, 2025
@ValerianRey
Copy link
Copy Markdown
Contributor Author

With WithModuleWithHybridPyTreeArg, I'm very confident that this works for any kind of args that contains tensors batched on dim 0.

* With this extra linear module, we now also check that the gradients wrt the args that require grad are correctly backpropagated.
Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except this, LGTM

@ValerianRey ValerianRey merged commit 6aae58c into main Oct 2, 2025
17 checks passed
@ValerianRey ValerianRey deleted the fix-args branch October 2, 2025 17:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: autogram

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants