Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions locat/locat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numba
import numpy as np
from loguru import logger
from numpy.random import Generator
from scanpy import AnnData
from scipy.interpolate import PchipInterpolator
from scipy.special import logsumexp
Expand Down Expand Up @@ -52,6 +53,7 @@ class LOCAT:
_X = None
_n_components_waypoints = None
_disable_progress_info = True
_rng: Generator = None

def __init__(
self,
Expand Down Expand Up @@ -97,7 +99,7 @@ def __init__(
"""
self._disable_progress_info = not show_progress
self._adata = adata

self.init_rng()
emb = np.asarray(cell_embedding)
emb = (emb - emb.mean(0)) / (emb.std(0) + emb.dtype.type(1e-6))
self._embedding = emb
Expand Down Expand Up @@ -175,6 +177,18 @@ def W_t(self):
def show_progress(self, show_progress=True):
self._disable_progress_info = not show_progress

def init_rng(self, seed: int = 0):
"""
Initialized the random number generator.

Parameters
----------
seed: int, optional
The seed to use

"""
self._rng = np.random.default_rng(seed)

# ------------------------------------------------------------------
# Background GMM and LTST null
# ------------------------------------------------------------------
Expand Down Expand Up @@ -238,7 +252,7 @@ def background_n_components_init(self, weights_transform=None, min_points=10, n_
o = np.arange(self.n_cells)
wgs = np.zeros(shape=(self.n_cells, n_reps))
for i_rep in range(n_reps):
np.random.shuffle(o)
self._rng.shuffle(o)
wgs[o < n, i_rep] = weights[o < n]
wgs = wgs / np.sum(wgs, axis=0)

Expand Down Expand Up @@ -522,7 +536,7 @@ def estimate_null_parameters(self, fractions=None, n_reps=50):

for _ in range(n_reps):
mask = np.zeros(self.n_cells, dtype=bool)
mask[np.random.choice(self.n_cells, n_pos, replace=False)] = True
mask[self._rng.choice(self.n_cells, n_pos, replace=False)] = True
gene_prior = mask.astype(self._dtype)

comp_gene_prior = self._auto_n_effective_weights(gene_prior)
Expand Down Expand Up @@ -801,7 +815,7 @@ def gmm_scan(
include_depletion_scan: bool = False,
) -> dict[str, LocatResult]:
"""

Runs Locat and identifies Localized genes

Parameters
----------
Expand Down Expand Up @@ -849,8 +863,9 @@ def gmm_scan(

"""
if verbose:
logger.info("gmm_scan_new: using depletion scan for depletion_pval (depletion_pval_scan)")
logger.info("gmm_scan: using depletion scan for depletion_pval (depletion_pval_scan)")

self.init_rng()
if n_bootstrap_inits is not None:
self.n_bootstrap_inits = int(n_bootstrap_inits)
rc_n_trials_cap_eff = (
Expand Down Expand Up @@ -1338,13 +1353,13 @@ def random_pdf(
# ----------------------------------------------------------------------
# JIT-accelerated scoring helpers
# ----------------------------------------------------------------------
@numba.jit(nopython=True)
@numba.njit()
def ltst_score_func(f0, f1, p):
q = np.sqrt(p)
return (q * (1 - q)) * (f1 - f0) / (f1 + f0)


@numba.jit(nopython=True)
@numba.njit()
def sens_score_func(f0, f1, i):
return np.mean(f1[i > 0] > f0[i > 0])

Expand Down
18 changes: 7 additions & 11 deletions locat/rgmm.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from functools import partial
import numpy as np
from sklearn.cluster import kmeans_plusplus
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import tensorflow_probability.substrates.jax as jaxp
import tensorflow_probability.substrates.jax.distributions as jaxd

jax.config.update("jax_enable_x64", True)

from sklearn.cluster import kmeans_plusplus
from loguru import logger

for d in jax.local_devices():
logger.info(f'Found device: {d}')
jaxd = jaxp.distributions


def _weighted_kmeans_init(X, w, n_c, n_inits):
return map(
Expand Down Expand Up @@ -134,7 +128,7 @@ def softbootstrap_gmm(X, raw_weights, n_components, n_inits=100, reg_covar=0.0,
rand_weights = rand_weights[o_back,:]

n = rand_weights.shape[0]
boot_weights = np.random.geometric(1 / n, size=rand_weights.shape)
boot_weights = rng.geometric(1 / n, size=rand_weights.shape)
boot_weights = (boot_weights / np.sum(boot_weights, axis=0)[None, :])
weights = rand_weights * boot_weights
return rgmm(X, weights, n_components, n_inits, reg_covar, rand_weights)
Expand All @@ -148,10 +142,11 @@ def hardbootstrap_gmm(X, raw_weights, n_components, fraction, n_inits=30, reg_co
fraction = np.clip(fraction, 0, 1)
n_points = X.shape[0]
n_samples = np.maximum(1, int(n_points * fraction))
rng = np.random.default_rng(seed)

weights = np.zeros(shape=(n_points, n_inits))
for i in range(n_inits):
sampled_indices = np.random.choice(
sampled_indices = rng.choice(
n_points,
size=n_samples,
replace=True,
Expand All @@ -170,10 +165,11 @@ def simplebootstrap_gmm(X, n_components, fraction, n_inits=30, reg_covar=0.0, se
fraction = np.clip(fraction, 0, 1)
n_points = X.shape[0]
n_samples = np.maximum(1, int(n_points * fraction))
rng = np.random.default_rng(seed)

weights = np.zeros(shape=(n_points, n_inits))
for i in range(n_inits):
sampled_indices = np.random.choice(
sampled_indices = rng.choice(
n_points,
size=n_samples,
replace=False,
Expand Down
9 changes: 5 additions & 4 deletions locat/utils/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ def create_anndata(matrix, cell_names=None, gene_names=None):


def simulate_blob_data(
n_samples:int = 5000,
n_tests:int = 200,
n_total = 50,
n_samples: int = 5000,
n_tests: int = 200,
n_total: int = 50,
seed: int = 0,
) -> AnnData:
coords, clusts, centers = make_blobs(
n_samples=[n_samples],
Expand All @@ -26,7 +27,7 @@ def simulate_blob_data(
cluster_std=[1.]
)

rng = np.random.default_rng(0)
rng = np.random.default_rng(seed)

# ----------------------------
# Fixed radius, vary in/out fraction
Expand Down
17 changes: 3 additions & 14 deletions locat/wgmms.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
from functools import partial

import numpy as np
from sklearn.cluster import kmeans_plusplus
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
import tensorflow_probability.substrates.jax as jaxp

#from jax.config import config as jax_config

from concurrent.futures import ThreadPoolExecutor
import tensorflow_probability.substrates.jax.distributions as jaxd

jax.config.update("jax_enable_x64", True)

from sklearn.cluster import kmeans_plusplus
from loguru import logger

for d in jax.local_devices():
logger.info(f'Found device: {d}')
jaxd = jaxp.distributions


def _weighted_kmeans_init(X, w, n_c, n_inits):
return map(
Expand Down
9 changes: 4 additions & 5 deletions tests/utils/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
class SimulationTestCase(unittest.TestCase):

def test_simulated_data(self):
np.random.seed(0)
adata = simulate_blob_data(n_samples=500, n_tests=20, n_total=50)
gene_name = 'Gene_0'

Expand All @@ -19,12 +18,12 @@ def test_simulated_data(self):
gene_0_results = locat_results[gene_name]

self.assertEqual(gene_name, gene_0_results.gene_name)
self.assertAlmostEqual(-1.512298, gene_0_results.bic, places=4)
self.assertLessEqual(0., gene_0_results.zscore) # This changes randomly
self.assertAlmostEqual(-1.512298, gene_0_results.bic, places=5)
self.assertAlmostEqual(46.893, gene_0_results.zscore, places=2)
self.assertAlmostEqual(1.0, gene_0_results.sens_score, places=5)
self.assertLess(gene_0_results.depletion_pval, 1e-6)
self.assertAlmostEqual(3.98127e-8, gene_0_results.depletion_pval, places=12)
self.assertAlmostEqual(1e-15, gene_0_results.concentration_pval, places=14)
self.assertLess(gene_0_results.pval, 1e-2)
self.assertAlmostEqual( 0.000970843, gene_0_results.pval, places=8)
self.assertEqual(1, gene_0_results.K_components)
self.assertEqual(50, gene_0_results.sample_size)

Expand Down