Skip to content

perf(aggregation): Prevent cuda sync in normalize#557

Merged
ValerianRey merged 2 commits intomainfrom
prevent-cuda-sync-normalize
Feb 5, 2026
Merged

perf(aggregation): Prevent cuda sync in normalize#557
ValerianRey merged 2 commits intomainfrom
prevent-cuda-sync-normalize

Conversation

@ValerianRey
Copy link
Copy Markdown
Contributor

With the previous implementation, normalize causes a cuda synchronization. It's not a big issue at all, because we need a cuda synchronization right after, during the call to project_weights, because it's always done on CPU. But I think it's good for three reasons:

  • Keep the number of cuda synchronizations minimal for performance reasons.
  • Avoiding synchronization in unexpected places makes it easier to analyze traces.
  • If we ever do UPGrad on cuda, it could make a big difference to avoid this cuda sync, because we may be down to 0 cuda sync for aggregation.

There might be a slight performance drop due to using torch.where which is element-wise with a condition that is scalar (and thus broadcasted). But since the gramian is never huge (especially if using UPGrad), this is really fine IMO. In my profiling, this torch.where takes 0.028 ms with batch size of 64.

So this is extremely minor but positive IMO.

@ValerianRey ValerianRey added package: aggregation cc: perf Conventional commit type for changes mostly focused on performance improvements (memory or speed). labels Feb 4, 2026
@ValerianRey ValerianRey self-assigned this Feb 4, 2026
@ValerianRey ValerianRey force-pushed the prevent-cuda-sync-normalize branch from 7fb2c75 to fcd865a Compare February 4, 2026 14:39
@codecov
Copy link
Copy Markdown

codecov bot commented Feb 4, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/_linalg/_gramian.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, we should also keep that in mind because I doubt this is the only place we do something like that (the aggregators maybe).

@ValerianRey ValerianRey merged commit f30a835 into main Feb 5, 2026
15 checks passed
@ValerianRey ValerianRey deleted the prevent-cuda-sync-normalize branch February 5, 2026 14:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: perf Conventional commit type for changes mostly focused on performance improvements (memory or speed). package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants