Reduced memory and debugged reverse AD#22
Merged
TonyZhou729 merged 8 commits intomainfrom May 4, 2026
Merged
Conversation
Replace the four (Nlna, Nk) integrand materialisations in Cl_one_ell
with a single lax.scan over lna_axis carrying four (Nk,) running sums.
Under the outer vmap over lensing_ells_indices (99 entries on the
default ell grid), the (Nell, Nlna, Nk) 3D tensors XLA was
materialising — not fusing — across the vmap are gone.
Full-pipeline GPU peak on fiducial LCDM drops from 5.813 GiB to
0.454 GiB (-12.8x); SS.get_Cl contribution collapses from 5.39 GiB
to 35 MiB (~150x inside get_Cl). Wall-clock is unchanged (full warm
+1.1%, inside measurement noise). ClTT/TE/EE at probe ells {2, 30,
200, 1000, 2000} agree with baseline to max rel 1.46e-13 — ULP-level
drift from reordered summation in float64 at Nlna=499, far below the
CLASS accuracy-test tolerance.
Bessel-table invariants (xphi{0,1,2}_tab columns, bessel_l_tab[idx],
column min/max bounds) are hoisted out of the scan body and captured
by phi{0,1,2}_local closures. chi = jnp.outer(tau0-tau, k_axis) is
no longer materialised; chi_l = (tau0 - tau[i]) * k_axis is a (Nk,)
vector per scan iter. Accumulators use dtype=sourceT0.dtype rather
than zeros_like(k_axis) to avoid silent downcast when k_axis is
float32 from geomspace.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Three major fixes:
diffraxdense output is a memory guzzler, so dropped that in places where we weren't really taking advantage. Also refactored spectral integrals so that, rather than making an (Nlna, Nk) tensor for each ell, now we just lax.scan over 1D (Nk,) accumulators. Peak memory down from ~5 GB --> a few hundred MB with both of these fixes, no loss to speed or accuracy.specsdict and has to live as its own argument toModel. Adds checkpointing, since even with the memory reductions above it was still pulling >20 GB/gradient.jnp.whereand reverse AD, where if one branch ofjnp.whereis infinite/NaN the whole calculation will give NaNs. Fixed here.