Skip to content

refactor(autojac): Avoid concat in Jac#518

Merged
ValerianRey merged 2 commits intomainfrom
optimize-jac-transform-memory
Jan 21, 2026
Merged

refactor(autojac): Avoid concat in Jac#518
ValerianRey merged 2 commits intomainfrom
optimize-jac-transform-memory

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey commented Jan 16, 2026

When running the profiler, I had an architecture for which this line gave an out of memory error on CUDA. I also realized that it was trying to concatenate a list of 1 tensor. So it seems that there's some huge waste of memory when the number of chunks is 1. FYI a chunk is a subset of the rows of the Jacobian. Whenever it's possible, we compute all rows at the same time (parallel_chunk_size=None or large enough to compute all rows at once), so we have exactly 1 chunk, which is why it's critical to optimize for this case.

On AlexNet with batch size of 8, this saves 1.12 GiB of cuda memory at this point of the backward pass at least. Maybe this is not actually peak memory usage, so maybe it doesn't make a difference in the end, but I think it's still good to solve.

@ValerianRey ValerianRey added package: autojac cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements labels Jan 16, 2026
@ValerianRey ValerianRey self-assigned this Jan 16, 2026
@codecov
Copy link
Copy Markdown

codecov bot commented Jan 16, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/autojac/_transform/_jac.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@claude
Copy link
Copy Markdown

claude bot commented Jan 16, 2026

Code review

No issues found. Checked for bugs and CLAUDE.md compliance.

Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So weird that cat of a tuple containing one element doesn't return this element. This is not specified in its doc, so I think cat always allocate memory, which answers some questions we were wondering about yesterday.

Good catch anyway, thanks a lot, LGTM!

@ValerianRey
Copy link
Copy Markdown
Contributor Author

So weird that cat of a tuple containing one element doesn't return this element. This is not specified in its doc, so I think cat always allocate memory, which answers some questions we were wondering about yesterday.

Good catch anyway, thanks a lot, LGTM!

I think it's because they want their functions to always have the same level of security, in the sense that a user of concatenate may be used to modifying one of the tensors given to concatenate, and expects that to no change the tensor that was output by concatenate. If they had a special case for when a single element is given to concatenate, this behavior would also differ. This is the case for concatenate but also for many other functions I think.

@ValerianRey ValerianRey changed the title refactor: Avoid cat in Jac when not needed refactor(autojac): Avoid concat in Jac Jan 17, 2026
@ValerianRey
Copy link
Copy Markdown
Contributor Author

Here are the performance comparisons, using the profiler:

  • AlexNet - BS=4: 70-90 ms gain, out of 470 ms call to Jac.
  • Cifar10Model - BS=64: 3-4 ms gain, out of 1800 ms call to Jac.
  • GroupNormMobileNetV3Small - BS=8: 9-10 ms gain, out of 675 ms call to Jac.
  • InstanceNormMobileNetV2 - BS=2: 4 ms, out of 275 ms call to Jac.
  • InstanceNormResNet18 - BS=4: 19 ms, out of 700 ms call to Jac.
  • SqueezeNet - BS=4: 2-3 ms, out of 375 ms call to Jac.
  • WithTransformerLarge - BS=4: 74 ms, out of 415 ms call to Jac.

=> Reduction of the time taken by Jac by 0.2% to 18% on those models, on CPU.

Note that the time taken by Jac is generally ~99% of the time of backward.

@PierreQuinton
Copy link
Copy Markdown
Contributor

PierreQuinton commented Jan 20, 2026

Excellent news, the peak memory saving is almost certain. I would like to test the alternative of gramian computations: Reshape generalized matrix to matrix, compute gramian.

If this is efficient, we could avoid using sum and use addmm in a for loop (compute gramian of first jacobian, use addmm on the rest). This will yield a minor memory saving (nb tensors x gramian size) and addmm is probably faster than mm then add.

EDIT: Wrong PR, I think what misled me is the amount of saving you got on some models for such a tiny change, this looks very promissing for the few optimization we had in mind.

@ValerianRey
Copy link
Copy Markdown
Contributor Author

the peak memory saving is almost certain.

Well, the point at which this concat took place was definitely a local maximum in terms of memory usage but maybe not the global maximum of the whole backward call. So I'm not sure that we actually reduce peak memory with just this.

@ValerianRey
Copy link
Copy Markdown
Contributor Author

I'm stuck with memory profiling, let's merge this.

@ValerianRey ValerianRey merged commit 4f1016e into main Jan 21, 2026
19 checks passed
@ValerianRey ValerianRey deleted the optimize-jac-transform-memory branch January 28, 2026 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements package: autojac

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants