Skip to content

compute_fake_perturbation_tests memory grows ~15 GB per iteration -> OOM at iter ~16/50 on real datasets #3

@adamklie

Description

@adamklie

compute_fake_perturbation_tests memory grows ~15 GB per iteration → OOM at iteration ~16/50 on real datasets

Summary

Even with the args.reference_targets fix (filed separately, see #N), compute_fake_perturbation_tests exhausts memory on real-world datasets after ~16 iterations of 50, hitting SLURM-allocated 256 GB.

Reproduction

Run the U-test calibration on a Huangfu HUES8 cNMF h5mu (~270k cells × ~36k genes sparse, ~14k guides, 600 NT) with the standard params:

--number_run 50 --number_guide 6 \
--components 30 50 60 80 100 200 250 300 \
--sel_thresh 2.0 \
--compute_fake_perturbation_tests

Both runs (ESC=10555425, DE=10555432) crashed at iteration 16/50 of K=30, exit code 9. SLURM MaxRSS:

ESC: 261,868,524K ≈ 250 GB
DE:  264,975,876K ≈ 252 GB

Root cause hypothesis

In compute_fake_perturbation_tests (lines 121–164):

for k in args.components:
    mdata = mu.read(...)                # ~10 GB sparse for our dataset
    _assign_guide(mdata, mdata_guide)
    for i in range(args.number_run):    # 50 iterations
        _mdata = mdata.copy()           # ← deep-copies the full mdata each iter
        _mdata[args.prog_key].obsm[args.guide_assignment_key] = mdata[...][:, non_targeting_idx]
        ...
        for samp in unique:
            mdata_samp = _mdata[mask]
            test_stats_df = compute_perturbation_association(mdata_samp, ...)
            ...

Per-iteration full mdata.copy() retains references that Python's GC can't free fast enough between iterations. ~15 GB residual per iteration × 16 iterations ≈ 240 GB at OOM time.

Workaround we used

Dropped --number_run from 50 → 10. Peaks ~150 GB, fits in 256 GB allocation. Less-resolved null distribution but still meaningful for QC.

Possible fixes

  1. Avoid the deep copy. The fake-test only mutates _mdata[prog_key].obsm[guide_assignment_key] and two uns arrays — never the rna modality, which is what makes the copy expensive (~10 GB sparse). A surgical fix would mutate mdata[prog_key] in place each iteration and restore at the end, skipping mdata.copy() entirely.

  2. Force release between iterations: add del _mdata, mdata_samp and gc.collect() at end of each iteration. See oom_remedy.diff for a small bandaid patch. Doesn't fix the underlying redundant deep-copy but should keep peak memory bounded.

  3. Process-level isolation: invoke each iteration in a sub-process; OS reclaims memory cleanly on exit. Heavier but bullet-proof.

(1) is the right structural fix. (2) is a small bandaid that would let the current 50-iteration default work in 256 GB.

Why this wasn't caught earlier

Same answer as the related args.reference_targets issue: PerturbNMF is a publishable rewrite of Stanford's older cNMF_benchmarking tool. Pre-existing fake-test outputs at the Engreitz lab come from the older tool, not this new code path. We seem to be the first to drive the new PerturbNMF U-test code end-to-end on a real-sized dataset, so this leak is surfacing now rather than during the rewrite.

Environment

  • PerturbNMF main @ 8f7c9dd (with the line 160 args.reference_targets fix applied locally on a branch)
  • Python 3.10, mudata 0.4.x
  • Carter HPC (UCSD), 256 GB SLURM allocation per job
  • Real dataset: Huangfu HUES8 endoderm differentiation, 8 K values × sel_thresh=2.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions