From b2826a4b8f4a474a96b7f7edf8a85ad608d6a4d6 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 20 Mar 2026 18:33:34 +0100 Subject: [PATCH 1/2] first draft --- docs/api/pertpy_gpu.md | 29 + docs/release-notes/0.15.0.md | 1 + src/rapids_singlecell/pertpy_gpu/__init__.py | 1 + .../pertpy_gpu/_guide_assignment.py | 506 ++++++++++++++++++ tests/pertpy/test_guide_assignment.py | 314 +++++++++++ 5 files changed, 851 insertions(+) create mode 100644 src/rapids_singlecell/pertpy_gpu/_guide_assignment.py create mode 100644 tests/pertpy/test_guide_assignment.py diff --git a/docs/api/pertpy_gpu.md b/docs/api/pertpy_gpu.md index bff75c06..7cc24552 100644 --- a/docs/api/pertpy_gpu.md +++ b/docs/api/pertpy_gpu.md @@ -38,3 +38,32 @@ .. automethod:: bootstrap :no-index: ``` + +## GuideAssignment + +```{eval-rst} +.. autosummary:: + :toctree: generated + + GuideAssignment +``` + +```{eval-rst} +.. autoclass:: GuideAssignment + :no-index: + + .. rubric:: Methods + + .. autosummary:: + + ~GuideAssignment.assign_by_threshold + ~GuideAssignment.assign_to_max_guide + ~GuideAssignment.assign_mixture_model + + .. automethod:: assign_by_threshold + :no-index: + .. automethod:: assign_to_max_guide + :no-index: + .. automethod:: assign_mixture_model + :no-index: +``` diff --git a/docs/release-notes/0.15.0.md b/docs/release-notes/0.15.0.md index f7c323e8..47dcfaa0 100644 --- a/docs/release-notes/0.15.0.md +++ b/docs/release-notes/0.15.0.md @@ -5,6 +5,7 @@ * Allow multiple control groups in ``onesided_distances`` for computing energy distances against several references in a single kernel launch {pr}`601` {smaller}`S Dicks` * Add ``contrast_distances`` to ``EDistanceMetric`` for computing energy distances directly from a contrasts DataFrame {pr}`603` {smaller}`S Dicks` * Improve L2 cache efficiency in ``edistance`` and ``co_occurrence`` kernels by always tiling the smaller group into shared memory, yielding up to 5x speedup for datasets with unequal group sizes {pr}`607` {smaller}`S Dicks` +* Add ``GuideAssignment`` to ``ptg`` for GPU-accelerated guide RNA assignment with batched Poisson–Gaussian EM, replacing pertpy's sequential per-guide MCMC {smaller}`S Dicks` ```{rubric} Bug fixes ``` diff --git a/src/rapids_singlecell/pertpy_gpu/__init__.py b/src/rapids_singlecell/pertpy_gpu/__init__.py index e19db417..5551eb59 100644 --- a/src/rapids_singlecell/pertpy_gpu/__init__.py +++ b/src/rapids_singlecell/pertpy_gpu/__init__.py @@ -1,3 +1,4 @@ from __future__ import annotations from ._distance import Distance, MeanVar +from ._guide_assignment import GuideAssignment diff --git a/src/rapids_singlecell/pertpy_gpu/_guide_assignment.py b/src/rapids_singlecell/pertpy_gpu/_guide_assignment.py new file mode 100644 index 00000000..4eb6fa51 --- /dev/null +++ b/src/rapids_singlecell/pertpy_gpu/_guide_assignment.py @@ -0,0 +1,506 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import cupy as cp +import numpy as np +import pandas as pd +from cupyx.scipy.sparse import issparse as cp_issparse +from cupyx.scipy.special import gammaln +from scanpy.get import _get_obs_rep, _set_obs_rep + +from rapids_singlecell.get import X_to_GPU + +if TYPE_CHECKING: + from anndata import AnnData + +_LOG_2PI = float(np.log(2.0 * np.pi)) + + +class GuideAssignment: + """GPU-accelerated guide RNA assignment. + + Provides threshold-based and mixture-model-based methods for assigning + cells to guide RNAs, compatible with pertpy's ``GuideAssignment`` API. + The mixture model uses a batched EM algorithm on GPU instead of + per-guide MCMC, yielding orders-of-magnitude speedup. + """ + + def assign_by_threshold( + self, + adata: AnnData, + *, + assignment_threshold: float, + layer: str | None = None, + output_layer: str = "assigned_guides", + ) -> None: + """Assign cells to gRNAs exceeding a count threshold. + + Each cell is assigned to every gRNA with at least + ``assignment_threshold`` counts. Expects unnormalized count data. + + Parameters + ---------- + adata + Annotated data matrix of shape ``n_obs x n_vars``. + assignment_threshold + Minimum count for a viable assignment. + layer + Layer with raw counts. Uses ``adata.X`` if ``None``. + output_layer + Key under which the binary assignment matrix is stored + in ``adata.layers``. + """ + X = X_to_GPU(_get_obs_rep(adata, layer=layer)) + + if cp_issparse(X): + from cupyx.scipy.sparse import csr_matrix as gpu_csr + + new_data = cp.where( + X.data >= assignment_threshold, + X.dtype.type(1), + X.dtype.type(0), + ) + result = gpu_csr( + (new_data, X.indices.copy(), X.indptr.copy()), shape=X.shape + ) + else: + result = cp.where(X >= assignment_threshold, cp.int8(1), cp.int8(0)) + + _set_obs_rep(adata, result, layer=output_layer) + + def assign_to_max_guide( + self, + adata: AnnData, + *, + assignment_threshold: float, + layer: str | None = None, + obs_key: str = "assigned_guide", + no_grna_assigned_key: str = "Negative", + ) -> None: + """Assign each cell to its most expressed gRNA. + + Each cell is assigned to the gRNA with the highest count, provided + that count is at least ``assignment_threshold``. Expects + unnormalized count data. + + Parameters + ---------- + adata + Annotated data matrix of shape ``n_obs x n_vars``. + assignment_threshold + Minimum count for a viable assignment. + layer + Layer with raw counts. Uses ``adata.X`` if ``None``. + obs_key + Column in ``adata.obs`` where the assignment is stored. + no_grna_assigned_key + Label for cells with no guide above threshold. + """ + X = X_to_GPU(_get_obs_rep(adata, layer=layer)) + var_names = np.asarray(adata.var_names) + + if cp_issparse(X): + X_dense = X.toarray() + else: + X_dense = X + + max_vals = X_dense.max(axis=1) + max_idx = X_dense.argmax(axis=1) + + max_vals_cpu = cp.asnumpy(max_vals).ravel() + max_idx_cpu = cp.asnumpy(max_idx).ravel() + + assigned = np.full(adata.n_obs, no_grna_assigned_key, dtype=object) + above = max_vals_cpu >= assignment_threshold + assigned[above] = var_names[max_idx_cpu[above]] + + adata.obs[obs_key] = assigned + + def assign_mixture_model( + self, + adata: AnnData, + *, + assigned_guides_key: str = "assigned_guide", + no_grna_assigned_key: str = "negative", + max_assignments_per_cell: int = 5, + multiple_grna_assigned_key: str = "multiple", + multiple_grna_assignment_string: str = "+", + only_return_results: bool = False, + max_iter: int = 50, + tol: float = 1e-4, + ) -> np.ndarray | None: + """Assign gRNAs using a GPU-accelerated Poisson–Gaussian mixture model. + + Fits a two-component mixture (Poisson background + Gaussian signal) + to the log₂-transformed non-zero counts of each guide simultaneously + using batched Expectation-Maximization on GPU. This replaces pertpy's + sequential per-guide MCMC with a single vectorised EM, giving + orders-of-magnitude speedup. + + Parameters + ---------- + adata + Annotated data matrix with guide RNA counts. + assigned_guides_key + Key in ``adata.obs`` for storing the assignment result. + no_grna_assigned_key + Label for cells negative for all gRNAs. + max_assignments_per_cell + Maximum number of gRNAs a cell can be assigned to. + multiple_grna_assigned_key + Label for cells exceeding ``max_assignments_per_cell``. + multiple_grna_assignment_string + Delimiter for joining multiple guide names. + only_return_results + If ``True``, return assignments without modifying ``adata``. + max_iter + Maximum number of EM iterations. + tol + Convergence tolerance on parameter changes. + + Returns + ------- + If ``only_return_results`` is ``True``, returns an array of + assignments. Otherwise modifies ``adata`` in-place and returns + ``None``. + """ + X = X_to_GPU(adata.X) + if cp_issparse(X): + X = X.toarray() + X = X.astype(cp.float32) + + _, n_guides = X.shape + var_names = np.asarray(adata.var_names) + + # Build padded arrays for batched EM + data_pad, mask, cell_indices, counts, valid_guides = _prepare_batched_data( + X, n_guides + ) + + if len(valid_guides) == 0: + warnings.warn( + "No guides have enough expressing cells for mixture model fitting.", + UserWarning, + stacklevel=2, + ) + series = pd.Series( + no_grna_assigned_key, + index=adata.obs_names, + ) + if only_return_results: + return series.values + adata.obs[assigned_guides_key] = series.values + return None + + # Run batched EM + lam, mu, sigma, pi0, assignments = _batched_em( + data_pad, mask, counts, max_iter=max_iter, tol=tol + ) + + # Store fitted parameters in adata.var + lam_cpu = cp.asnumpy(lam.ravel()) + mu_cpu = cp.asnumpy(mu.ravel()) + sigma_cpu = cp.asnumpy(sigma.ravel()) + pi0_cpu = cp.asnumpy(pi0.ravel()) + + for col in [ + "poisson_rate", + "gaussian_mean", + "gaussian_std", + "mix_probs_0", + "mix_probs_1", + ]: + if col not in adata.var.columns: + adata.var[col] = np.nan + + for i, g in enumerate(valid_guides): + adata.var.iloc[g, adata.var.columns.get_loc("poisson_rate")] = lam_cpu[i] + adata.var.iloc[g, adata.var.columns.get_loc("gaussian_mean")] = mu_cpu[i] + adata.var.iloc[g, adata.var.columns.get_loc("gaussian_std")] = sigma_cpu[i] + adata.var.iloc[g, adata.var.columns.get_loc("mix_probs_0")] = pi0_cpu[i] + adata.var.iloc[g, adata.var.columns.get_loc("mix_probs_1")] = ( + 1.0 - pi0_cpu[i] + ) + + # Map assignments back to (n_cells, n_guides) result + assignments_cpu = cp.asnumpy(assignments) # (n_valid_guides, max_nnz) + mask_cpu = cp.asnumpy(mask) + cell_indices_cpu = cp.asnumpy(cell_indices) + + result = pd.DataFrame(0, index=adata.obs_names, columns=var_names) + for i, g in enumerate(valid_guides): + valid = mask_cpu[i] + cells = cell_indices_cpu[i, valid] + assigned = assignments_cpu[i, valid] + result.iloc[cells, g] = assigned + + # Build final assignment series + series = pd.Series(no_grna_assigned_key, index=adata.obs_names) + num_assigned = result.sum(axis=1) + multi_mask = (num_assigned > 0) & (num_assigned <= max_assignments_per_cell) + series.loc[multi_mask] = result.loc[multi_mask].apply( + lambda row: multiple_grna_assignment_string.join( + row.index[row == 1].tolist() + ), + axis=1, + ) + series.loc[num_assigned > max_assignments_per_cell] = multiple_grna_assigned_key + + if only_return_results: + return series.values + + adata.obs[assigned_guides_key] = series.values + return None + + +def _prepare_batched_data( + X: cp.ndarray, + n_guides: int, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray, list[int]]: + """Build padded arrays for batched EM across guides. + + Returns + ------- + data_pad + ``(n_valid_guides, max_nnz)`` log₂ of non-zero counts, zero-padded. + mask + ``(n_valid_guides, max_nnz)`` boolean validity mask. + cell_indices + ``(n_valid_guides, max_nnz)`` original cell indices. + counts + ``(n_valid_guides,)`` number of non-zero cells per guide. + valid_guides + Column indices of guides with >= 2 non-zero cells. + """ + valid_guides = [] + nnz_per_guide = [] + nz_data = [] + nz_indices = [] + + for g in range(n_guides): + col = X[:, g] + nz_mask = col > 0 + nz_count = int(nz_mask.sum()) + if nz_count < 2: + if nz_count > 0: + warnings.warn( + f"Skipping guide index {g} as there are less than 2 cells " + "expressing the guide.", + UserWarning, + stacklevel=4, + ) + continue + valid_guides.append(g) + nnz_per_guide.append(nz_count) + nz_data.append(cp.log2(col[nz_mask])) + nz_indices.append(cp.where(nz_mask)[0]) + + if len(valid_guides) == 0: + empty = cp.empty((0, 0), dtype=cp.float32) + return empty, empty, empty.astype(cp.int64), cp.empty(0, dtype=cp.int32), [] + + max_nnz = max(nnz_per_guide) + n_valid = len(valid_guides) + + data_pad = cp.zeros((n_valid, max_nnz), dtype=cp.float32) + mask = cp.zeros((n_valid, max_nnz), dtype=cp.bool_) + cell_indices = cp.zeros((n_valid, max_nnz), dtype=cp.int64) + counts = cp.array(nnz_per_guide, dtype=cp.int32) + + for i in range(n_valid): + n = nnz_per_guide[i] + data_pad[i, :n] = nz_data[i].astype(cp.float32) + mask[i, :n] = True + cell_indices[i, :n] = nz_indices[i] + + return data_pad, mask, cell_indices, counts, valid_guides + + +def _batched_em( + data: cp.ndarray, + mask: cp.ndarray, + counts: cp.ndarray, + *, + max_iter: int = 50, + tol: float = 1e-4, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray]: + """Run batched Poisson–Gaussian EM across all guides simultaneously. + + Parameters + ---------- + data + ``(n_guides, max_nnz)`` padded log₂ counts. + mask + ``(n_guides, max_nnz)`` validity mask. + counts + ``(n_guides,)`` valid cell count per guide. + max_iter + Maximum EM iterations. + tol + Convergence tolerance. + + Returns + ------- + lam, mu, sigma, pi0 + Fitted parameters, each ``(n_guides, 1)``. + assignments + ``(n_guides, max_nnz)`` int8 array (1 = positive, 0 = negative). + """ + n_valid = mask.sum(axis=1, keepdims=True).astype(cp.float32) # (n_guides, 1) + + # --- Initialization via percentiles --- + lam, mu, sigma, pi0 = _initialize_params(data, mask, counts) + + # Prior hyperparameters (matching crispat / Braunger et al.) + # μ ~ Normal(3, 2) + prior_mu_mean = cp.float32(3.0) + prior_mu_var = cp.float32(4.0) # σ₀² = 2² = 4 + # λ ~ LogNormal(0, 1) — mode ≈ 0.37, mean ≈ 1.65 + prior_lam_mu = cp.float32(0.0) + prior_lam_sigma2 = cp.float32(1.0) + # scale ~ LogNormal(2, 1) — mode ≈ 2.72, mean ≈ 12.2 + prior_scale_mu = cp.float32(2.0) + prior_scale_sigma2 = cp.float32(1.0) + # π ~ Dirichlet(0.9, 0.1) + prior_alpha0 = cp.float32(0.9) + prior_alpha1 = cp.float32(0.1) + + for _ in range(max_iter): + # E-step (no separation penalty — crispat doesn't use one) + r0, r1 = _e_step(data, mask, lam=lam, mu=mu, sigma=sigma, pi0=pi0) + + # MAP M-step with prior regularization + n0 = r0.sum(axis=1, keepdims=True) + n1 = r1.sum(axis=1, keepdims=True) + + # λ MAP with LogNormal(μ₀, σ₀²) prior: + # MLE: λ = Σ(r₀·x) / Σr₀ + # LogNormal prior adds -(log(λ)-μ₀)²/(2σ₀²) - log(λ) to log-posterior + # We approximate via iterating: use MLE then shrink toward prior mode + lam_mle = (r0 * data).sum(axis=1, keepdims=True) / cp.maximum(n0, 1e-10) + log_lam_mle = cp.log(cp.maximum(lam_mle, 1e-10)) + # Bayesian shrinkage: weighted average in log-space + log_lam_new = (n0 * log_lam_mle + prior_lam_mu / prior_lam_sigma2) / ( + n0 + 1.0 / prior_lam_sigma2 + ) + lam_new = cp.exp(log_lam_new) + + # μ MAP: Normal(μ₀, σ₀²) prior + sigma_sq = sigma * sigma + mu_new = ( + (r1 * data).sum(axis=1, keepdims=True) / cp.maximum(sigma_sq, 1e-10) + + prior_mu_mean / prior_mu_var + ) / (n1 / cp.maximum(sigma_sq, 1e-10) + 1.0 / prior_mu_var) + + # σ MAP with LogNormal(μ₀, σ₀²) prior: + # Same approach as λ — MLE then shrink in log-space + diff = data - mu_new + sigma_sq_mle = (r1 * diff * diff).sum(axis=1, keepdims=True) / cp.maximum( + n1, 1e-10 + ) + sigma_mle = cp.maximum(cp.sqrt(sigma_sq_mle), 1e-2) + log_sig_mle = cp.log(sigma_mle) + log_sig_new = (n1 * log_sig_mle + prior_scale_mu / prior_scale_sigma2) / ( + n1 + 1.0 / prior_scale_sigma2 + ) + sigma_new = cp.maximum(cp.exp(log_sig_new), 1e-2) + + # π MAP: Dirichlet(α₀, α₁) prior + denom = n_valid + prior_alpha0 + prior_alpha1 - 2.0 + pi0_new = (n0 + prior_alpha0 - 1.0) / cp.maximum(denom, 1e-10) + pi0_new = cp.clip(pi0_new, 0.01, 0.99) + + # Convergence check + max_change = max( + float(cp.abs(lam_new - lam).max()), + float(cp.abs(mu_new - mu).max()), + float(cp.abs(sigma_new - sigma).max()), + float(cp.abs(pi0_new - pi0).max()), + ) + + lam, mu, sigma, pi0 = lam_new, mu_new, sigma_new, pi0_new + + if max_change < tol: + break + + # Final assignment: cell is positive if P(Gaussian) > 0.5 + r0, r1 = _e_step(data, mask, lam=lam, mu=mu, sigma=sigma, pi0=pi0) + assignments = (r1 > 0.5).astype(cp.int8) + + return lam, mu, sigma, pi0, assignments + + +def _initialize_params( + data: cp.ndarray, + mask: cp.ndarray, + counts: cp.ndarray, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray]: + """Initialize EM parameters using data-driven estimates. + + Uses the data distribution per guide to set starting points that + help the EM converge to pertpy-compatible solutions. + """ + n_guides = data.shape[0] + n_valid_f = mask.sum(axis=1, keepdims=True).astype(cp.float32) + + # Compute per-guide statistics on valid entries + masked_data = data * mask + mean_vals = masked_data.sum(axis=1, keepdims=True) / cp.maximum(n_valid_f, 1.0) + + # λ: start small — Exponential(0.2) prior has mode at 0 + # Use the overall mean as a rough guide but keep it modest + lam = cp.minimum(mean_vals * 0.5, cp.float32(0.5)) + lam = cp.maximum(lam, cp.float32(0.01)) + + # μ: Normal(3, 2) prior mean, but shift toward data if there's signal + # Use mean of top quartile as signal estimate + data_for_sort = cp.where(mask, data, -cp.inf) + sorted_data = cp.sort(data_for_sort, axis=1) + counts_f = counts.astype(cp.float32) + idx_75 = cp.clip((counts_f * 0.75).astype(cp.int64), 0, data.shape[1] - 1) + p75 = sorted_data[cp.arange(n_guides), idx_75].reshape(-1, 1) + # If p75 is 0 (most data at 0), fall back to prior mean + mu = cp.where(p75 > 0.5, p75, cp.float32(3.0)) + + # σ: HalfNormal(1) prior — start at 1.0 + sigma = cp.full((n_guides, 1), 1.0, dtype=cp.float32) + + # π₀: Dirichlet(0.85, 0.15) — start at prior + pi0 = cp.full((n_guides, 1), 0.85, dtype=cp.float32) + + return lam, mu, sigma, pi0 + + +def _e_step( + data: cp.ndarray, + mask: cp.ndarray, + *, + lam: cp.ndarray, + mu: cp.ndarray, + sigma: cp.ndarray, + pi0: cp.ndarray, +) -> tuple[cp.ndarray, cp.ndarray]: + """Compute responsibilities for Poisson and Gaussian components.""" + # Poisson log-PMF: x*log(λ) - λ - gammaln(x+1) + log_lam = cp.log(cp.maximum(lam, 1e-10)) + log_p0 = data * log_lam - lam - gammaln(data + 1.0) + cp.log(cp.maximum(pi0, 1e-10)) + + # Gaussian log-PDF: -0.5*((x-μ)/σ)^2 - log(σ) - 0.5*log(2π) + z = (data - mu) / cp.maximum(sigma, 1e-10) + log_sigma = cp.log(cp.maximum(sigma, 1e-10)) + log_p1 = ( + -0.5 * z * z - log_sigma - 0.5 * _LOG_2PI + cp.log(cp.maximum(1.0 - pi0, 1e-10)) + ) + + # Numerically stable softmax + log_total = cp.logaddexp(log_p0, log_p1) + r0 = cp.exp(log_p0 - log_total) + r1 = 1.0 - r0 + + # Zero out padding + r0 = r0 * mask + r1 = r1 * mask + + return r0, r1 diff --git a/tests/pertpy/test_guide_assignment.py b/tests/pertpy/test_guide_assignment.py new file mode 100644 index 00000000..06f26414 --- /dev/null +++ b/tests/pertpy/test_guide_assignment.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +import cupy as cp +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData +from cupyx.scipy.sparse import csr_matrix as gpu_csr + +import rapids_singlecell as rsc + + +@pytest.fixture +def guide_adata() -> AnnData: + """Synthetic guide RNA dataset with clear bimodal signal. + + 200 cells, 8 guides. For each guide, ~30 % of cells have high counts + (signal, drawn from Poisson(lambda=50)) and the rest have low counts + (background, drawn from Poisson(lambda=2)). + """ + rng = np.random.default_rng(42) + n_cells = 200 + n_guides = 8 + n_signal = 60 # 30% signal + + X = np.zeros((n_cells, n_guides), dtype=np.float32) + for g in range(n_guides): + bg = rng.poisson(lam=2, size=n_cells - n_signal).astype(np.float32) + sig = rng.poisson(lam=50, size=n_signal).astype(np.float32) + X[:, g] = np.concatenate([bg, sig]) + rng.shuffle(X[:, g]) + + var = pd.DataFrame(index=[f"guide_{i}" for i in range(n_guides)]) + obs = pd.DataFrame(index=[f"cell_{i}" for i in range(n_cells)]) + return AnnData(X=cp.array(X), obs=obs, var=var) + + +@pytest.fixture +def guide_adata_sparse(guide_adata: AnnData) -> AnnData: + """Same data as guide_adata but stored as CuPy CSR sparse.""" + adata = guide_adata.copy() + adata.X = gpu_csr(adata.X) + return adata + + +# ------------------------------------------------------------------ # +# assign_by_threshold +# ------------------------------------------------------------------ # + + +def test_assign_by_threshold(guide_adata: AnnData) -> None: + ga = rsc.ptg.GuideAssignment() + ga.assign_by_threshold(guide_adata, assignment_threshold=5) + + result = guide_adata.layers["assigned_guides"] + X = guide_adata.X + if hasattr(X, "get"): + X = X.get() + if hasattr(result, "toarray"): + result = result.toarray() + if hasattr(result, "get"): + result = result.get() + + expected = (X >= 5).astype(np.int8) + np.testing.assert_array_equal(result, expected) + + +def test_assign_by_threshold_sparse(guide_adata_sparse: AnnData) -> None: + ga = rsc.ptg.GuideAssignment() + ga.assign_by_threshold(guide_adata_sparse, assignment_threshold=5) + + result = guide_adata_sparse.layers["assigned_guides"] + if hasattr(result, "toarray"): + result = result.toarray() + if hasattr(result, "get"): + result = result.get() + + X = guide_adata_sparse.X + if hasattr(X, "toarray"): + X = X.toarray() + if hasattr(X, "get"): + X = X.get() + + expected = (X >= 5).astype(np.int8) + np.testing.assert_array_equal(result, expected) + + +# ------------------------------------------------------------------ # +# assign_to_max_guide +# ------------------------------------------------------------------ # + + +def test_assign_to_max_guide(guide_adata: AnnData) -> None: + ga = rsc.ptg.GuideAssignment() + ga.assign_to_max_guide(guide_adata, assignment_threshold=5) + + assigned = guide_adata.obs["assigned_guide"] + X = guide_adata.X + if hasattr(X, "get"): + X = X.get() + + for i in range(guide_adata.n_obs): + row = X[i] + max_val = row.max() + if max_val >= 5: + expected_guide = guide_adata.var_names[int(row.argmax())] + assert assigned.iloc[i] == expected_guide, ( + f"Cell {i}: expected {expected_guide}, got {assigned.iloc[i]}" + ) + else: + assert assigned.iloc[i] == "Negative", ( + f"Cell {i}: expected Negative, got {assigned.iloc[i]}" + ) + + +def test_assign_to_max_guide_sparse(guide_adata_sparse: AnnData) -> None: + ga = rsc.ptg.GuideAssignment() + ga.assign_to_max_guide(guide_adata_sparse, assignment_threshold=5) + + assigned = guide_adata_sparse.obs["assigned_guide"] + X = guide_adata_sparse.X.toarray() + if hasattr(X, "get"): + X = X.get() + + for i in range(guide_adata_sparse.n_obs): + row = X[i] + max_val = row.max() + if max_val >= 5: + expected_guide = guide_adata_sparse.var_names[int(row.argmax())] + assert assigned.iloc[i] == expected_guide + else: + assert assigned.iloc[i] == "Negative" + + +def test_assign_to_max_guide_below_threshold() -> None: + """All counts below threshold → all Negative.""" + rng = np.random.default_rng(0) + X = rng.poisson(lam=1, size=(50, 4)).astype(np.float32) + adata = AnnData( + X=cp.array(X), + var=pd.DataFrame(index=[f"g{i}" for i in range(4)]), + obs=pd.DataFrame(index=[f"c{i}" for i in range(50)]), + ) + ga = rsc.ptg.GuideAssignment() + ga.assign_to_max_guide(adata, assignment_threshold=100) + assert (adata.obs["assigned_guide"] == "Negative").all() + + +# ------------------------------------------------------------------ # +# assign_mixture_model +# ------------------------------------------------------------------ # + + +def test_mixture_model_separation(guide_adata: AnnData) -> None: + """EM should separate the clearly bimodal signal.""" + ga = rsc.ptg.GuideAssignment() + ga.assign_mixture_model(guide_adata, max_iter=100) + + # Check that at least some cells are assigned (not all negative) + assigned = guide_adata.obs["assigned_guide"] + n_assigned = (assigned != "negative").sum() + assert n_assigned > 0, "No cells were assigned to any guide" + + # With such clear separation (Poisson(2) vs Poisson(50)), + # most high-count cells should be positive + X = guide_adata.X + if hasattr(X, "get"): + X = X.get() + + # For each guide, cells with count >= 20 should mostly be assigned + for g in range(X.shape[1]): + high_count_cells = X[:, g] >= 20 + if high_count_cells.sum() == 0: + continue + # At least 80% of high-count cells should have this guide in assignment + guide_name = guide_adata.var_names[g] + has_guide = assigned.str.contains(guide_name, na=False) + overlap = (high_count_cells & has_guide.values).sum() + assert overlap / high_count_cells.sum() >= 0.8, ( + f"Guide {guide_name}: only {overlap}/{high_count_cells.sum()} " + "high-count cells assigned" + ) + + +def test_mixture_model_stores_params(guide_adata: AnnData) -> None: + ga = rsc.ptg.GuideAssignment() + ga.assign_mixture_model(guide_adata) + + for col in [ + "poisson_rate", + "gaussian_mean", + "gaussian_std", + "mix_probs_0", + "mix_probs_1", + ]: + assert col in guide_adata.var.columns, f"Missing column: {col}" + + # Parameters should be finite for all guides + for col in ["poisson_rate", "gaussian_mean", "gaussian_std"]: + vals = guide_adata.var[col].dropna() + assert len(vals) > 0, f"No values for {col}" + assert np.all(np.isfinite(vals)), f"Non-finite values in {col}" + + # Poisson rate should be < Gaussian mean (anti-flip) + rates = guide_adata.var["poisson_rate"].dropna() + means = guide_adata.var["gaussian_mean"].dropna() + assert (rates < means).all(), "Poisson rate should be < Gaussian mean" + + +def test_mixture_model_sparse_input(guide_adata_sparse: AnnData) -> None: + ga = rsc.ptg.GuideAssignment() + ga.assign_mixture_model(guide_adata_sparse) + + assigned = guide_adata_sparse.obs["assigned_guide"] + n_assigned = (assigned != "negative").sum() + assert n_assigned > 0 + + +def test_mixture_model_only_return_results(guide_adata: AnnData) -> None: + ga = rsc.ptg.GuideAssignment() + result = ga.assign_mixture_model(guide_adata, only_return_results=True) + + assert result is not None + assert len(result) == guide_adata.n_obs + assert isinstance(result, np.ndarray) + + +def test_mixture_model_skip_low_count() -> None: + """Guides with < 2 expressing cells should be skipped with a warning.""" + X = np.zeros((50, 3), dtype=np.float32) + # guide 0: only 1 cell expressing + X[0, 0] = 10.0 + # guide 1: no cells expressing + # guide 2: good signal + X[:20, 2] = np.random.default_rng(0).poisson(lam=50, size=20).astype(np.float32) + + adata = AnnData( + X=cp.array(X), + var=pd.DataFrame(index=["low", "empty", "good"]), + obs=pd.DataFrame(index=[f"c{i}" for i in range(50)]), + ) + + ga = rsc.ptg.GuideAssignment() + with pytest.warns(UserWarning, match="less than 2 cells"): + ga.assign_mixture_model(adata) + + # "good" guide should have some assignments + assigned = adata.obs["assigned_guide"] + assert assigned.str.contains("good", na=False).any() + + +def test_multiple_guide_assignment() -> None: + """Cells assigned to multiple guides get joined names.""" + rng = np.random.default_rng(99) + n_cells = 100 + n_guides = 3 + + # All cells get high counts for guide 0 and 1, low for guide 2 + X = np.zeros((n_cells, n_guides), dtype=np.float32) + X[:, 0] = rng.poisson(lam=50, size=n_cells).astype(np.float32) + X[:, 1] = rng.poisson(lam=50, size=n_cells).astype(np.float32) + # guide 2: mix of low and high + X[:30, 2] = rng.poisson(lam=50, size=30).astype(np.float32) + X[30:, 2] = rng.poisson(lam=2, size=70).astype(np.float32) + + # Add some background cells + bg_cells = rng.choice(n_cells, size=20, replace=False) + X[bg_cells, 0] = rng.poisson(lam=2, size=20).astype(np.float32) + X[bg_cells, 1] = rng.poisson(lam=2, size=20).astype(np.float32) + + adata = AnnData( + X=cp.array(X), + var=pd.DataFrame(index=["gA", "gB", "gC"]), + obs=pd.DataFrame(index=[f"c{i}" for i in range(n_cells)]), + ) + + ga = rsc.ptg.GuideAssignment() + ga.assign_mixture_model(adata) + + assigned = adata.obs["assigned_guide"] + # Some cells should have multi-guide assignments (containing "+") + has_multi = assigned.str.contains("+", na=False, regex=False) + assert has_multi.any(), "Expected some cells with multiple guide assignments" + + +def test_multiple_guide_max_cap() -> None: + """Cells exceeding max_assignments_per_cell get the multiple key.""" + rng = np.random.default_rng(7) + n_cells = 50 + n_guides = 6 + + # All cells get very high counts for all guides + X = rng.poisson(lam=100, size=(n_cells, n_guides)).astype(np.float32) + # Add some background cells for EM to separate + X[:10, :] = rng.poisson(lam=2, size=(10, n_guides)).astype(np.float32) + + adata = AnnData( + X=cp.array(X), + var=pd.DataFrame(index=[f"g{i}" for i in range(n_guides)]), + obs=pd.DataFrame(index=[f"c{i}" for i in range(n_cells)]), + ) + + ga = rsc.ptg.GuideAssignment() + ga.assign_mixture_model( + adata, + max_assignments_per_cell=2, + multiple_grna_assigned_key="too_many", + ) + + assigned = adata.obs["assigned_guide"] + # Some cells should be capped + assert (assigned == "too_many").any(), ( + "Expected some cells capped at max assignments" + ) From 50b20cdad9db077e62b6967d2c915fd12ce87df0 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 27 Apr 2026 18:21:10 +0200 Subject: [PATCH 2/2] add cuda path --- CMakeLists.txt | 1 + src/rapids_singlecell/_cuda/__init__.py | 1 + .../guide_assignment/guide_assignment.cu | 91 +++++ .../kernels_guide_assignment.cuh | 351 ++++++++++++++++++ .../pertpy_gpu/_guide_assignment.py | 273 ++++++++++++-- tests/pertpy/test_guide_assignment.py | 103 +++++ 6 files changed, 788 insertions(+), 32 deletions(-) create mode 100644 src/rapids_singlecell/_cuda/guide_assignment/guide_assignment.cu create mode 100644 src/rapids_singlecell/_cuda/guide_assignment/kernels_guide_assignment.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index cacf9849..af31c5c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,6 +82,7 @@ if (RSC_BUILD_EXTENSIONS) add_nb_cuda_module(_ligrec_cuda src/rapids_singlecell/_cuda/ligrec/ligrec.cu) add_nb_cuda_module(_pv_cuda src/rapids_singlecell/_cuda/pv/pv.cu) add_nb_cuda_module(_edistance_cuda src/rapids_singlecell/_cuda/edistance/edistance.cu) + add_nb_cuda_module(_guide_assignment_cuda src/rapids_singlecell/_cuda/guide_assignment/guide_assignment.cu) add_nb_cuda_module(_hvg_cuda src/rapids_singlecell/_cuda/hvg/hvg.cu) add_nb_cuda_module(_kde_cuda src/rapids_singlecell/_cuda/kde/kde.cu) add_nb_cuda_module(_wilcoxon_cuda src/rapids_singlecell/_cuda/wilcoxon/wilcoxon.cu) diff --git a/src/rapids_singlecell/_cuda/__init__.py b/src/rapids_singlecell/_cuda/__init__.py index 35e82a0d..fd9672e2 100644 --- a/src/rapids_singlecell/_cuda/__init__.py +++ b/src/rapids_singlecell/_cuda/__init__.py @@ -20,6 +20,7 @@ "_bbknn_cuda", "_cooc_cuda", "_edistance_cuda", + "_guide_assignment_cuda", "_harmony_clustering_cuda", "_harmony_colsum_cuda", "_harmony_correction_batched_cuda", diff --git a/src/rapids_singlecell/_cuda/guide_assignment/guide_assignment.cu b/src/rapids_singlecell/_cuda/guide_assignment/guide_assignment.cu new file mode 100644 index 00000000..836b9000 --- /dev/null +++ b/src/rapids_singlecell/_cuda/guide_assignment/guide_assignment.cu @@ -0,0 +1,91 @@ +#include + +#include "../nb_types.h" + +#include "kernels_guide_assignment.cuh" + +using namespace nb::literals; + +static inline void launch_assign_threshold_dense( + const float* X, const int* valid_guides, const float* lam, const float* mu, + const float* sigma, const float* pi0, bool* assignments, float* thresholds, + int n_cells, int n_guides, int n_valid_guides, float posterior_threshold, + cudaStream_t stream) { + if (n_valid_guides == 0) return; + + dim3 block(BLOCK_SIZE); + dim3 grid(n_valid_guides); + assign_threshold_dense_kernel<<>>( + X, valid_guides, lam, mu, sigma, pi0, assignments, thresholds, n_cells, + n_guides, posterior_threshold); + CUDA_CHECK_LAST_ERROR(assign_threshold_dense_kernel); +} + +static inline void launch_fit_assign_dense( + const float* X, bool* assignments, float* thresholds, float* lam, float* mu, + float* sigma, float* pi0, bool* valid_mask, int* nonzero_counts, + int* max_counts, int n_cells, int n_guides, int max_iter, float tol, + float posterior_threshold, cudaStream_t stream) { + if (n_guides == 0) return; + + dim3 block(BLOCK_SIZE); + dim3 grid(n_guides); + fit_assign_dense_kernel<<>>( + X, assignments, thresholds, lam, mu, sigma, pi0, valid_mask, + nonzero_counts, max_counts, n_cells, n_guides, max_iter, tol, + posterior_threshold); + CUDA_CHECK_LAST_ERROR(fit_assign_dense_kernel); +} + +template +void register_bindings(nb::module_& m) { + m.def( + "assign_threshold_dense", + [](gpu_array_c X, + gpu_array_c valid_guides, + gpu_array_c lam, + gpu_array_c mu, + gpu_array_c sigma, + gpu_array_c pi0, + gpu_array_c assignments, + gpu_array_c thresholds, int n_cells, int n_guides, + int n_valid_guides, float posterior_threshold, + std::uintptr_t stream) { + launch_assign_threshold_dense( + X.data(), valid_guides.data(), lam.data(), mu.data(), + sigma.data(), pi0.data(), assignments.data(), thresholds.data(), + n_cells, n_guides, n_valid_guides, posterior_threshold, + (cudaStream_t)stream); + }, + "X"_a, "valid_guides"_a, "lam"_a, "mu"_a, "sigma"_a, "pi0"_a, + "assignments"_a, "thresholds"_a, nb::kw_only(), "n_cells"_a, + "n_guides"_a, "n_valid_guides"_a, "posterior_threshold"_a, + "stream"_a = 0); + + m.def( + "fit_assign_dense", + [](gpu_array_c X, + gpu_array_c assignments, + gpu_array_c thresholds, + gpu_array_c lam, gpu_array_c mu, + gpu_array_c sigma, gpu_array_c pi0, + gpu_array_c valid_mask, + gpu_array_c nonzero_counts, + gpu_array_c max_counts, int n_cells, int n_guides, + int max_iter, float tol, float posterior_threshold, + std::uintptr_t stream) { + launch_fit_assign_dense( + X.data(), assignments.data(), thresholds.data(), lam.data(), + mu.data(), sigma.data(), pi0.data(), valid_mask.data(), + nonzero_counts.data(), max_counts.data(), n_cells, n_guides, + max_iter, tol, posterior_threshold, (cudaStream_t)stream); + }, + "X"_a, "assignments"_a, "thresholds"_a, "lam"_a, "mu"_a, "sigma"_a, + "pi0"_a, "valid_mask"_a, "nonzero_counts"_a, "max_counts"_a, + nb::kw_only(), "n_cells"_a, "n_guides"_a, "max_iter"_a, "tol"_a, + "posterior_threshold"_a, "stream"_a = 0); +} + +NB_MODULE(_guide_assignment_cuda, m) { + REGISTER_GPU_BINDINGS(register_bindings, m); +} diff --git a/src/rapids_singlecell/_cuda/guide_assignment/kernels_guide_assignment.cuh b/src/rapids_singlecell/_cuda/guide_assignment/kernels_guide_assignment.cuh new file mode 100644 index 00000000..b595c35b --- /dev/null +++ b/src/rapids_singlecell/_cuda/guide_assignment/kernels_guide_assignment.cuh @@ -0,0 +1,351 @@ +#pragma once + +#include + +#include + +constexpr int BLOCK_SIZE = 256; +constexpr int NUM_WARPS = BLOCK_SIZE / 32; +constexpr int HIST_BINS = 4096; + +template +__device__ inline T warp_reduce_sum(T val) { + for (int offset = warpSize / 2; offset > 0; offset >>= 1) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + return val; +} + +__device__ inline int warp_reduce_max(int val) { + for (int offset = warpSize / 2; offset > 0; offset >>= 1) { + val = max(val, __shfl_down_sync(0xffffffff, val, offset)); + } + return val; +} + +template +__device__ inline T block_reduce_sum_thread0(T val, T* warp_sums) { + int lane = threadIdx.x & (warpSize - 1); + int warp = threadIdx.x >> 5; + int n_warps = (blockDim.x + warpSize - 1) / warpSize; + + val = warp_reduce_sum(val); + if (lane == 0) { + warp_sums[warp] = val; + } + __syncthreads(); + + val = threadIdx.x < n_warps ? warp_sums[threadIdx.x] : static_cast(0); + if (warp == 0) { + val = warp_reduce_sum(val); + } + return val; +} + +__device__ inline int block_reduce_max_thread0(int val, int* warp_maxes) { + int lane = threadIdx.x & (warpSize - 1); + int warp = threadIdx.x >> 5; + int n_warps = (blockDim.x + warpSize - 1) / warpSize; + + val = warp_reduce_max(val); + if (lane == 0) { + warp_maxes[warp] = val; + } + __syncthreads(); + + val = threadIdx.x < n_warps ? warp_maxes[threadIdx.x] : 0; + if (warp == 0) { + val = warp_reduce_max(val); + } + return val; +} + +__device__ inline float poisson_log_prob(float value, float lam) { + lam = fmaxf(lam, 1.0e-10f); + return value * logf(lam) - lam - lgammaf(value + 1.0f); +} + +__device__ inline float normal_log_prob(float value, float mu, float sigma) { + constexpr float log_2pi = 1.8378770664093453f; + sigma = fmaxf(sigma, 1.0e-10f); + float z = (value - mu) / sigma; + return -0.5f * z * z - logf(sigma) - 0.5f * log_2pi; +} + +__global__ void assign_threshold_dense_kernel( + const float* __restrict__ X, const int* __restrict__ valid_guides, + const float* __restrict__ lam, const float* __restrict__ mu, + const float* __restrict__ sigma, const float* __restrict__ pi0, + bool* __restrict__ assignments, float* __restrict__ thresholds, int n_cells, + int n_guides, float posterior_threshold) { + int valid_idx = blockIdx.x; + int guide = valid_guides[valid_idx]; + int tid = threadIdx.x; + + __shared__ int max_counts[NUM_WARPS]; + __shared__ float guide_threshold; + + int local_max = 0; + for (int cell = tid; cell < n_cells; cell += blockDim.x) { + float count = X[cell * n_guides + guide]; + int count_int = static_cast(ceilf(count)); + local_max = max(local_max, count_int); + } + int guide_max_count = block_reduce_max_thread0(local_max, max_counts); + + if (tid == 0) { + float threshold = NAN; + float guide_lam = lam[valid_idx]; + float guide_mu = mu[valid_idx]; + float guide_sigma = sigma[valid_idx]; + float guide_pi0 = + fminf(fmaxf(pi0[valid_idx], 1.0e-10f), 1.0f - 1.0e-10f); + float guide_pi1 = 1.0f - guide_pi0; + + for (int raw_count = 1; raw_count <= guide_max_count; ++raw_count) { + float log_count = log2f(static_cast(raw_count)); + float log_p0 = + poisson_log_prob(log_count, guide_lam) + logf(guide_pi0); + float log_p1 = normal_log_prob(log_count, guide_mu, guide_sigma) + + logf(guide_pi1); + float posterior = 1.0f / (1.0f + expf(log_p0 - log_p1)); + if (posterior > posterior_threshold) { + threshold = static_cast(raw_count); + break; + } + } + + thresholds[valid_idx] = threshold; + guide_threshold = threshold; + } + __syncthreads(); + + bool has_threshold = !isnan(guide_threshold); + for (int cell = tid; cell < n_cells; cell += blockDim.x) { + float count = X[cell * n_guides + guide]; + assignments[valid_idx * n_cells + cell] = + has_threshold && count >= guide_threshold; + } +} + +__global__ void fit_assign_dense_kernel( + const float* __restrict__ X, bool* __restrict__ assignments, + float* __restrict__ thresholds, float* __restrict__ lam_out, + float* __restrict__ mu_out, float* __restrict__ sigma_out, + float* __restrict__ pi0_out, bool* __restrict__ valid_mask, + int* __restrict__ nonzero_counts, int* __restrict__ max_counts, int n_cells, + int n_guides, int max_iter, float tol, float posterior_threshold) { + int guide = blockIdx.x; + int tid = threadIdx.x; + + __shared__ float sum_log_warp[NUM_WARPS]; + __shared__ int nz_counts_warp[NUM_WARPS]; + __shared__ int raw_max_counts_warp[NUM_WARPS]; + __shared__ float red_n0_warp[NUM_WARPS]; + __shared__ float red_n1_warp[NUM_WARPS]; + __shared__ float red_sum_r0_y_warp[NUM_WARPS]; + __shared__ float red_sum_r1_y_warp[NUM_WARPS]; + __shared__ float red_sum_r1_y2_warp[NUM_WARPS]; + __shared__ int count_hist[HIST_BINS + 1]; + __shared__ float guide_threshold; + __shared__ float guide_sum_log; + __shared__ int guide_nz; + __shared__ int guide_max_raw; + + float local_sum = 0.0f; + int local_nz = 0; + int local_max_raw = 0; + + // TODO: Template this kernel on C/F layout. One block scans one guide, so + // F-order input would make the cell loop contiguous instead of strided by + // n_guides. + for (int cell = tid; cell < n_cells; cell += blockDim.x) { + float count = X[cell * n_guides + guide]; + if (count > 0.0f) { + float log_count = log2f(count); + local_sum += log_count; + ++local_nz; + local_max_raw = max(local_max_raw, static_cast(ceilf(count))); + } + } + + float reduced_sum_log = block_reduce_sum_thread0(local_sum, sum_log_warp); + int reduced_nz = block_reduce_sum_thread0(local_nz, nz_counts_warp); + int reduced_max_raw = + block_reduce_max_thread0(local_max_raw, raw_max_counts_warp); + + if (tid == 0) { + guide_sum_log = reduced_sum_log; + guide_nz = reduced_nz; + guide_max_raw = reduced_max_raw; + nonzero_counts[guide] = guide_nz; + max_counts[guide] = guide_max_raw; + valid_mask[guide] = guide_nz >= 2 && guide_max_raw >= 2; + } + __syncthreads(); + + if (guide_nz < 2 || guide_max_raw < 2) { + if (tid == 0) { + thresholds[guide] = NAN; + lam_out[guide] = NAN; + mu_out[guide] = NAN; + sigma_out[guide] = NAN; + pi0_out[guide] = NAN; + guide_threshold = NAN; + } + for (int cell = tid; cell < n_cells; cell += blockDim.x) { + assignments[guide * n_cells + cell] = false; + } + return; + } + + float n_valid = static_cast(guide_nz); + float mean_log = guide_sum_log / fmaxf(n_valid, 1.0f); + float lam = fminf(fmaxf(mean_log * 0.5f, 0.01f), 0.5f); + for (int bin = tid; bin <= HIST_BINS; bin += blockDim.x) { + count_hist[bin] = 0; + } + __syncthreads(); + + for (int cell = tid; cell < n_cells; cell += blockDim.x) { + float count = X[cell * n_guides + guide]; + if (count > 0.0f) { + int bin = min(max(static_cast(ceilf(count)), 1), HIST_BINS); + atomicAdd(&count_hist[bin], 1); + } + } + __syncthreads(); + + __shared__ float init_mu; + if (tid == 0) { + int target = static_cast(n_valid * 0.75f); + int cumulative = 0; + int percentile_bin = 1; + for (int bin = 1; bin <= HIST_BINS; ++bin) { + cumulative += count_hist[bin]; + if (cumulative > target) { + percentile_bin = bin; + break; + } + } + float p75 = log2f(static_cast(percentile_bin)); + init_mu = p75 > 0.5f ? p75 : 3.0f; + } + __syncthreads(); + + float mu = init_mu; + float sigma = 1.0f; + float pi0 = 0.85f; + + for (int iter = 0; iter < max_iter; ++iter) { + float local_n0 = 0.0f; + float local_n1 = 0.0f; + float local_sum_r0_y = 0.0f; + float local_sum_r1_y = 0.0f; + float local_sum_r1_y2 = 0.0f; + + float safe_pi0 = fminf(fmaxf(pi0, 1.0e-10f), 1.0f - 1.0e-10f); + float safe_pi1 = 1.0f - safe_pi0; + + for (int cell = tid; cell < n_cells; cell += blockDim.x) { + float count = X[cell * n_guides + guide]; + if (count <= 0.0f) continue; + + float y = log2f(count); + float log_p0 = poisson_log_prob(y, lam) + logf(safe_pi0); + float log_p1 = normal_log_prob(y, mu, sigma) + logf(safe_pi1); + float r1 = 1.0f / (1.0f + expf(log_p0 - log_p1)); + float r0 = 1.0f - r1; + + local_n0 += r0; + local_n1 += r1; + local_sum_r0_y += r0 * y; + local_sum_r1_y += r1 * y; + local_sum_r1_y2 += r1 * y * y; + } + + float n0 = block_reduce_sum_thread0(local_n0, red_n0_warp); + float n1 = block_reduce_sum_thread0(local_n1, red_n1_warp); + float sum_r0_y = + block_reduce_sum_thread0(local_sum_r0_y, red_sum_r0_y_warp); + float sum_r1_y = + block_reduce_sum_thread0(local_sum_r1_y, red_sum_r1_y_warp); + float sum_r1_y2 = + block_reduce_sum_thread0(local_sum_r1_y2, red_sum_r1_y2_warp); + + __shared__ float next_lam; + __shared__ float next_mu; + __shared__ float next_sigma; + __shared__ float next_pi0; + __shared__ bool converged; + + if (tid == 0) { + float lam_mle = sum_r0_y / fmaxf(n0, 1.0e-10f); + float log_lam_mle = logf(fmaxf(lam_mle, 1.0e-10f)); + next_lam = expf((n0 * log_lam_mle) / (n0 + 1.0f)); + + float sigma_sq = sigma * sigma; + next_mu = (sum_r1_y / fmaxf(sigma_sq, 1.0e-10f) + 3.0f / 4.0f) / + (n1 / fmaxf(sigma_sq, 1.0e-10f) + 1.0f / 4.0f); + + float sigma_sq_mle = (sum_r1_y2 - 2.0f * next_mu * sum_r1_y + + next_mu * next_mu * n1) / + fmaxf(n1, 1.0e-10f); + sigma_sq_mle = fmaxf(sigma_sq_mle, 0.0f); + float sigma_mle = fmaxf(sqrtf(sigma_sq_mle), 1.0e-2f); + next_sigma = fmaxf( + expf((n1 * logf(sigma_mle) + 2.0f) / (n1 + 1.0f)), 1.0e-2f); + + float denom = fmaxf(n_valid - 1.0f, 1.0e-10f); + next_pi0 = (n0 - 0.1f) / denom; + next_pi0 = fminf(fmaxf(next_pi0, 0.01f), 0.99f); + + float max_change = + fmaxf(fmaxf(fabsf(next_lam - lam), fabsf(next_mu - mu)), + fmaxf(fabsf(next_sigma - sigma), fabsf(next_pi0 - pi0))); + converged = max_change < tol; + } + __syncthreads(); + + lam = next_lam; + mu = next_mu; + sigma = next_sigma; + pi0 = next_pi0; + bool done = converged; + __syncthreads(); + if (done) break; + } + + if (tid == 0) { + float threshold = NAN; + float safe_pi0 = fminf(fmaxf(pi0, 1.0e-10f), 1.0f - 1.0e-10f); + float safe_pi1 = 1.0f - safe_pi0; + + for (int raw_count = 1; raw_count <= guide_max_raw; ++raw_count) { + float log_count = log2f(static_cast(raw_count)); + float log_p0 = poisson_log_prob(log_count, lam) + logf(safe_pi0); + float log_p1 = + normal_log_prob(log_count, mu, sigma) + logf(safe_pi1); + float posterior = 1.0f / (1.0f + expf(log_p0 - log_p1)); + if (posterior > posterior_threshold) { + threshold = static_cast(raw_count); + break; + } + } + + lam_out[guide] = lam; + mu_out[guide] = mu; + sigma_out[guide] = sigma; + pi0_out[guide] = pi0; + thresholds[guide] = threshold; + guide_threshold = threshold; + } + __syncthreads(); + + bool has_threshold = !isnan(guide_threshold); + for (int cell = tid; cell < n_cells; cell += blockDim.x) { + float count = X[cell * n_guides + guide]; + assignments[guide * n_cells + cell] = + has_threshold && count >= guide_threshold; + } +} diff --git a/src/rapids_singlecell/pertpy_gpu/_guide_assignment.py b/src/rapids_singlecell/pertpy_gpu/_guide_assignment.py index 4eb6fa51..54ef7233 100644 --- a/src/rapids_singlecell/pertpy_gpu/_guide_assignment.py +++ b/src/rapids_singlecell/pertpy_gpu/_guide_assignment.py @@ -23,8 +23,9 @@ class GuideAssignment: Provides threshold-based and mixture-model-based methods for assigning cells to guide RNAs, compatible with pertpy's ``GuideAssignment`` API. - The mixture model uses a batched EM algorithm on GPU instead of - per-guide MCMC, yielding orders-of-magnitude speedup. + The mixture model follows crispat's Poisson-Gaussian assignment rule + while using batched EM on GPU instead of per-guide Pyro SVI, yielding + orders-of-magnitude speedup. """ def assign_by_threshold( @@ -128,16 +129,21 @@ def assign_mixture_model( multiple_grna_assigned_key: str = "multiple", multiple_grna_assignment_string: str = "+", only_return_results: bool = False, - max_iter: int = 50, + max_iter: int = 90, tol: float = 1e-4, + posterior_threshold: float = 0.645, + backend: str = "cupy", ) -> np.ndarray | None: """Assign gRNAs using a GPU-accelerated Poisson–Gaussian mixture model. Fits a two-component mixture (Poisson background + Gaussian signal) to the log₂-transformed non-zero counts of each guide simultaneously - using batched Expectation-Maximization on GPU. This replaces pertpy's - sequential per-guide MCMC with a single vectorised EM, giving - orders-of-magnitude speedup. + using batched Expectation-Maximization on GPU. Like crispat's + Poisson-Gaussian assignment, the fitted model is converted to an + integer raw-count threshold. The default posterior cutoff is slightly + conservative to calibrate the GPU EM approximation against crispat's + Pyro SVI outputs; set ``posterior_threshold=0.5`` for the literal + crispat threshold rule. Parameters ---------- @@ -159,6 +165,14 @@ def assign_mixture_model( Maximum number of EM iterations. tol Convergence tolerance on parameter changes. + posterior_threshold + Minimum posterior probability of the Gaussian component required + for a raw UMI count to define the assignment threshold. + backend + Backend for fitting and assignment. ``"cupy"`` uses the existing + CuPy EM and threshold implementation, ``"cuda"`` uses the + nanobind/CUDA EM + threshold kernel, and ``"auto"`` tries CUDA + with a CuPy fallback. Returns ------- @@ -169,15 +183,51 @@ def assign_mixture_model( X = X_to_GPU(adata.X) if cp_issparse(X): X = X.toarray() - X = X.astype(cp.float32) + # TODO: The CUDA guide kernel scans one guide column per block. If this + # path becomes the default, consider densifying sparse inputs directly + # to F-order with _sparse_to_dense(order="F") and dispatching an + # F-contiguous kernel to improve memory coalescing. + X = cp.ascontiguousarray(X.astype(cp.float32, copy=False)) + if not 0 < posterior_threshold < 1: + raise ValueError("posterior_threshold must be between 0 and 1.") + if backend not in {"cupy", "cuda", "auto"}: + raise ValueError("backend must be one of 'cupy', 'cuda', or 'auto'.") _, n_guides = X.shape var_names = np.asarray(adata.var_names) - # Build padded arrays for batched EM - data_pad, mask, cell_indices, counts, valid_guides = _prepare_batched_data( - X, n_guides - ) + if backend in {"cuda", "auto"}: + try: + assignments, thresholds, lam, mu, sigma, pi0, valid_guides = ( + _fit_assign_cuda( + X, + max_iter=max_iter, + tol=tol, + posterior_threshold=posterior_threshold, + ) + ) + except ImportError: + if backend == "cuda": + raise + backend = "cupy" + + if backend == "cupy": + data_pad, mask, _cell_indices, counts, valid_guides = _prepare_batched_data( + X, n_guides + ) + if len(valid_guides) > 0: + lam, mu, sigma, pi0, _ = _batched_em( + data_pad, mask, counts, max_iter=max_iter, tol=tol + ) + assignments, thresholds = _assign_by_crispat_threshold( + X, + valid_guides, + lam=lam, + mu=mu, + sigma=sigma, + pi0=pi0, + posterior_threshold=posterior_threshold, + ) if len(valid_guides) == 0: warnings.warn( @@ -194,11 +244,6 @@ def assign_mixture_model( adata.obs[assigned_guides_key] = series.values return None - # Run batched EM - lam, mu, sigma, pi0, assignments = _batched_em( - data_pad, mask, counts, max_iter=max_iter, tol=tol - ) - # Store fitted parameters in adata.var lam_cpu = cp.asnumpy(lam.ravel()) mu_cpu = cp.asnumpy(mu.ravel()) @@ -211,10 +256,17 @@ def assign_mixture_model( "gaussian_std", "mix_probs_0", "mix_probs_1", + "threshold", + "weight_Poisson", + "weight_Normal", + "lambda", + "mu", + "scale", ]: if col not in adata.var.columns: adata.var[col] = np.nan + thresholds_cpu = cp.asnumpy(thresholds.ravel()) for i, g in enumerate(valid_guides): adata.var.iloc[g, adata.var.columns.get_loc("poisson_rate")] = lam_cpu[i] adata.var.iloc[g, adata.var.columns.get_loc("gaussian_mean")] = mu_cpu[i] @@ -223,29 +275,35 @@ def assign_mixture_model( adata.var.iloc[g, adata.var.columns.get_loc("mix_probs_1")] = ( 1.0 - pi0_cpu[i] ) + adata.var.iloc[g, adata.var.columns.get_loc("threshold")] = thresholds_cpu[ + i + ] + adata.var.iloc[g, adata.var.columns.get_loc("weight_Poisson")] = pi0_cpu[i] + adata.var.iloc[g, adata.var.columns.get_loc("weight_Normal")] = ( + 1.0 - pi0_cpu[i] + ) + adata.var.iloc[g, adata.var.columns.get_loc("lambda")] = lam_cpu[i] + adata.var.iloc[g, adata.var.columns.get_loc("mu")] = mu_cpu[i] + adata.var.iloc[g, adata.var.columns.get_loc("scale")] = sigma_cpu[i] # Map assignments back to (n_cells, n_guides) result - assignments_cpu = cp.asnumpy(assignments) # (n_valid_guides, max_nnz) - mask_cpu = cp.asnumpy(mask) - cell_indices_cpu = cp.asnumpy(cell_indices) + assignments_cpu = cp.asnumpy(assignments) # (n_valid_guides, n_cells) - result = pd.DataFrame(0, index=adata.obs_names, columns=var_names) + result = pd.DataFrame(data=False, index=adata.obs_names, columns=var_names) for i, g in enumerate(valid_guides): - valid = mask_cpu[i] - cells = cell_indices_cpu[i, valid] - assigned = assignments_cpu[i, valid] - result.iloc[cells, g] = assigned + result.iloc[:, g] = assignments_cpu[i] # Build final assignment series series = pd.Series(no_grna_assigned_key, index=adata.obs_names) num_assigned = result.sum(axis=1) multi_mask = (num_assigned > 0) & (num_assigned <= max_assignments_per_cell) - series.loc[multi_mask] = result.loc[multi_mask].apply( - lambda row: multiple_grna_assignment_string.join( - row.index[row == 1].tolist() - ), - axis=1, - ) + if multi_mask.any(): + series.loc[multi_mask] = result.loc[multi_mask].apply( + lambda row: multiple_grna_assignment_string.join( + row.index[row].tolist() + ), + axis=1, + ) series.loc[num_assigned > max_assignments_per_cell] = multiple_grna_assigned_key if only_return_results: @@ -292,6 +350,14 @@ def _prepare_batched_data( stacklevel=4, ) continue + max_count = float(col.max().item()) + if max_count < 2: + warnings.warn( + f"Skipping guide index {g} as the maximum UMI count is less than 2.", + UserWarning, + stacklevel=4, + ) + continue valid_guides.append(g) nnz_per_guide.append(nz_count) nz_data.append(cp.log2(col[nz_mask])) @@ -318,6 +384,149 @@ def _prepare_batched_data( return data_pad, mask, cell_indices, counts, valid_guides +def _assign_by_crispat_threshold( + X: cp.ndarray, + valid_guides: list[int], + *, + lam: cp.ndarray, + mu: cp.ndarray, + sigma: cp.ndarray, + pi0: cp.ndarray, + posterior_threshold: float, +) -> tuple[cp.ndarray, cp.ndarray]: + """Assign cells using crispat's posterior-derived raw UMI threshold.""" + assignments = cp.zeros((len(valid_guides), X.shape[0]), dtype=cp.bool_) + thresholds = cp.full((len(valid_guides), 1), cp.nan, dtype=cp.float32) + + for i, g in enumerate(valid_guides): + col = X[:, g] + max_count = int(cp.ceil(col.max()).item()) + if max_count < 2: + continue + + raw_counts = cp.arange(1, max_count + 1, dtype=cp.float32) + log_counts = cp.log2(raw_counts).reshape(1, -1) + threshold_mask = cp.ones_like(log_counts, dtype=cp.bool_) + _, prob_gaussian = _e_step( + log_counts, + threshold_mask, + lam=lam[i : i + 1], + mu=mu[i : i + 1], + sigma=sigma[i : i + 1], + pi0=pi0[i : i + 1], + ) + positive_counts = raw_counts[cp.ravel(prob_gaussian > posterior_threshold)] + if positive_counts.size == 0: + continue + + threshold = positive_counts[0] + thresholds[i, 0] = threshold + assignments[i] = col >= threshold + + return assignments, thresholds + + +def _fit_assign_cuda( + X: cp.ndarray, + *, + max_iter: int, + tol: float, + posterior_threshold: float, +) -> tuple[ + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + cp.ndarray, + list[int], +]: + """Fit and assign all guides with the nanobind/CUDA EM kernel.""" + from rapids_singlecell._cuda import _guide_assignment_cuda + + if _guide_assignment_cuda is None: + raise ImportError( + "The _guide_assignment_cuda extension is not available. " + "Build rapids-singlecell with CUDA extensions or use backend='cupy'." + ) + + n_cells, n_guides = X.shape + assignments_all = cp.empty((n_guides, n_cells), dtype=cp.bool_) + thresholds_all = cp.empty((n_guides, 1), dtype=cp.float32) + lam_all = cp.empty((n_guides, 1), dtype=cp.float32) + mu_all = cp.empty((n_guides, 1), dtype=cp.float32) + sigma_all = cp.empty((n_guides, 1), dtype=cp.float32) + pi0_all = cp.empty((n_guides, 1), dtype=cp.float32) + valid_mask = cp.empty(n_guides, dtype=cp.bool_) + nonzero_counts = cp.empty(n_guides, dtype=cp.int32) + max_counts = cp.empty(n_guides, dtype=cp.int32) + + _guide_assignment_cuda.fit_assign_dense( + X, + assignments_all, + thresholds_all, + lam_all, + mu_all, + sigma_all, + pi0_all, + valid_mask, + nonzero_counts, + max_counts, + n_cells=n_cells, + n_guides=n_guides, + max_iter=int(max_iter), + tol=float(tol), + posterior_threshold=float(posterior_threshold), + stream=cp.cuda.get_current_stream().ptr, + ) + + valid_mask_cpu = cp.asnumpy(valid_mask).astype(bool) + nonzero_counts_cpu = cp.asnumpy(nonzero_counts) + max_counts_cpu = cp.asnumpy(max_counts) + + for guide, (nz_count, max_count) in enumerate( + zip(nonzero_counts_cpu, max_counts_cpu, strict=True) + ): + if 0 < nz_count < 2: + warnings.warn( + f"Skipping guide index {guide} as there are less than 2 cells " + "expressing the guide.", + UserWarning, + stacklevel=4, + ) + elif nz_count >= 2 and max_count < 2: + warnings.warn( + f"Skipping guide index {guide} as the maximum UMI count is less " + "than 2.", + UserWarning, + stacklevel=4, + ) + + valid_guides = np.flatnonzero(valid_mask_cpu).tolist() + if len(valid_guides) == 0: + empty_2d = cp.empty((0, 1), dtype=cp.float32) + return ( + cp.empty((0, n_cells), dtype=cp.bool_), + empty_2d, + empty_2d, + empty_2d, + empty_2d, + empty_2d, + [], + ) + + valid_guides_gpu = cp.asarray(valid_guides, dtype=cp.int32) + return ( + assignments_all[valid_guides_gpu], + thresholds_all[valid_guides_gpu], + lam_all[valid_guides_gpu], + mu_all[valid_guides_gpu], + sigma_all[valid_guides_gpu], + pi0_all[valid_guides_gpu], + valid_guides, + ) + + def _batched_em( data: cp.ndarray, mask: cp.ndarray, @@ -346,7 +555,7 @@ def _batched_em( lam, mu, sigma, pi0 Fitted parameters, each ``(n_guides, 1)``. assignments - ``(n_guides, max_nnz)`` int8 array (1 = positive, 0 = negative). + ``(n_guides, max_nnz)`` boolean array (``True`` = positive). """ n_valid = mask.sum(axis=1, keepdims=True).astype(cp.float32) # (n_guides, 1) @@ -427,7 +636,7 @@ def _batched_em( # Final assignment: cell is positive if P(Gaussian) > 0.5 r0, r1 = _e_step(data, mask, lam=lam, mu=mu, sigma=sigma, pi0=pi0) - assignments = (r1 > 0.5).astype(cp.int8) + assignments = r1 > 0.5 return lam, mu, sigma, pi0, assignments diff --git a/tests/pertpy/test_guide_assignment.py b/tests/pertpy/test_guide_assignment.py index 06f26414..bc6422d8 100644 --- a/tests/pertpy/test_guide_assignment.py +++ b/tests/pertpy/test_guide_assignment.py @@ -8,6 +8,7 @@ from cupyx.scipy.sparse import csr_matrix as gpu_csr import rapids_singlecell as rsc +from rapids_singlecell.pertpy_gpu._guide_assignment import _fit_assign_cuda @pytest.fixture @@ -192,6 +193,12 @@ def test_mixture_model_stores_params(guide_adata: AnnData) -> None: "gaussian_std", "mix_probs_0", "mix_probs_1", + "threshold", + "weight_Poisson", + "weight_Normal", + "lambda", + "mu", + "scale", ]: assert col in guide_adata.var.columns, f"Missing column: {col}" @@ -206,6 +213,29 @@ def test_mixture_model_stores_params(guide_adata: AnnData) -> None: means = guide_adata.var["gaussian_mean"].dropna() assert (rates < means).all(), "Poisson rate should be < Gaussian mean" + # Crispat-compatible aliases should mirror the pertpy-style parameter names. + np.testing.assert_allclose( + guide_adata.var["weight_Poisson"].dropna(), + guide_adata.var["mix_probs_0"].dropna(), + ) + np.testing.assert_allclose( + guide_adata.var["weight_Normal"].dropna(), + guide_adata.var["mix_probs_1"].dropna(), + ) + np.testing.assert_allclose( + guide_adata.var["lambda"].dropna(), + guide_adata.var["poisson_rate"].dropna(), + ) + np.testing.assert_allclose( + guide_adata.var["mu"].dropna(), + guide_adata.var["gaussian_mean"].dropna(), + ) + np.testing.assert_allclose( + guide_adata.var["scale"].dropna(), + guide_adata.var["gaussian_std"].dropna(), + ) + assert guide_adata.var["threshold"].dropna().ge(1).all() + def test_mixture_model_sparse_input(guide_adata_sparse: AnnData) -> None: ga = rsc.ptg.GuideAssignment() @@ -225,6 +255,61 @@ def test_mixture_model_only_return_results(guide_adata: AnnData) -> None: assert isinstance(result, np.ndarray) +def test_mixture_model_invalid_posterior_threshold(guide_adata: AnnData) -> None: + ga = rsc.ptg.GuideAssignment() + with pytest.raises(ValueError, match="posterior_threshold"): + ga.assign_mixture_model(guide_adata, posterior_threshold=1.0) + + +def test_mixture_model_invalid_backend(guide_adata: AnnData) -> None: + ga = rsc.ptg.GuideAssignment() + with pytest.raises(ValueError, match="backend"): + ga.assign_mixture_model(guide_adata, backend="not-a-backend") + with pytest.raises(ValueError, match="backend"): + ga.assign_mixture_model(guide_adata, backend="cuda_em") + + +def test_mixture_model_cuda_backend_matches_cupy(guide_adata: AnnData) -> None: + from rapids_singlecell._cuda import _guide_assignment_cuda + + if _guide_assignment_cuda is None: + pytest.skip("_guide_assignment_cuda extension is not available") + + cupy_adata = guide_adata.copy() + cuda_adata = guide_adata.copy() + ga = rsc.ptg.GuideAssignment() + + ga.assign_mixture_model(cupy_adata, backend="cupy") + ga.assign_mixture_model(cuda_adata, backend="cuda") + + np.testing.assert_array_equal( + cupy_adata.obs["assigned_guide"].to_numpy(), + cuda_adata.obs["assigned_guide"].to_numpy(), + ) + np.testing.assert_array_equal( + cupy_adata.var["threshold"].to_numpy(), + cuda_adata.var["threshold"].to_numpy(), + ) + for col in ["lambda", "mu", "scale", "weight_Poisson"]: + assert np.isfinite(cuda_adata.var[col].dropna()).all() + + +def test_mixture_model_cuda_assignments_are_bool(guide_adata: AnnData) -> None: + from rapids_singlecell._cuda import _guide_assignment_cuda + + if _guide_assignment_cuda is None: + pytest.skip("_guide_assignment_cuda extension is not available") + + assignments, *_ = _fit_assign_cuda( + cp.ascontiguousarray(guide_adata.X.astype(cp.float32, copy=False)), + max_iter=90, + tol=1e-4, + posterior_threshold=0.645, + ) + + assert assignments.dtype == cp.bool_ + + def test_mixture_model_skip_low_count() -> None: """Guides with < 2 expressing cells should be skipped with a warning.""" X = np.zeros((50, 3), dtype=np.float32) @@ -249,6 +334,24 @@ def test_mixture_model_skip_low_count() -> None: assert assigned.str.contains("good", na=False).any() +def test_mixture_model_skip_max_count_below_two() -> None: + """Crispat skips guides whose non-zero counts never reach 2 UMIs.""" + X = np.zeros((50, 2), dtype=np.float32) + X[:25, :] = 1.0 + + adata = AnnData( + X=cp.array(X), + var=pd.DataFrame(index=["one_a", "one_b"]), + obs=pd.DataFrame(index=[f"c{i}" for i in range(50)]), + ) + + ga = rsc.ptg.GuideAssignment() + with pytest.warns(UserWarning, match="maximum UMI count is less than 2"): + ga.assign_mixture_model(adata) + + assert (adata.obs["assigned_guide"] == "negative").all() + + def test_multiple_guide_assignment() -> None: """Cells assigned to multiple guides get joined names.""" rng = np.random.default_rng(99)