diff --git a/README.md b/README.md index 8569d51..b44bfc6 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Deep Learning-Based Genetic Perturbation Models Do Outperform Uninformative Baselines on Well-Calibrated Metrics +# Deep Learning-Based Genetic Perturbation Models *Do* Outperform Uninformative Baselines on Well-Calibrated Metrics **Preprint:** https://www.biorxiv.org/content/10.1101/2025.10.20.683304v1 diff --git a/cellsimbench/configs/modelgroup/simplebenchmark.yaml b/cellsimbench/configs/modelgroup/simplebenchmark.yaml index 32162b9..135b5f6 100644 --- a/cellsimbench/configs/modelgroup/simplebenchmark.yaml +++ b/cellsimbench/configs/modelgroup/simplebenchmark.yaml @@ -3,7 +3,6 @@ models: - sclambda - - presage - fmlp_esm2 description: "Compare some different perturbation response models" \ No newline at end of file diff --git a/cellsimbench/core/benchmark.py b/cellsimbench/core/benchmark.py index 7669c2b..592843f 100644 --- a/cellsimbench/core/benchmark.py +++ b/cellsimbench/core/benchmark.py @@ -560,14 +560,13 @@ def _add_delta_calculations(self, predictions: sc.AnnData, control_pred = fold_baselines['control'] if 'control' in fold_baselines else universal_baselines['control'] delta_ctrl = np.zeros_like(predictions.X) # Find intersection of control and predictions var_names - # Use sorted list to ensure deterministic, reproducible ordering - common_var_names = sorted(set(control_pred.var_names) & set(predictions.var_names)) + common_var_names = set(control_pred.var_names) & set(predictions.var_names) if not common_var_names: raise ValueError("No common variable names found between control and predictions") # Filter control and predictions to only include common var_names - control_pred = control_pred[:, common_var_names] - predictions = predictions[:, common_var_names] + control_pred = control_pred[:, list(common_var_names)] + predictions = predictions[:, list(common_var_names)] for i, cov in enumerate(tqdm(predictions.obs['covariate'], desc="Adding delta from control")): # Find matching control - MUST exist @@ -586,14 +585,13 @@ def _add_delta_calculations(self, predictions: sc.AnnData, dataset_mean_pred = fold_baselines['dataset_mean'] delta_mean = np.zeros_like(predictions.X) # Find intersection of dataset mean and predictions var_names - # Use sorted list to ensure deterministic, reproducible ordering - common_var_names = sorted(set(dataset_mean_pred.var_names) & set(predictions.var_names)) + common_var_names = set(dataset_mean_pred.var_names) & set(predictions.var_names) if not common_var_names: raise ValueError("No common variable names found between dataset mean and predictions") # Filter dataset mean and predictions to only include common var_names - dataset_mean_pred = dataset_mean_pred[:, common_var_names] - predictions = predictions[:, common_var_names] + dataset_mean_pred = dataset_mean_pred[:, list(common_var_names)] + predictions = predictions[:, list(common_var_names)] for i, cov in enumerate(tqdm(predictions.obs['covariate'], desc="Adding delta from dataset mean")): # Find matching dataset mean - MUST exist diff --git a/cellsimbench/core/data_manager.py b/cellsimbench/core/data_manager.py index 45ffda6..5d70d6a 100644 --- a/cellsimbench/core/data_manager.py +++ b/cellsimbench/core/data_manager.py @@ -245,17 +245,16 @@ def get_available_splits(self) -> List[str]: split_columns = [col for col in self.adata.obs.columns if 'split' in col.lower()] return split_columns - def get_deg_weights(self, covariate_value: str, perturbation: str, gene_order: List[str]) -> np.ndarray: + def get_deg_weights(self, covariate_value: str, perturbation: str, common_var_names: np.ndarray=None) -> np.ndarray: """ Get DEG-based weights for a specific covariate-perturbation combination. Args: covariate_value: Value of the covariate (e.g., donor ID) perturbation: Perturbation identifier - gene_order: Ordered list of gene names to align weights to. Returns: - Array of weights aligned with gene_order + Array of weights aligned with adata.var_names """ cov_pert_key = f"{covariate_value}_{perturbation}" @@ -263,10 +262,12 @@ def get_deg_weights(self, covariate_value: str, perturbation: str, gene_order: L weights = self.pert_normalized_abs_scores_vsrest_df[cov_pert_key] else: # Return zero weights if no DEG data available - return np.zeros(len(gene_order)) + weights = np.zeros(self.adata.n_vars) + return weights - # Reindex weights to match the target gene order - weights = weights.reindex(gene_order, fill_value=0.0) + # Filter the weights to only include the common var_names + if common_var_names is not None: + weights = weights[list(common_var_names)] return weights.values @@ -277,20 +278,24 @@ def get_deg_mask(self, covariate_value: str, perturbation: str, gene_order: List Args: covariate_value: Value of the covariate (e.g., donor ID) perturbation: Perturbation identifier - gene_order: Ordered list of gene names to align the mask to. pval_threshold: P-value threshold for significance topn: Number of top DEGs to use. If None, all DEGs are used. Returns: - Boolean array indicating DEG positions, aligned to gene_order + Boolean array indicating DEG positions """ cov_pert_key = f"{covariate_value}_{perturbation}" if cov_pert_key not in self.deg_pvals_dict: - return np.zeros(len(gene_order), dtype=bool) + return np.zeros(self.adata.n_vars, dtype=bool) - # Get p-values and gene names (these are in DEG rank order, not var_names order) + # Get p-values and gene names pvals = self.deg_pvals_dict[cov_pert_key] gene_names = self.deg_names_dict[cov_pert_key] + + if common_var_names_mask is not None: + pvals = pvals[common_var_names_mask] + gene_names = gene_names[common_var_names_mask] + # Create boolean mask for significant genes sig_mask = pvals < pval_threshold @@ -315,8 +320,8 @@ def get_deg_mask(self, covariate_value: str, perturbation: str, gene_order: List pvals_aggregated = pvals_df.groupby('gene')['pval'].min() deg_mask_aggregated = pvals_aggregated < pval_threshold - # Reindex to match the target gene order - deg_mask = deg_mask_aggregated.reindex(gene_order, fill_value=False) + # Reindex to match adata.var_names + deg_mask = deg_mask_aggregated.reindex(self.adata.var_names[common_var_names_mask] if common_var_names_mask is not None else self.adata.var_names, fill_value=False) return deg_mask.values diff --git a/cellsimbench/core/metrics_engine.py b/cellsimbench/core/metrics_engine.py index 348f5b9..b16b1dc 100644 --- a/cellsimbench/core/metrics_engine.py +++ b/cellsimbench/core/metrics_engine.py @@ -81,15 +81,16 @@ def calculate_all_metrics( ) -> Dict[str, Dict[str, float]]: # Ensure predictions and ground truth have the same var_names - # Use sorted list to ensure deterministic, reproducible ordering - common_var_names = sorted(set(predictions.columns) & set(ground_truth.columns)) + common_var_names = set(predictions.columns) & set(ground_truth.columns) + common_var_names_mask = ground_truth.columns.isin(common_var_names) if not common_var_names: raise ValueError("Predictions and ground truth have different var_names") - predictions = predictions[common_var_names] - ground_truth = ground_truth[common_var_names] - predictions_deltas = {key: df[common_var_names] for key, df in predictions_deltas.items()} - ground_truth_deltas = {key: df[common_var_names] for key, df in ground_truth_deltas.items()} + predictions = predictions[list(common_var_names)] + ground_truth = ground_truth[list(common_var_names)] + predictions_deltas = {key: df[list(common_var_names)] for key, df in predictions_deltas.items()} + ground_truth_deltas = {key: df[list(common_var_names)] for key, df in ground_truth_deltas.items()} + # Get a mask of the common var_names found in the predictions to use for filtering # Calculate nir scores (needs full dataset) - only if enabled