Skip to content

Optimise Davidson Hv: precompute connected pairs, vmap element kernel, batched MGS#8

Draft
Copilot wants to merge 3 commits into
mainfrom
copilot/optimize-hamiltonian-vector-product
Draft

Optimise Davidson Hv: precompute connected pairs, vmap element kernel, batched MGS#8
Copilot wants to merge 3 commits into
mainfrom
copilot/optimize-hamiltonian-vector-product

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 3, 2026

Profiling showed hamiltonian_vector_product (three Python loops over connected pairs, called ~14×/diagonalisation) dominated runtime at 72% of total wall time. This PR eliminates those loops by precomputing the pair structure once and replacing the inner loop with a single JIT+vmap scatter-add.

Changes

determinants.py — JAX-native bitwise helpers + pair precomputation

  • jax_popcount, excitation_level_jax, first_set_bit_pos_jax, two_set_bit_pos_jax, phase_single_jax, phase_double_jax: pure-JAX implementations of the scalar kernels using only bitwise ops and jnp.where — no Python if/else, fully vmappable.
  • precompute_connections(dets_alpha, dets_beta, norb)(row_idx, col_idx): runs the existing sort/generator logic once per diagonalisation and returns all three connection types (β singles/doubles, opposite-spin doubles, α singles/doubles) as flat JAX int32 index arrays with no duplicates.

hamiltonian.py — vmappable element kernel + JIT-compiled matvec

  • _single_excitation_element_spin_jax, _double_same_spin_element_jax, _double_opposite_spin_element_jax: branch-free sub-kernels safe for jax.vmap.
  • hamiltonian_element_batch: top-level vmappable element function; selects the correct sub-kernel via nested jnp.where (all five excitation types evaluated; wrong branches produce finite values masked out, no NaN propagation).
  • precompute_h_vals(dets_alpha, dets_beta, row_idx, col_idx, norb, h_core, eri): JIT-compiled, computes all H_ij in one jax.vmap call — called once before Davidson starts.
  • scatter_add_matvec(coeffs, diag_h, h_vals, row_idx, col_idx): JIT-compiled matvec using scatter-add, zero Python loops:
    v_off = jnp.zeros(ndet).at[row_idx].add(h_vals * coeffs[col_idx])
    v_off = v_off.at[col_idx].add(h_vals * coeffs[row_idx])
    return diag_h * coeffs + v_off

cipsi.py — wire up fast path in _diagonalise_variational_space

Replaced the hamiltonian_vector_product closure with:

row_idx, col_idx = precompute_connections(dets_alpha, dets_beta, norb)
h_vals = precompute_h_vals(dets_alpha, dets_beta, row_idx, col_idx, norb, h_core, eri)
diag.davidson(lambda v: scatter_add_matvec(v, h_diag, h_vals, row_idx, col_idx))

h_vals is computed once; all ~14 Davidson matvec calls reuse it.

diagonaliser.py — batched MGS projection

Replaced the for k in range(m) Gram-Schmidt loop (O(m) Python iterations growing with subspace size) with a single batched projection, valid because V_m maintains orthonormal columns throughout:

# before
for k in range(m):
    vk = Vmat[:, k:k+1]
    new_vecs = new_vecs - vk * (vk.T @ new_vecs)

# after
new_vecs = corrections - V_m @ (V_m.T @ corrections)

@jphaupt
Copy link
Copy Markdown
Owner

jphaupt commented Apr 8, 2026

@copilot If you're going to use acronyms, please always define it at least once (when first using it). For example, I am not sure what MGS is in this context.

Copy link
Copy Markdown
Contributor Author

Copilot AI commented Apr 8, 2026

@copilot If you're going to use acronyms, please always define it at least once (when first using it). For example, I am not sure what MGS is in this context.

Done — "MGS" is now written out as "modified Gram-Schmidt (MGS)" on its first use in diagonaliser.py (commit cec1750). The PR title keeps the short form since it already appears after the definition.

Copilot AI requested a review from jphaupt April 8, 2026 12:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants