Skip to content

Reduced memory and debugged reverse AD#22

Merged
TonyZhou729 merged 8 commits intomainfrom
mem_quick
May 4, 2026
Merged

Reduced memory and debugged reverse AD#22
TonyZhou729 merged 8 commits intomainfrom
mem_quick

Conversation

@cgiovanetti
Copy link
Copy Markdown
Collaborator

@cgiovanetti cgiovanetti commented May 3, 2026

Three major fixes:

  • Reduced memory usage dramatically. Turns out diffrax dense output is a memory guzzler, so dropped that in places where we weren't really taking advantage. Also refactored spectral integrals so that, rather than making an (Nlna, Nk) tensor for each ell, now we just lax.scan over 1D (Nk,) accumulators. Peak memory down from ~5 GB --> a few hundred MB with both of these fixes, no loss to speed or accuracy.
  • Added reverse AD option (adjoint option at initialization). Note unfortunately it cannot go in the specs dict and has to live as its own argument to Model. Adds checkpointing, since even with the memory reductions above it was still pulling >20 GB/gradient.
  • Fixed bug preventing reverse AD. There's a known JAX gotcha with nested jnp.where and reverse AD, where if one branch of jnp.where is infinite/NaN the whole calculation will give NaNs. Fixed here.

Cara Giovanetti and others added 8 commits April 21, 2026 12:01
Replace the four (Nlna, Nk) integrand materialisations in Cl_one_ell
with a single lax.scan over lna_axis carrying four (Nk,) running sums.
Under the outer vmap over lensing_ells_indices (99 entries on the
default ell grid), the (Nell, Nlna, Nk) 3D tensors XLA was
materialising — not fusing — across the vmap are gone.

Full-pipeline GPU peak on fiducial LCDM drops from 5.813 GiB to
0.454 GiB (-12.8x); SS.get_Cl contribution collapses from 5.39 GiB
to 35 MiB (~150x inside get_Cl). Wall-clock is unchanged (full warm
+1.1%, inside measurement noise). ClTT/TE/EE at probe ells {2, 30,
200, 1000, 2000} agree with baseline to max rel 1.46e-13 — ULP-level
drift from reordered summation in float64 at Nlna=499, far below the
CLASS accuracy-test tolerance.

Bessel-table invariants (xphi{0,1,2}_tab columns, bessel_l_tab[idx],
column min/max bounds) are hoisted out of the scan body and captured
by phi{0,1,2}_local closures. chi = jnp.outer(tau0-tau, k_axis) is
no longer materialised; chi_l = (tau0 - tau[i]) * k_axis is a (Nk,)
vector per scan iter. Accumulators use dtype=sourceT0.dtype rather
than zeros_like(k_axis) to avoid silent downcast when k_axis is
float32 from geomspace.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@cgiovanetti cgiovanetti requested a review from TonyZhou729 May 3, 2026 18:09
@TonyZhou729 TonyZhou729 merged commit 4b0a9f0 into main May 4, 2026
1 check passed
@cgiovanetti cgiovanetti deleted the mem_quick branch May 4, 2026 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants