diff --git a/locat/locat.py b/locat/locat.py index 5cb4dbe..9da8f19 100755 --- a/locat/locat.py +++ b/locat/locat.py @@ -4,6 +4,7 @@ import numba import numpy as np from loguru import logger +from numpy.random import Generator from scanpy import AnnData from scipy.interpolate import PchipInterpolator from scipy.special import logsumexp @@ -52,6 +53,7 @@ class LOCAT: _X = None _n_components_waypoints = None _disable_progress_info = True + _rng: Generator = None def __init__( self, @@ -97,7 +99,7 @@ def __init__( """ self._disable_progress_info = not show_progress self._adata = adata - + self.init_rng() emb = np.asarray(cell_embedding) emb = (emb - emb.mean(0)) / (emb.std(0) + emb.dtype.type(1e-6)) self._embedding = emb @@ -175,6 +177,18 @@ def W_t(self): def show_progress(self, show_progress=True): self._disable_progress_info = not show_progress + def init_rng(self, seed: int = 0): + """ + Initialized the random number generator. + + Parameters + ---------- + seed: int, optional + The seed to use + + """ + self._rng = np.random.default_rng(seed) + # ------------------------------------------------------------------ # Background GMM and LTST null # ------------------------------------------------------------------ @@ -238,7 +252,7 @@ def background_n_components_init(self, weights_transform=None, min_points=10, n_ o = np.arange(self.n_cells) wgs = np.zeros(shape=(self.n_cells, n_reps)) for i_rep in range(n_reps): - np.random.shuffle(o) + self._rng.shuffle(o) wgs[o < n, i_rep] = weights[o < n] wgs = wgs / np.sum(wgs, axis=0) @@ -522,7 +536,7 @@ def estimate_null_parameters(self, fractions=None, n_reps=50): for _ in range(n_reps): mask = np.zeros(self.n_cells, dtype=bool) - mask[np.random.choice(self.n_cells, n_pos, replace=False)] = True + mask[self._rng.choice(self.n_cells, n_pos, replace=False)] = True gene_prior = mask.astype(self._dtype) comp_gene_prior = self._auto_n_effective_weights(gene_prior) @@ -801,7 +815,7 @@ def gmm_scan( include_depletion_scan: bool = False, ) -> dict[str, LocatResult]: """ - + Runs Locat and identifies Localized genes Parameters ---------- @@ -849,8 +863,9 @@ def gmm_scan( """ if verbose: - logger.info("gmm_scan_new: using depletion scan for depletion_pval (depletion_pval_scan)") + logger.info("gmm_scan: using depletion scan for depletion_pval (depletion_pval_scan)") + self.init_rng() if n_bootstrap_inits is not None: self.n_bootstrap_inits = int(n_bootstrap_inits) rc_n_trials_cap_eff = ( @@ -1338,13 +1353,13 @@ def random_pdf( # ---------------------------------------------------------------------- # JIT-accelerated scoring helpers # ---------------------------------------------------------------------- -@numba.jit(nopython=True) +@numba.njit() def ltst_score_func(f0, f1, p): q = np.sqrt(p) return (q * (1 - q)) * (f1 - f0) / (f1 + f0) -@numba.jit(nopython=True) +@numba.njit() def sens_score_func(f0, f1, i): return np.mean(f1[i > 0] > f0[i > 0]) diff --git a/locat/rgmm.py b/locat/rgmm.py index 9e4d7da..69629f9 100755 --- a/locat/rgmm.py +++ b/locat/rgmm.py @@ -1,19 +1,13 @@ from functools import partial import numpy as np +from sklearn.cluster import kmeans_plusplus import jax import jax.numpy as jnp import jax.scipy as jsp -import tensorflow_probability.substrates.jax as jaxp +import tensorflow_probability.substrates.jax.distributions as jaxd jax.config.update("jax_enable_x64", True) -from sklearn.cluster import kmeans_plusplus -from loguru import logger - -for d in jax.local_devices(): - logger.info(f'Found device: {d}') -jaxd = jaxp.distributions - def _weighted_kmeans_init(X, w, n_c, n_inits): return map( @@ -134,7 +128,7 @@ def softbootstrap_gmm(X, raw_weights, n_components, n_inits=100, reg_covar=0.0, rand_weights = rand_weights[o_back,:] n = rand_weights.shape[0] - boot_weights = np.random.geometric(1 / n, size=rand_weights.shape) + boot_weights = rng.geometric(1 / n, size=rand_weights.shape) boot_weights = (boot_weights / np.sum(boot_weights, axis=0)[None, :]) weights = rand_weights * boot_weights return rgmm(X, weights, n_components, n_inits, reg_covar, rand_weights) @@ -148,10 +142,11 @@ def hardbootstrap_gmm(X, raw_weights, n_components, fraction, n_inits=30, reg_co fraction = np.clip(fraction, 0, 1) n_points = X.shape[0] n_samples = np.maximum(1, int(n_points * fraction)) + rng = np.random.default_rng(seed) weights = np.zeros(shape=(n_points, n_inits)) for i in range(n_inits): - sampled_indices = np.random.choice( + sampled_indices = rng.choice( n_points, size=n_samples, replace=True, @@ -170,10 +165,11 @@ def simplebootstrap_gmm(X, n_components, fraction, n_inits=30, reg_covar=0.0, se fraction = np.clip(fraction, 0, 1) n_points = X.shape[0] n_samples = np.maximum(1, int(n_points * fraction)) + rng = np.random.default_rng(seed) weights = np.zeros(shape=(n_points, n_inits)) for i in range(n_inits): - sampled_indices = np.random.choice( + sampled_indices = rng.choice( n_points, size=n_samples, replace=False, diff --git a/locat/utils/simulations.py b/locat/utils/simulations.py index 49d7fdb..78cace0 100644 --- a/locat/utils/simulations.py +++ b/locat/utils/simulations.py @@ -13,9 +13,10 @@ def create_anndata(matrix, cell_names=None, gene_names=None): def simulate_blob_data( - n_samples:int = 5000, - n_tests:int = 200, - n_total = 50, + n_samples: int = 5000, + n_tests: int = 200, + n_total: int = 50, + seed: int = 0, ) -> AnnData: coords, clusts, centers = make_blobs( n_samples=[n_samples], @@ -26,7 +27,7 @@ def simulate_blob_data( cluster_std=[1.] ) - rng = np.random.default_rng(0) + rng = np.random.default_rng(seed) # ---------------------------- # Fixed radius, vary in/out fraction diff --git a/locat/wgmms.py b/locat/wgmms.py index c1cb608..140ab02 100755 --- a/locat/wgmms.py +++ b/locat/wgmms.py @@ -1,24 +1,13 @@ from functools import partial - +import numpy as np +from sklearn.cluster import kmeans_plusplus import jax import jax.numpy as jnp import jax.scipy as jsp -import numpy as np -import tensorflow_probability.substrates.jax as jaxp - -#from jax.config import config as jax_config - -from concurrent.futures import ThreadPoolExecutor +import tensorflow_probability.substrates.jax.distributions as jaxd jax.config.update("jax_enable_x64", True) -from sklearn.cluster import kmeans_plusplus -from loguru import logger - -for d in jax.local_devices(): - logger.info(f'Found device: {d}') -jaxd = jaxp.distributions - def _weighted_kmeans_init(X, w, n_c, n_inits): return map( diff --git a/tests/utils/test_simulations.py b/tests/utils/test_simulations.py index a411d1b..dea6ed3 100644 --- a/tests/utils/test_simulations.py +++ b/tests/utils/test_simulations.py @@ -8,7 +8,6 @@ class SimulationTestCase(unittest.TestCase): def test_simulated_data(self): - np.random.seed(0) adata = simulate_blob_data(n_samples=500, n_tests=20, n_total=50) gene_name = 'Gene_0' @@ -19,12 +18,12 @@ def test_simulated_data(self): gene_0_results = locat_results[gene_name] self.assertEqual(gene_name, gene_0_results.gene_name) - self.assertAlmostEqual(-1.512298, gene_0_results.bic, places=4) - self.assertLessEqual(0., gene_0_results.zscore) # This changes randomly + self.assertAlmostEqual(-1.512298, gene_0_results.bic, places=5) + self.assertAlmostEqual(46.893, gene_0_results.zscore, places=2) self.assertAlmostEqual(1.0, gene_0_results.sens_score, places=5) - self.assertLess(gene_0_results.depletion_pval, 1e-6) + self.assertAlmostEqual(3.98127e-8, gene_0_results.depletion_pval, places=12) self.assertAlmostEqual(1e-15, gene_0_results.concentration_pval, places=14) - self.assertLess(gene_0_results.pval, 1e-2) + self.assertAlmostEqual( 0.000970843, gene_0_results.pval, places=8) self.assertEqual(1, gene_0_results.K_components) self.assertEqual(50, gene_0_results.sample_size)