Skip to content

refactor(autojac): Avoid concat in jac_to_grad#515

Closed
PierreQuinton wants to merge 5 commits intomainfrom
avoid-concatenation-for-gramian-based-aggregators
Closed

refactor(autojac): Avoid concat in jac_to_grad#515
PierreQuinton wants to merge 5 commits intomainfrom
avoid-concatenation-for-gramian-based-aggregators

Conversation

@PierreQuinton
Copy link
Copy Markdown
Contributor

The goal is to half peak memory usage when the aggregator is Gramian based.

Does:

  • add _utils package to torchjd
  • put Matrix and PSDMatrix in _utils package
  • factor computation of gramian (from autogram and aggregation) into _utils.compute_gramian: Tensor -> PSDMatrix
  • add a function _utils.compute_gramian_sum (Note that _utils is the only package responsible for cast to PSDMatrix)
  • add _jacobian_based (current) and _gramian_based (new) strategy to jac_to_grad

@ValerianRey This could be cut into several more atomic PRs, this is why this is a draft (I just wanted to see what are the things/problems involved in this).

- moves Matrix and PSDMatrix to compute_gramian (not best position probably, but should be in _utils)
- Change return type of compute_gramian to PSDMatrix
- Add compute_gramian_sum (note that the responsability of casting to PSDMatrix is given to _utils now).
- add _gramian_based version of jac_to_grad. Note that we could put the tensordot(weights, jacobian, dims=1) in _utils as a weight_generalize_matrix method.
@PierreQuinton PierreQuinton added cc: feat Conventional commit type for new features. package: autojac labels Jan 15, 2026
@codecov
Copy link
Copy Markdown

codecov bot commented Jan 15, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/_utils/__init__.py 100.00% <100.00%> (ø)
src/torchjd/_utils/compute_gramian.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_aggregator_bases.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_aligned_mtl.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_cagrad.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_constant.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_dualproj.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_flattening.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_imtl_g.py 100.00% <100.00%> (ø)
src/torchjd/aggregation/_krum.py 100.00% <100.00%> (ø)
... and 12 more
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

grad_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> Optional[Tensor]:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a Optional[PSDMatrix]

@ValerianRey ValerianRey changed the title feat(autojac): Avoid concatenation for gramian based aggregators refactor(autojac): Avoid concat in jac_to_grad Jan 17, 2026
@ValerianRey ValerianRey deleted the avoid-concatenation-for-gramian-based-aggregators branch February 4, 2026 03:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: autojac

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants