From 7642452900d4a523130fb43ec8a8a067ba946365 Mon Sep 17 00:00:00 2001 From: millerh1 Date: Wed, 28 Jan 2026 00:24:42 +0000 Subject: [PATCH 1/2] Revert "Fixed bug causing DEGs order shuffling" This reverts commit 3ff329a5a85118cb5bdec5c75ce8eff123539530. --- README.md | 3 -- .../configs/modelgroup/simplebenchmark.yaml | 1 - cellsimbench/core/benchmark.py | 14 ++++---- cellsimbench/core/data_manager.py | 32 +++++++++++-------- cellsimbench/core/metrics_engine.py | 17 +++++----- 5 files changed, 33 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 48505fc..7ceceb5 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,6 @@ uv run cellsimbench train model=fmlp_esm2 dataset=norman19 # Run benchmark (prediction + evaluation) uv run cellsimbench benchmark model=fmlp_esm2 dataset=norman19 -# Enable NIR (Nearest In-distribution Reference) analysis (slow) -uv run cellsimbench benchmark model=fmlp_esm2 dataset=norman19 +run_nir_analysis=true - # Train and benchmark across multiple datasets for dataset in norman19 wessels23; do uv run cellsimbench train model=fmlp_esm2 dataset=$dataset diff --git a/cellsimbench/configs/modelgroup/simplebenchmark.yaml b/cellsimbench/configs/modelgroup/simplebenchmark.yaml index 32162b9..135b5f6 100644 --- a/cellsimbench/configs/modelgroup/simplebenchmark.yaml +++ b/cellsimbench/configs/modelgroup/simplebenchmark.yaml @@ -3,7 +3,6 @@ models: - sclambda - - presage - fmlp_esm2 description: "Compare some different perturbation response models" \ No newline at end of file diff --git a/cellsimbench/core/benchmark.py b/cellsimbench/core/benchmark.py index fd1db38..8309fcc 100644 --- a/cellsimbench/core/benchmark.py +++ b/cellsimbench/core/benchmark.py @@ -558,14 +558,13 @@ def _add_delta_calculations(self, predictions: sc.AnnData, control_pred = fold_baselines['control'] if 'control' in fold_baselines else universal_baselines['control'] delta_ctrl = np.zeros_like(predictions.X) # Find intersection of control and predictions var_names - # Use sorted list to ensure deterministic, reproducible ordering - common_var_names = sorted(set(control_pred.var_names) & set(predictions.var_names)) + common_var_names = set(control_pred.var_names) & set(predictions.var_names) if not common_var_names: raise ValueError("No common variable names found between control and predictions") # Filter control and predictions to only include common var_names - control_pred = control_pred[:, common_var_names] - predictions = predictions[:, common_var_names] + control_pred = control_pred[:, list(common_var_names)] + predictions = predictions[:, list(common_var_names)] for i, cov in enumerate(tqdm(predictions.obs['covariate'], desc="Adding delta from control")): # Find matching control - MUST exist @@ -584,14 +583,13 @@ def _add_delta_calculations(self, predictions: sc.AnnData, dataset_mean_pred = fold_baselines['dataset_mean'] delta_mean = np.zeros_like(predictions.X) # Find intersection of dataset mean and predictions var_names - # Use sorted list to ensure deterministic, reproducible ordering - common_var_names = sorted(set(dataset_mean_pred.var_names) & set(predictions.var_names)) + common_var_names = set(dataset_mean_pred.var_names) & set(predictions.var_names) if not common_var_names: raise ValueError("No common variable names found between dataset mean and predictions") # Filter dataset mean and predictions to only include common var_names - dataset_mean_pred = dataset_mean_pred[:, common_var_names] - predictions = predictions[:, common_var_names] + dataset_mean_pred = dataset_mean_pred[:, list(common_var_names)] + predictions = predictions[:, list(common_var_names)] for i, cov in enumerate(tqdm(predictions.obs['covariate'], desc="Adding delta from dataset mean")): # Find matching dataset mean - MUST exist diff --git a/cellsimbench/core/data_manager.py b/cellsimbench/core/data_manager.py index 7b0a15f..9cbedb5 100644 --- a/cellsimbench/core/data_manager.py +++ b/cellsimbench/core/data_manager.py @@ -213,17 +213,16 @@ def get_available_splits(self) -> List[str]: split_columns = [col for col in self.adata.obs.columns if 'split' in col.lower()] return split_columns - def get_deg_weights(self, covariate_value: str, perturbation: str, gene_order: List[str]) -> np.ndarray: + def get_deg_weights(self, covariate_value: str, perturbation: str, common_var_names: np.ndarray=None) -> np.ndarray: """ Get DEG-based weights for a specific covariate-perturbation combination. Args: covariate_value: Value of the covariate (e.g., donor ID) perturbation: Perturbation identifier - gene_order: Ordered list of gene names to align weights to. Returns: - Array of weights aligned with gene_order + Array of weights aligned with adata.var_names """ cov_pert_key = f"{covariate_value}_{perturbation}" @@ -231,34 +230,40 @@ def get_deg_weights(self, covariate_value: str, perturbation: str, gene_order: L weights = self.pert_normalized_abs_scores_vsrest_df[cov_pert_key] else: # Return zero weights if no DEG data available - return np.zeros(len(gene_order)) + weights = np.zeros(self.adata.n_vars) + return weights - # Reindex weights to match the target gene order - weights = weights.reindex(gene_order, fill_value=0.0) + # Filter the weights to only include the common var_names + if common_var_names is not None: + weights = weights[list(common_var_names)] return weights.values - def get_deg_mask(self, covariate_value: str, perturbation: str, gene_order: List[str], pval_threshold: float = 0.05) -> np.ndarray: + def get_deg_mask(self, covariate_value: str, perturbation: str, pval_threshold: float = 0.05, common_var_names_mask: np.ndarray=None) -> np.ndarray: """ Get DEG mask for a specific covariate-perturbation combination. Args: covariate_value: Value of the covariate (e.g., donor ID) perturbation: Perturbation identifier - gene_order: Ordered list of gene names to align the mask to. pval_threshold: P-value threshold for significance Returns: - Boolean array indicating DEG positions, aligned to gene_order + Boolean array indicating DEG positions """ cov_pert_key = f"{covariate_value}_{perturbation}" if cov_pert_key not in self.deg_pvals_dict: - return np.zeros(len(gene_order), dtype=bool) + return np.zeros(self.adata.n_vars, dtype=bool) - # Get p-values and gene names (these are in DEG rank order, not var_names order) + # Get p-values and gene names pvals = self.deg_pvals_dict[cov_pert_key] gene_names = self.deg_names_dict[cov_pert_key] + + if common_var_names_mask is not None: + pvals = pvals[common_var_names_mask] + gene_names = gene_names[common_var_names_mask] + # Create boolean mask for significant genes sig_mask = pvals < pval_threshold @@ -273,10 +278,9 @@ def get_deg_mask(self, covariate_value: str, perturbation: str, gene_order: List # Group by gene and take minimum p-value, then check significance pvals_aggregated = pvals_df.groupby('gene')['pval'].min() deg_mask_aggregated = pvals_aggregated < pval_threshold - - # Reindex to match the target gene order - deg_mask = deg_mask_aggregated.reindex(gene_order, fill_value=False) + # Reindex to match adata.var_names + deg_mask = deg_mask_aggregated.reindex(self.adata.var_names[common_var_names_mask] if common_var_names_mask is not None else self.adata.var_names, fill_value=False) return deg_mask.values diff --git a/cellsimbench/core/metrics_engine.py b/cellsimbench/core/metrics_engine.py index 16faa00..201e1c5 100644 --- a/cellsimbench/core/metrics_engine.py +++ b/cellsimbench/core/metrics_engine.py @@ -40,15 +40,16 @@ def calculate_all_metrics( ) -> Dict[str, Dict[str, float]]: # Ensure predictions and ground truth have the same var_names - # Use sorted list to ensure deterministic, reproducible ordering - common_var_names = sorted(set(predictions.columns) & set(ground_truth.columns)) + common_var_names = set(predictions.columns) & set(ground_truth.columns) + common_var_names_mask = ground_truth.columns.isin(common_var_names) if not common_var_names: raise ValueError("Predictions and ground truth have different var_names") - predictions = predictions[common_var_names] - ground_truth = ground_truth[common_var_names] - predictions_deltas = {key: df[common_var_names] for key, df in predictions_deltas.items()} - ground_truth_deltas = {key: df[common_var_names] for key, df in ground_truth_deltas.items()} + predictions = predictions[list(common_var_names)] + ground_truth = ground_truth[list(common_var_names)] + predictions_deltas = {key: df[list(common_var_names)] for key, df in predictions_deltas.items()} + ground_truth_deltas = {key: df[list(common_var_names)] for key, df in ground_truth_deltas.items()} + # Get a mask of the common var_names found in the predictions to use for filtering # Calculate nir scores (needs full dataset) - only if enabled @@ -89,8 +90,8 @@ def calculate_all_metrics( # Get DEG weights and mask using covariate and perturbation - weights = self.data_manager.get_deg_weights(covariate_value, condition, gene_order=common_var_names) - deg_mask = self.data_manager.get_deg_mask(covariate_value, condition, gene_order=common_var_names) + weights = self.data_manager.get_deg_weights(covariate_value, condition, common_var_names) + deg_mask = self.data_manager.get_deg_mask(covariate_value, condition, common_var_names_mask=common_var_names_mask) condition_metrics[cov_pert_key] = { 'mse': self._calculate_mse(pred_expression, truth_expression), 'wmse': self._calculate_wmse(pred_expression, truth_expression, weights), From 08db3c1021d0b09cae8d53528552efda775893cd Mon Sep 17 00:00:00 2001 From: millerh1 Date: Wed, 28 Jan 2026 00:25:44 +0000 Subject: [PATCH 2/2] Fixed minor README issue --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7ceceb5..83da16b 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Deep Learning-Based Genetic Perturbation Models Do Outperform Uninformative Baselines on Well-Calibrated Metrics +# Deep Learning-Based Genetic Perturbation Models *Do* Outperform Uninformative Baselines on Well-Calibrated Metrics **Preprint:** https://www.biorxiv.org/content/10.1101/2025.10.20.683304v1