refactor(autojac): Avoid concat in Jac#518
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests.
🚀 New features to boost your workflow:
|
Code reviewNo issues found. Checked for bugs and CLAUDE.md compliance. |
PierreQuinton
left a comment
There was a problem hiding this comment.
So weird that cat of a tuple containing one element doesn't return this element. This is not specified in its doc, so I think cat always allocate memory, which answers some questions we were wondering about yesterday.
Good catch anyway, thanks a lot, LGTM!
I think it's because they want their functions to always have the same level of security, in the sense that a user of concatenate may be used to modifying one of the tensors given to concatenate, and expects that to no change the tensor that was output by concatenate. If they had a special case for when a single element is given to concatenate, this behavior would also differ. This is the case for concatenate but also for many other functions I think. |
|
Here are the performance comparisons, using the profiler:
=> Reduction of the time taken by Jac by 0.2% to 18% on those models, on CPU. Note that the time taken by Jac is generally ~99% of the time of backward. |
|
Excellent news, the peak memory saving is almost certain. I would like to test the alternative of gramian computations: Reshape generalized matrix to matrix, compute gramian. If this is efficient, we could avoid using EDIT: Wrong PR, I think what misled me is the amount of saving you got on some models for such a tiny change, this looks very promissing for the few optimization we had in mind. |
Well, the point at which this concat took place was definitely a local maximum in terms of memory usage but maybe not the global maximum of the whole backward call. So I'm not sure that we actually reduce peak memory with just this. |
|
I'm stuck with memory profiling, let's merge this. |
When running the profiler, I had an architecture for which this line gave an out of memory error on CUDA. I also realized that it was trying to concatenate a list of 1 tensor. So it seems that there's some huge waste of memory when the number of chunks is 1. FYI a chunk is a subset of the rows of the Jacobian. Whenever it's possible, we compute all rows at the same time (
parallel_chunk_size=Noneor large enough to compute all rows at once), so we have exactly 1 chunk, which is why it's critical to optimize for this case.On AlexNet with batch size of 8, this saves 1.12 GiB of cuda memory at this point of the backward pass at least. Maybe this is not actually peak memory usage, so maybe it doesn't make a difference in the end, but I think it's still good to solve.