Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Deep Learning-Based Genetic Perturbation Models Do Outperform Uninformative Baselines on Well-Calibrated Metrics
# Deep Learning-Based Genetic Perturbation Models *Do* Outperform Uninformative Baselines on Well-Calibrated Metrics

**Preprint:** https://www.biorxiv.org/content/10.1101/2025.10.20.683304v1

Expand Down
1 change: 0 additions & 1 deletion cellsimbench/configs/modelgroup/simplebenchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

models:
- sclambda
- presage
- fmlp_esm2

description: "Compare some different perturbation response models"
14 changes: 6 additions & 8 deletions cellsimbench/core/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,14 +560,13 @@ def _add_delta_calculations(self, predictions: sc.AnnData,
control_pred = fold_baselines['control'] if 'control' in fold_baselines else universal_baselines['control']
delta_ctrl = np.zeros_like(predictions.X)
# Find intersection of control and predictions var_names
# Use sorted list to ensure deterministic, reproducible ordering
common_var_names = sorted(set(control_pred.var_names) & set(predictions.var_names))
common_var_names = set(control_pred.var_names) & set(predictions.var_names)
if not common_var_names:
raise ValueError("No common variable names found between control and predictions")

# Filter control and predictions to only include common var_names
control_pred = control_pred[:, common_var_names]
predictions = predictions[:, common_var_names]
control_pred = control_pred[:, list(common_var_names)]
predictions = predictions[:, list(common_var_names)]

for i, cov in enumerate(tqdm(predictions.obs['covariate'], desc="Adding delta from control")):
# Find matching control - MUST exist
Expand All @@ -586,14 +585,13 @@ def _add_delta_calculations(self, predictions: sc.AnnData,
dataset_mean_pred = fold_baselines['dataset_mean']
delta_mean = np.zeros_like(predictions.X)
# Find intersection of dataset mean and predictions var_names
# Use sorted list to ensure deterministic, reproducible ordering
common_var_names = sorted(set(dataset_mean_pred.var_names) & set(predictions.var_names))
common_var_names = set(dataset_mean_pred.var_names) & set(predictions.var_names)
if not common_var_names:
raise ValueError("No common variable names found between dataset mean and predictions")

# Filter dataset mean and predictions to only include common var_names
dataset_mean_pred = dataset_mean_pred[:, common_var_names]
predictions = predictions[:, common_var_names]
dataset_mean_pred = dataset_mean_pred[:, list(common_var_names)]
predictions = predictions[:, list(common_var_names)]

for i, cov in enumerate(tqdm(predictions.obs['covariate'], desc="Adding delta from dataset mean")):
# Find matching dataset mean - MUST exist
Expand Down
29 changes: 17 additions & 12 deletions cellsimbench/core/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,28 +245,29 @@ def get_available_splits(self) -> List[str]:
split_columns = [col for col in self.adata.obs.columns if 'split' in col.lower()]
return split_columns

def get_deg_weights(self, covariate_value: str, perturbation: str, gene_order: List[str]) -> np.ndarray:
def get_deg_weights(self, covariate_value: str, perturbation: str, common_var_names: np.ndarray=None) -> np.ndarray:
"""
Get DEG-based weights for a specific covariate-perturbation combination.

Args:
covariate_value: Value of the covariate (e.g., donor ID)
perturbation: Perturbation identifier
gene_order: Ordered list of gene names to align weights to.

Returns:
Array of weights aligned with gene_order
Array of weights aligned with adata.var_names
"""
cov_pert_key = f"{covariate_value}_{perturbation}"

if cov_pert_key in self.pert_normalized_abs_scores_vsrest:
weights = self.pert_normalized_abs_scores_vsrest_df[cov_pert_key]
else:
# Return zero weights if no DEG data available
return np.zeros(len(gene_order))
weights = np.zeros(self.adata.n_vars)
return weights

# Reindex weights to match the target gene order
weights = weights.reindex(gene_order, fill_value=0.0)
# Filter the weights to only include the common var_names
if common_var_names is not None:
weights = weights[list(common_var_names)]

return weights.values

Expand All @@ -277,20 +278,24 @@ def get_deg_mask(self, covariate_value: str, perturbation: str, gene_order: List
Args:
covariate_value: Value of the covariate (e.g., donor ID)
perturbation: Perturbation identifier
gene_order: Ordered list of gene names to align the mask to.
pval_threshold: P-value threshold for significance
topn: Number of top DEGs to use. If None, all DEGs are used.
Returns:
Boolean array indicating DEG positions, aligned to gene_order
Boolean array indicating DEG positions
"""
cov_pert_key = f"{covariate_value}_{perturbation}"

if cov_pert_key not in self.deg_pvals_dict:
return np.zeros(len(gene_order), dtype=bool)
return np.zeros(self.adata.n_vars, dtype=bool)

# Get p-values and gene names (these are in DEG rank order, not var_names order)
# Get p-values and gene names
pvals = self.deg_pvals_dict[cov_pert_key]
gene_names = self.deg_names_dict[cov_pert_key]

if common_var_names_mask is not None:
pvals = pvals[common_var_names_mask]
gene_names = gene_names[common_var_names_mask]


# Create boolean mask for significant genes
sig_mask = pvals < pval_threshold
Expand All @@ -315,8 +320,8 @@ def get_deg_mask(self, covariate_value: str, perturbation: str, gene_order: List
pvals_aggregated = pvals_df.groupby('gene')['pval'].min()
deg_mask_aggregated = pvals_aggregated < pval_threshold

# Reindex to match the target gene order
deg_mask = deg_mask_aggregated.reindex(gene_order, fill_value=False)
# Reindex to match adata.var_names
deg_mask = deg_mask_aggregated.reindex(self.adata.var_names[common_var_names_mask] if common_var_names_mask is not None else self.adata.var_names, fill_value=False)

return deg_mask.values

Expand Down
13 changes: 7 additions & 6 deletions cellsimbench/core/metrics_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,16 @@ def calculate_all_metrics(
) -> Dict[str, Dict[str, float]]:

# Ensure predictions and ground truth have the same var_names
# Use sorted list to ensure deterministic, reproducible ordering
common_var_names = sorted(set(predictions.columns) & set(ground_truth.columns))
common_var_names = set(predictions.columns) & set(ground_truth.columns)
common_var_names_mask = ground_truth.columns.isin(common_var_names)

if not common_var_names:
raise ValueError("Predictions and ground truth have different var_names")
predictions = predictions[common_var_names]
ground_truth = ground_truth[common_var_names]
predictions_deltas = {key: df[common_var_names] for key, df in predictions_deltas.items()}
ground_truth_deltas = {key: df[common_var_names] for key, df in ground_truth_deltas.items()}
predictions = predictions[list(common_var_names)]
ground_truth = ground_truth[list(common_var_names)]
predictions_deltas = {key: df[list(common_var_names)] for key, df in predictions_deltas.items()}
ground_truth_deltas = {key: df[list(common_var_names)] for key, df in ground_truth_deltas.items()}
# Get a mask of the common var_names found in the predictions to use for filtering


# Calculate nir scores (needs full dataset) - only if enabled
Expand Down
Loading