diff --git a/locat/locat_condensed.py b/locat/locat_condensed.py index b3e35df..ee881b4 100755 --- a/locat/locat_condensed.py +++ b/locat/locat_condensed.py @@ -56,7 +56,7 @@ def __init__( adata: AnnData, cell_embedding: np.ndarray, k: int, - n_bootstrap_inits: int, + n_bootstrap_inits: int = 50, show_progress=False, wgmm_dtype: str = "same", # "same" | "float32" | "float64" knn=None, # <-- NEW: precomputed adjacency/connectivities @@ -126,7 +126,13 @@ def reg_covar(self, sample_size=None): # Data access / options # ------------------------------------------------------------------ @property - def X(self): + def W_t(self): + """ + Implementation-oriented expression matrix (cells x genes). + + This is the transpose of the manuscript's W object, which is + described in gene x cell orientation. + """ if self._X is None: X = self._adata.X if not isinstance(X, np.ndarray): @@ -154,7 +160,7 @@ def background_n_components_init(self, weights_transform=None, min_points=10, n_ bic_component_cost = self.n_dims * (self.n_dims + 3) / 2 if weights_transform is not None: - Xdense = self.X.copy() + Xdense = self.W_t.copy() for i in range(self.n_genes): Xdense[:, i] = weights_transform(Xdense[:, i]) weights = np.asarray(Xdense.sum(axis=1), dtype=self._dtype) @@ -367,7 +373,6 @@ def background_pdf( self, n_comp=None, reps=10, - total_counts_weight=True, weights_transform=None, force_refresh=False, ): @@ -388,7 +393,7 @@ def background_pdf( logger.info(f"Using {n_comp} components") if weights_transform is not None: - Xdense = self.X.copy() + Xdense = self.W_t.copy() for i in range(self.n_genes): Xdense[:, i] = weights_transform(Xdense[:, i]) weights = np.asarray(Xdense.sum(axis=1), dtype=self._dtype) @@ -529,21 +534,21 @@ def estimate_null_parameters(self, fractions=None, n_reps=50): # ------------------------------------------------------------------ # Depletion-style localization scan # ------------------------------------------------------------------ - def localization_pval_dep_scan( + def depletion_pval_scan( self, gmm1, gene_prior, *, - c_values=None, + lambda_values=None, soft_bound=None, # default computed from n: max((n-1)/n, 0.99) min_p0_abs=0.10, min_expected=30, min_abs_deficit=0.02, n_trials_cap=500, - weight_mode="amount", + weight_mode="binary", p_floor=1e-12, - n_eff_scale=1.0, - rho_bb=0.2, # >0 enables Beta–Binomial tail + n_eff_scale=0.6, + rho_bb=0.02, # >0 enables Beta–Binomial tail eps_rel=0.01, debug=False, debug_store_masks=False, @@ -565,14 +570,14 @@ def localization_pval_dep_scan( else: raise ValueError("weight_mode must be 'amount' or 'binary'") - # Kish n_eff with tempering + cap + # Kish n_eff_g with tempering + cap sw = float(np.sum(w_obs)) - n_eff = 0.0 if sw <= 0 else (sw * sw) / max(float(np.sum(w_obs * w_obs)), 1e-12) - n_eff *= float(n_eff_scale) + n_eff_g = 0.0 if sw <= 0 else (sw * sw) / max(float(np.sum(w_obs * w_obs)), 1e-12) + n_eff_g *= float(n_eff_scale) if n_trials_cap is not None: - n_eff = min(n_eff, float(n_trials_cap)) - n_eff = max(1.0, n_eff) - n_trials = int(round(n_eff)) + n_eff_g = min(n_eff_g, float(n_trials_cap)) + n_eff_g = max(1.0, n_eff_g) + n_trials_eff = int(round(n_eff_g)) f0_x = np.clip(self._background_gmm.pdf(X), 1e-300, np.inf) f1_x = np.clip(gmm1.pdf(X), 1e-300, np.inf) @@ -581,25 +586,25 @@ def localization_pval_dep_scan( w_obs_alpha = w_obs / (sw if sw > 0 else 1.0) w0_alpha = w0 - if c_values is None: - c_values = np.concatenate([[1.0], np.geomspace(1.05, 3.0, 12)]) + if lambda_values is None: + lambda_values = np.concatenate([[1.0], np.geomspace(1.05, 3.0, 12)]) best_logp = None - best = {"c": None, "k_obs": None, "p0": None, "obs_prop": None} + best = {"lambda": None, "k_obs": None, "p0": None, "obs_prop": None} scanned = 0 tested = 0 - per_c = [] if debug else None + per_lambda = [] if debug else None - for c in c_values: - c = float(c) - in_R_mask = f0_x > c * f1_x + for lambda_ in lambda_values: + lambda_ = float(lambda_) + in_R_mask = f0_x > lambda_ * f1_x p0_abs = float(np.sum(w0_alpha * in_R_mask)) reason = None if p0_abs < min_p0_abs: reason = "fail:min_p0_abs" - if n_eff * p0_abs < min_expected: + if n_eff_g * p0_abs < min_expected: reason = "fail:min_expected" if p0_abs > soft_bound: reason = "fail:soft_bound" @@ -608,16 +613,16 @@ def localization_pval_dep_scan( if reason is None: if (p0_abs - obs_prop) < min_abs_deficit: reason = "fail:min_abs_deficit" - elif obs_prop > (p0_abs / c) * (1.0 - float(eps_rel)): + elif obs_prop > (p0_abs / lambda_) * (1.0 - float(eps_rel)): reason = "fail:c_bound" #this checks whether observed f1 density in the region is at an equal or lower proportion than the observed f0 density in the region * c (where c is the contrast). If the region is larger than expectation, reject the gene. if debug: rec = { - "c": c, + "lambda": lambda_, "p0_abs": p0_abs, "obs_prop": obs_prop, - "n_eff": float(n_eff), - "n_eff_expected": float(n_eff * p0_abs), + "n_eff_g": float(n_eff_g), + "n_eff_expected": float(n_eff_g * p0_abs), "reason": reason, } if debug_store_masks: @@ -626,7 +631,7 @@ def localization_pval_dep_scan( if weight_mode == "binary": expr_mask = gp > 0 rec["expr_in_R_count"] = int(np.sum(expr_mask & in_R_mask)) - per_c.append(rec) + per_lambda.append(rec) if reason is not None: continue @@ -634,24 +639,24 @@ def localization_pval_dep_scan( scanned += 1 tested += 1 - k_eff = int(np.rint(obs_prop * n_trials)) + k_eff = int(np.rint(obs_prop * n_trials_eff)) p0_clip = np.clip(p0_abs, 1e-12, 1 - 1e-12) if rho_bb and rho_bb > 0.0: ab_sum = max(1.0 / float(rho_bb) - 1.0, 2.0) alpha = float(p0_clip * ab_sum) beta = float((1.0 - p0_clip) * ab_sum) - p_raw = float(betabinom.cdf(k_eff, n_trials, alpha, beta)) + p_raw = float(betabinom.cdf(k_eff, n_trials_eff, alpha, beta)) logp_raw = np.log(max(p_raw, np.finfo(float).tiny)) else: - lFkm1 = binom.logcdf(k_eff - 1, n_trials, p0_clip) - lpk = binom.logpmf(k_eff, n_trials, p0_clip) + np.log(0.5) + lFkm1 = binom.logcdf(k_eff - 1, n_trials_eff, p0_clip) + lpk = binom.logpmf(k_eff, n_trials_eff, p0_clip) + np.log(0.5) logp_raw = logsumexp([lFkm1, lpk]) if (best_logp is None) or (logp_raw < best_logp): best_logp = float(logp_raw) best.update( - {"c": c, "k_obs": k_eff, "p0": p0_abs, "obs_prop": obs_prop} + {"lambda": lambda_, "k_obs": k_eff, "p0": p0_abs, "obs_prop": obs_prop} ) if tested == 0: @@ -662,15 +667,16 @@ def localization_pval_dep_scan( "log_p_sidak": 0.0, "neglog10_p_single": 0.0, "neglog10_p_sidak": 0.0, - "best_c": None, + "best_lambda": None, "k_obs_eff": None, "p0_abs": None, "obs_prop": None, "scanned": int(scanned), "tested": int(tested), - "sidak_penalty": 1, + "sidak_penalty": int(tested), "n": n, - "n_eff": float(n_eff), + "n_eff_g": float(n_eff_g), + "n_trials_eff": int(n_trials_eff), "guards": { "min_p0_abs": float(min_p0_abs), "min_expected": float(min_expected), @@ -682,7 +688,7 @@ def localization_pval_dep_scan( "rho_bb": float(rho_bb), } if debug: - out["per_c"] = per_c + out["per_lambda"] = per_lambda return out m_eff = tested @@ -700,15 +706,16 @@ def localization_pval_dep_scan( "log_p_sidak": float(sidak_logp), "neglog10_p_single": float(-best_logp / np.log(10)), "neglog10_p_sidak": float(-sidak_logp / np.log(10)), - "best_c": best["c"], + "best_lambda": best["lambda"], "k_obs_eff": best["k_obs"], "p0_abs": best["p0"], "obs_prop": best["obs_prop"], "scanned": int(scanned), "tested": int(tested), - "sidak_penalty": 1, + "sidak_penalty": int(m_eff), "n": n, - "n_eff": float(n_eff), + "n_eff_g": float(n_eff_g), + "n_trials_eff": int(n_trials_eff), "guards": { "min_p0_abs": float(min_p0_abs), "min_expected": float(min_expected), @@ -720,9 +727,13 @@ def localization_pval_dep_scan( "rho_bb": float(rho_bb), } if debug: - out["per_c"] = per_c + out["per_lambda"] = per_lambda return out + # Backward-compatible alias + def localization_pval_dep_scan(self, *args, **kwargs): + return self.depletion_pval_scan(*args, **kwargs) + # ------------------------------------------------------------------ # Main scan used in practice # ------------------------------------------------------------------ @@ -741,27 +752,32 @@ def gmm_scan_new( genes=None, weights_transform=None, zscore_thresh=None, - max_freq=0.5, + max_freq=0.9, verbose=False, n_bootstrap_inits=None, # Depletion-scan defaults - rc_c_values=None, # default inside method + rc_lambda_values=None, # default inside method rc_min_p0_abs=0.10, #minimum proportion of f0 density in depleted region required for the region pval to be estimated - rc_min_expected=30, #minimum expected cells in depleted region required for the region pval to be estimated - rc_min_abs_deficit=0.02, #minimum absolute difference in f1(x) - f0(x) for all x in depleted region - rc_n_trials_cap=500, #maximum effective sample size + rc_min_expected=3, #minimum expected cells in depleted region required for the region pval to be estimated + rc_min_abs_deficit=0.04, #minimum absolute difference in f1(x) - f0(x) for all x in depleted region + rc_n_trials_cap=None, #if None, defaults to sqrt(n_cells) rc_soft_bound=1.0, #this is unused/can be removed - rc_n_eff_scale=0.5, #scaling factor for effective sample sizes -- can be tweaked to stabilize pvalues across various gene sample sizes + rc_n_eff_scale=0.6, #scaling factor for effective sample sizes -- can be tweaked to stabilize pvalues across various gene sample sizes rc_p_floor=1e-12, # this is just model precision, can be ignored rc_rho_bb=0.02, #this is the strength of the beta binomial (0.0 is standard binomial, set at 0.02-0.05 for wider tails) - rc_weight_mode="amount", + rc_weight_mode="binary", rc_eps_rel=0.01, ): if verbose: - print("gmm_scan_new: using depletion scan for localization_pval (localization_pval_dep_scan)") + print("gmm_scan_new: using depletion scan for depletion_pval (depletion_pval_scan)") if n_bootstrap_inits is not None: self.n_bootstrap_inits = int(n_bootstrap_inits) + rc_n_trials_cap_eff = ( + int(max(1, np.sqrt(self.n_cells))) + if rc_n_trials_cap is None + else int(rc_n_trials_cap) + ) locally_enriched = dict() gzeros, freqzeros, zzeros = [], [], [] @@ -810,35 +826,35 @@ def gmm_scan_new( zzeros.append(self._adata.var_names[i_gene]) continue - cs_res = self.localization_pval_dep_scan( + cs_res = self.depletion_pval_scan( gmm1, gene_prior, debug=True, - c_values=( - rc_c_values - if rc_c_values is not None + lambda_values=( + rc_lambda_values + if rc_lambda_values is not None else np.concatenate([[1.0], np.geomspace(1.05, 3.0, 12)]) ), soft_bound=rc_soft_bound, min_p0_abs=rc_min_p0_abs, min_expected=rc_min_expected, min_abs_deficit=rc_min_abs_deficit, - n_trials_cap=rc_n_trials_cap, + n_trials_cap=rc_n_trials_cap_eff, weight_mode=rc_weight_mode, p_floor=rc_p_floor, n_eff_scale=rc_n_eff_scale, rho_bb=rc_rho_bb, eps_rel=rc_eps_rel, ) - localization_pval = _safe_p(cs_res["p_value"]) + depletion_pval = _safe_p(cs_res["p_value"]) concentration_pval = _safe_p(float(normal_sf(zscore, 0.0, 1.0))) - p_cauchy = cauchy_combine([localization_pval, concentration_pval]) - p_size = _safe_p(1.0 - np.exp(-1.0 / (sample_size + 1.0))) - p_sens = _safe_p(1.0 - (sens_score + 1e-9)) - p_final = 1.0 - (1.0 - p_cauchy) * (1.0 - 0.05 * p_size) * ( - 1.0 - 0.12 * p_sens + p_cauchy = cauchy_combine([depletion_pval, concentration_pval]) + h_size = _safe_p(1.0 - np.exp(-1.0 / (sample_size + 1.0))) + h_sens = _safe_p(1.0 - (sens_score + 1e-9)) + p_final = 1.0 - (1.0 - p_cauchy) * (1.0 - 0.05 * h_size) * ( + 1.0 - 0.12 * h_sens ) p_final = float(smooth_qvals(np.array([_safe_p(p_final)]))[0]) @@ -846,11 +862,14 @@ def gmm_scan_new( "bic": self.bic_score(gmm1, gene_prior), "zscore": zscore, "sens_score": sens_score, - "localization_pval": localization_pval, + "depletion_pval": depletion_pval, "concentration_pval": concentration_pval, + "h_size": h_size, + "h_sens": h_sens, "pval": p_final, - "n_components": n_comp, + "K_components": n_comp, "sample_size": sample_size, + "depletion_scan": cs_res, "depl_scan": cs_res, } @@ -873,7 +892,7 @@ def get_genes_indices(self, genes): return inclgenes def get_gene_prior(self, i_gene, weights_transform): - gene_prior = self.X[:, i_gene] + gene_prior = self.W_t[:, i_gene] if weights_transform is not None: gene_prior = weights_transform(gene_prior) return gene_prior @@ -917,7 +936,8 @@ def gmm_loglikelihoodtest(self, genes=None, weights_transform=None, max_freq=0.5 f1 = self.signal_pdf(weights=gene_prior, n_comp=n_comp) - df = (bkg_df - n_comp) + 1 + # Keep df in the valid chi-square domain in edge cases. + df = max(1, int((bkg_df - n_comp) + 1)) ix = gene_prior_gt0 & (f1 > 0) if np.sum(ix) > 5: ix = ix & bkg_pdf_gt0 @@ -1157,7 +1177,7 @@ def gmm_local_scan(self, 'localization_pval': localization_pval, 'concentration_pval': concentration_pval, 'local_pvalue': localization_pval + concentration_pval - localization_pval*concentration_pval, - 'n_components': n_comp + 'K_components': n_comp } except Exception as e: logger.exception(e) # prints full traceback @@ -1378,12 +1398,12 @@ def cauchy_combine(pvals, weights=None): def summarize_rc_debug(cs_res, top=8): """ Convenience helper to inspect per-threshold diagnostics from - localization_pval_dep_scan(..., debug=True). + depletion_pval_scan(..., debug=True). """ - if "per_c" not in cs_res: + if "per_lambda" not in cs_res: print("No per-threshold diagnostics captured. Run with debug=True.") return - rows = cs_res["per_c"] + rows = cs_res["per_lambda"] if not rows: print("No thresholds scanned.") return @@ -1397,7 +1417,7 @@ def summarize_rc_debug(cs_res, top=8): print(f"\nTop {min(top, len(cand))} failing thresholds by n_eff_expected:") for r in cand[:top]: print( - f" c={r['c']:.3f} p0_abs={r['p0_abs']:.3f} " + f" lambda={r['lambda']:.3f} p0_abs={r['p0_abs']:.3f} " f"obs_prop={r['obs_prop']:.3f} n_eff_expected={r['n_eff_expected']:.1f} " f"reason={r['reason']}" ) diff --git a/locat/plotting_and_other_methods.py b/locat/plotting_and_other_methods.py index 35e4ff6..e92560b 100755 --- a/locat/plotting_and_other_methods.py +++ b/locat/plotting_and_other_methods.py @@ -267,7 +267,7 @@ def plot_gene_localization_summary( f"{gene}\nLocalized: {loc_mask.sum()} | " f"Unloc: {unloc_mask.sum()} | " f"ExpUnloc: {rec.get('expected_unlocalized', np.nan):.1f}\n" - f"LocPval: {rec['localization_pval']:.2e}" + f"DepPval: {rec['depletion_pval']:.2e}" ), fontsize=11 ) @@ -347,7 +347,7 @@ def plotgenes( f"{gene}\n" f"pval: {d0.loc[gene]['pval']:.2e}\n" f"conc_pval: {d0.loc[gene]['concentration_pval']:.2e}\n" - f"loca_pval: {d0.loc[gene]['localization_pval']:.2e}", + f"dep_pval: {d0.loc[gene]['depletion_pval']:.2e}", fontsize=text_size, pad=10 ) @@ -380,4 +380,3 @@ def plotgenes( plt.tight_layout() fig.subplots_adjust(hspace=0.4) plt.show() -