diff --git a/.vscode/settings.json b/.vscode/settings.json index 9b38853..c263d5d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,5 +3,6 @@ "tests" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "python.terminal.activateEnvironment": false } \ No newline at end of file diff --git a/README.md b/README.md index 0743819..2205029 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ These are the required ANNOVAR components for ContextScore: ## User Workflow ```bash -contextscore --input input.vcf --output scored.vcf --sample-coverage 30 --buildver {hg38,hg19} --threshold 0.2 \ +contextscore --input input.vcf --output scored.vcf --sample-coverage 30 --buildver {hg38,hg19} \ --annovar /path/to/annovar --annovar-db /path/to/humandb ``` diff --git a/conda/meta.yaml b/conda/meta.yaml index 88d049d..c84c3fe 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -25,7 +25,7 @@ requirements: - scikit-learn =1.6.1 # For consistency with model training environment - joblib - bedtools - - contextscore-models + - contextscore-models ==1.0.0 about: home: https://github.com/WGLab/ContextScore diff --git a/contextscore/extract_features.py b/contextscore/extract_features.py index 1e090ee..42b74bc 100644 --- a/contextscore/extract_features.py +++ b/contextscore/extract_features.py @@ -133,6 +133,7 @@ def normalize_chrom_label(chrom): return chrom_str.upper() def extract_features(input_bed, annovar_path, db_path, outdiranno, buildversion='hg38', sample_coverage=None): + # ...existing code... """Extract the features from the BED file, columns are in the first row: chrom, start, end, sv_type, sv_length, genotype, read_depth, hmm_llh, aln_type, cluster_size @@ -170,6 +171,12 @@ def extract_features(input_bed, annovar_path, db_path, outdiranno, buildversion= # Normalize SV length to a positive magnitude. bed_df['sv_length'] = bed_df['sv_length'].abs() + # Add SV length interval features (4 non-type-specific intervals, one-hot encoded) + svlen = bed_df['sv_length'] + bed_df['svlen_50_500'] = ((svlen >= 50) & (svlen < 500)).astype(int) + bed_df['svlen_500_5000'] = ((svlen >= 500) & (svlen < 5000)).astype(int) + bed_df['svlen_5000_50000'] = ((svlen >= 5000) & (svlen < 50000)).astype(int) + bed_df['svlen_50000_plus'] = (svlen >= 50000).astype(int) # Drop the genotype column and cn_state columns (due to redundancy). bed_df.drop(columns=['genotype', 'cn_state'], inplace=True) @@ -240,62 +247,78 @@ def extract_features(input_bed, annovar_path, db_path, outdiranno, buildversion= # Drop the cn_state column from the data. bed_df = bed_df.drop(columns=['cn_state'], errors='ignore') - # Add distance to nearest other SV call, clustered false positives often appear near real SVs. - logging.info('Computing distance to nearest other SV call (same chromosome)...') - bed_df['dist_to_nearest_sv'] = np.nan - for chrom, idx in bed_df.groupby('chrom', sort=False).groups.items(): - chrom_df = bed_df.loc[idx, ['start', 'end']].sort_values(['start', 'end']) - n = chrom_df.shape[0] - - if n <= 1: + # Compute robust local-neighborhood SV features by chromosome + SV type. + # Nearest distance excludes overlaps and is stabilized with log and length-relative transforms. + logging.info('Computing robust nearest-SV and local SV-density features (same SV type)...') + bed_df['dist_nearest_nonoverlap_same_type'] = np.nan + bed_df['dist_nearest_nonoverlap_same_type_log1p'] = np.nan + bed_df['dist_nearest_nonoverlap_same_type_rel_log1p'] = np.nan + bed_df['local_same_type_count_1kb'] = 0 + bed_df['local_same_type_count_10kb'] = 0 + bed_df['local_same_type_count_100kb'] = 0 + + for (_, _), idx in bed_df.groupby(['chrom', 'sv_type_str'], sort=False).groups.items(): + idx_list = list(idx) + group_df = bed_df.loc[idx_list, ['start', 'end', 'sv_length']].copy() + n = group_df.shape[0] + + if n == 0: continue - starts = chrom_df['start'].to_numpy(dtype=np.int64) - ends = chrom_df['end'].to_numpy(dtype=np.int64) - - # Previous interval summary. - prev_max_end = np.maximum.accumulate(ends) - prev_max_end_excl = np.empty(n, dtype=np.int64) - prev_max_end_excl[0] = np.iinfo(np.int64).min - prev_max_end_excl[1:] = prev_max_end[:-1] - - # Next interval summary. - next_start_excl = np.empty(n, dtype=np.int64) - next_start_excl[:-1] = starts[1:] - next_start_excl[-1] = np.iinfo(np.int64).max - - # Overlap checks with prior/next intervals. - overlap_prev = prev_max_end_excl > starts - overlap_next = ends > next_start_excl - overlap_any = overlap_prev | overlap_next - - # Gap to closest left/right neighbor (touching intervals yield 0). - left_gap = starts - prev_max_end_excl - right_gap = next_start_excl - ends - - # No-left/no-right sentinels. - left_gap[0] = np.iinfo(np.int64).max - right_gap[-1] = np.iinfo(np.int64).max - - nearest = np.minimum(left_gap, right_gap).astype(np.float64) - nearest[overlap_any] = 0.0 - - # Any remaining sentinel values are undefined (should only happen in degenerate cases). - sentinel = float(np.iinfo(np.int64).max) - nearest[nearest >= sentinel] = np.nan - - bed_df.loc[chrom_df.index, 'dist_to_nearest_sv'] = nearest - - logging.info('Distance to nearest SV calculated. Coverage: %.1f%%', (bed_df['dist_to_nearest_sv'].notna().sum() / len(bed_df) * 100)) - - # Print statistics about the distance to nearest SV feature. - logging.info('Distance to nearest SV - mean: %.2f, median: %.2f, std: %.2f', bed_df['dist_to_nearest_sv'].mean(), bed_df['dist_to_nearest_sv'].median(), bed_df['dist_to_nearest_sv'].std()) - - # Normalize by SV size - bed_df['dist_nearest_sv_per_kb'] = np.where( - bed_df['sv_length'] > 0, - bed_df['dist_to_nearest_sv'] / (bed_df['sv_length'] / 1000.0), - bed_df['dist_to_nearest_sv'] + starts = pd.to_numeric(group_df['start'], errors='coerce').to_numpy(dtype=np.int64) + ends = pd.to_numeric(group_df['end'], errors='coerce').to_numpy(dtype=np.int64) + lengths = pd.to_numeric(group_df['sv_length'], errors='coerce').to_numpy(dtype=np.float64) + + # Ensure interval ordering is valid even for malformed coordinates. + left = np.minimum(starts, ends) + right = np.maximum(starts, ends) + + nearest = np.full(n, np.nan, dtype=np.float64) + if n > 1: + sorted_ends = np.sort(right) + sorted_starts = np.sort(left) + + left_pos = np.searchsorted(sorted_ends, left, side='right') - 1 + right_pos = np.searchsorted(sorted_starts, right, side='left') + + left_gap = np.full(n, np.inf, dtype=np.float64) + valid_left = left_pos >= 0 + left_gap[valid_left] = (left[valid_left] - sorted_ends[left_pos[valid_left]]).astype(np.float64) + + right_gap = np.full(n, np.inf, dtype=np.float64) + valid_right = right_pos < n + right_gap[valid_right] = (sorted_starts[right_pos[valid_right]] - right[valid_right]).astype(np.float64) + + nearest = np.minimum(left_gap, right_gap) + nearest[np.isinf(nearest)] = np.nan + + # Length-relative normalization keeps comparability across 100bp to 100kb+ SVs. + length_scale = np.maximum(np.abs(lengths), 100.0) + nearest_rel = nearest / length_scale + + bed_df.loc[idx_list, 'dist_nearest_nonoverlap_same_type'] = nearest + bed_df.loc[idx_list, 'dist_nearest_nonoverlap_same_type_log1p'] = np.log1p(nearest) + bed_df.loc[idx_list, 'dist_nearest_nonoverlap_same_type_rel_log1p'] = np.log1p(nearest_rel) + + # Local center-based same-type SV density counts. + centers = ((left + right) // 2).astype(np.int64) + sorted_centers = np.sort(centers) + for window_bp, col_name in [ + (1000, 'local_same_type_count_1kb'), + (10000, 'local_same_type_count_10kb'), + (100000, 'local_same_type_count_100kb'), + ]: + lo = np.searchsorted(sorted_centers, centers - window_bp, side='left') + hi = np.searchsorted(sorted_centers, centers + window_bp, side='right') + counts = (hi - lo - 1).astype(np.int32) # Exclude self. + bed_df.loc[idx_list, col_name] = counts + + coverage_pct = bed_df['dist_nearest_nonoverlap_same_type'].notna().mean() * 100 + logging.info( + 'Nearest same-type non-overlap distance computed. Coverage: %.2f%%, mean(log1p): %.3f, mean(rel_log1p): %.3f', + coverage_pct, + bed_df['dist_nearest_nonoverlap_same_type_log1p'].mean(skipna=True), + bed_df['dist_nearest_nonoverlap_same_type_rel_log1p'].mean(skipna=True), ) # Return the features dataframe. diff --git a/contextscore/predict.py b/contextscore/predict.py index 4b63591..6b79dfa 100644 --- a/contextscore/predict.py +++ b/contextscore/predict.py @@ -21,13 +21,13 @@ import numpy as np import joblib import pandas as pd +from sklearn.mixture import GaussianMixture try: from .extract_features import extract_features except ImportError: from extract_features import extract_features - USER_PREFIX = "[ContextScore]" DEFAULT_MODEL_ENV_VAR = 'CONTEXTSCORE_MODEL_PATH' DEFAULT_MODEL_INSTALL_PATH = os.path.join( @@ -47,8 +47,12 @@ def user_message(message): def configure_logging(verbose=False, debug=False): """Configure logging output level based on user-selected mode.""" level = logging.DEBUG if debug else (logging.INFO if verbose else logging.WARNING) - logging.basicConfig(level=level, format='%(asctime)s - %(levelname)s - %(message)s') - + logging.basicConfig( + level=level, + format='%(asctime)s - %(levelname)s - %(message)s', + stream=sys.stdout, + force=True, + ) def resolve_annovar_paths(annovar_path, annovar_db_path): """Resolve ANNOVAR paths from CLI flags or environment variables.""" @@ -237,36 +241,117 @@ def _extract_sample_field(row, field_name): logging.info('Created BED file: %s', output_bed) return skipped_chrom_ids -def score(model, input_vcf, output_vcf, buildver='hg38', threshold=0.05, - threshold_del=None, threshold_dup=None, threshold_ins=None, threshold_inv=None, - sample_coverage=None, large_cutoff=10000, annovar_path=None, annovar_db_path=None, - debug_plot=False): +def add_confidence_to_info(line, confidence_score): + fields = line.rstrip('\n').split('\t') + fields[7] += f';CONFSCORE={confidence_score:.4f}' + return '\t'.join(fields) + '\n' + +def gmm_threshold(scores, fallback=0.2, max_threshold=0.5, min_samples=20): + """Fit a robust 2-component GMM threshold, with safeguards for unimodal distributions.""" + fallback = float(np.clip(fallback, 0.0, max_threshold)) + scores = np.asarray(scores, dtype=np.float64) + scores = scores[np.isfinite(scores)] + scores = np.clip(scores, 0.0, 1.0) + + if scores.size < min_samples: + logging.info('GMM threshold skipped: only %d samples (min=%d). Using fallback %.4f', scores.size, min_samples, fallback) + return fallback + + try: + x = scores.reshape(-1, 1) + gmm_1 = GaussianMixture(n_components=1, random_state=42) + gmm_2 = GaussianMixture(n_components=2, random_state=42) + gmm_1.fit(x) + gmm_2.fit(x) + + bic_1 = float(gmm_1.bic(x)) + bic_2 = float(gmm_2.bic(x)) + bic_gain = bic_1 - bic_2 + + # If 2-component fit is not clearly better, treat as unimodal. + if bic_gain < 10.0: + logging.info( + 'GMM threshold fallback: weak 2-component evidence (BIC gain %.2f). Using fallback %.4f', + bic_gain, + fallback, + ) + return fallback + + means = gmm_2.means_.flatten() + weights = gmm_2.weights_.flatten() + stds = np.sqrt(gmm_2.covariances_.flatten()) + low_idx, high_idx = np.argsort(means) + + mean_gap = float(means[high_idx] - means[low_idx]) + separation = mean_gap / float(stds[high_idx] + stds[low_idx] + 1e-9) + + # Guard against pseudo-bimodal fits where one component just captures a tail. + if mean_gap < 0.08 or separation < 1.0 or float(weights.min()) < 0.10: + logging.info( + 'GMM threshold fallback: weak separation (gap=%.3f, sep=%.3f, min_weight=%.3f). Using fallback %.4f', + mean_gap, + separation, + float(weights.min()), + fallback, + ) + return fallback + + xs = np.linspace(means[low_idx], means[high_idx], 800).reshape(-1, 1) + posteriors = gmm_2.predict_proba(xs) + cross = np.where(posteriors[:, high_idx] >= 0.5)[0] + if len(cross) == 0: + logging.info('GMM threshold fallback: no posterior crossing found. Using fallback %.4f', fallback) + return fallback + + threshold = float(xs[cross[0]]) + + # Quantile guardrails prevent over-aggressive cutoffs when one high peak dominates. + # Hard cap at max_threshold as requested by filtering policy. + low_guard = min(float(np.quantile(scores, 0.05)), max_threshold) + high_guard = min(float(np.quantile(scores, 0.85)), max_threshold) + if high_guard < low_guard: + low_guard = high_guard + + threshold_clipped = float(np.clip(threshold, low_guard, high_guard)) + if threshold_clipped != threshold: + logging.info( + 'GMM threshold clipped from %.4f to %.4f using guards [%.4f, %.4f].', + threshold, + threshold_clipped, + low_guard, + high_guard, + ) + + return threshold_clipped + except Exception as exc: + logging.warning('GMM threshold fallback after error: %s. Using fallback %.4f', str(exc), fallback) + return fallback + +def score(model, input_vcf, output_vcf, buildver='hg38', + sample_coverage=None, annovar_path=None, annovar_db_path=None, + debug_plot=False, sample_name=None): """Score the structural variants using the binary classification model. Args: model (str): Path to the model file. input_vcf (str): Path to the input VCF file. output_vcf (str): Path to the output VCF file. - threshold (float): Default threshold for SV types not specified. - threshold_del (float): Optional. Threshold for DEL variants. If None, uses default threshold. - threshold_dup (float): Optional. Threshold for DUP variants. If None, uses default threshold. - threshold_ins (float): Optional. Threshold for INS variants. If None, uses default threshold. - threshold_inv (float): Optional. Threshold for INV variants. If None, uses default threshold. sample_coverage (float): Required. Mean read depth coverage for the sample. - large_cutoff (int): SV size cutoff in bp; variants larger than this are always kept (default: 50000). + sample_name (str): Optional. Name shown in debug probability plot title. """ - # Build threshold dictionary with type-specific values + # Threshold policy: per-type GMM when valid, otherwise fallback values. + + gmm_fallback_threshold = 0.2 + max_threshold = 0.3 threshold_by_type = { - 'DEL': threshold_del if threshold_del is not None else threshold, - 'DUP': threshold_dup if threshold_dup is not None else threshold, - 'INS': threshold_ins if threshold_ins is not None else threshold, - 'INV': threshold_inv if threshold_inv is not None else threshold, + 'DEL': gmm_fallback_threshold, + 'DUP': gmm_fallback_threshold, + 'INS': gmm_fallback_threshold, + 'INV': gmm_fallback_threshold, } - - prob_threshold = threshold - logging.info('Using confidence threshold policy:') - for svtype, thr in sorted(threshold_by_type.items()): - logging.info(' %s: %.3f', svtype, thr) + + prob_threshold = gmm_fallback_threshold + logging.info('Using confidence threshold policy: GMM per SV type, fallback=%.4f, max=%.4f', gmm_fallback_threshold, max_threshold) output_dir = os.path.dirname(os.path.abspath(output_vcf)) or '.' os.makedirs(output_dir, exist_ok=True) @@ -304,11 +389,16 @@ def score(model, input_vcf, output_vcf, buildver='hg38', threshold=0.05, 'end': pd.to_numeric(feature_df['end'], errors='coerce').astype('Int64').values if 'end' in feature_df.columns else pd.Series([pd.NA] * len(id_col), dtype='Int64').values, 'sv_type_str': feature_df['sv_type_str'].astype(str).values if 'sv_type_str' in feature_df.columns else np.nan, 'sv_length': pd.to_numeric(feature_df['sv_length'], errors='coerce').astype('Int64').values if 'sv_length' in feature_df.columns else pd.Series([pd.NA] * len(id_col), dtype='Int64').values, + # Only the 4 non-type-specific interval features + 'svlen_50_500': feature_df['svlen_50_500'].values if 'svlen_50_500' in feature_df.columns else np.nan, + 'svlen_500_5000': feature_df['svlen_500_5000'].values if 'svlen_500_5000' in feature_df.columns else np.nan, + 'svlen_5000_50000': feature_df['svlen_5000_50000'].values if 'svlen_5000_50000' in feature_df.columns else np.nan, + 'svlen_50000_plus': feature_df['svlen_50000_plus'].values if 'svlen_50000_plus' in feature_df.columns else np.nan, }) predictions_meta['sv_length_abs'] = predictions_meta['sv_length'].abs() - # Remove other non-feature columns before prediction. - # Keep normalized *_per_kb features; remove raw versions. + # Remove non-feature columns before prediction. + # Keep normalized *_per_kb features and keep raw sv_length for length-aware models. for col in ['chrom', 'start', 'end', 'sv_type_str', 'cluster_size', 'dist_to_nearest_sv', 'read_depth']: if col in feature_df.columns: feature_df.pop(col) @@ -342,16 +432,41 @@ def score(model, input_vcf, output_vcf, buildver='hg38', threshold=0.05, predictions_df.to_csv(predictions_tsv, sep='\t', index=False) logging.info('Saved per-variant predictions to %s', predictions_tsv) + # --- Adaptive GMM thresholds (per SV type) --- + logging.info('Fitting per-SV-type GMM thresholds with safeguards (fallback/max=%.4f/%.4f)...', gmm_fallback_threshold, max_threshold) + gmm_thresholds = {} + for svtype in ['DEL', 'DUP', 'INS', 'INV']: + mask = predictions_df['sv_type_str'] == svtype + scores = predictions_df.loc[mask, 'confidence_score'].values + if len(scores) >= 20: + t = gmm_threshold(scores, fallback=gmm_fallback_threshold, max_threshold=max_threshold, min_samples=20) + logging.info(' %s threshold from GMM: %.4f (n=%d, capped <= %s)', svtype, t, len(scores), max_threshold) + else: + t = gmm_fallback_threshold + logging.info(' Too few %s variants (%d), using fallback %.4f', svtype, len(scores), t) + gmm_thresholds[svtype] = t + threshold_by_type = gmm_thresholds # override static thresholds + + logging.info('Final thresholds after GMM fitting:') + for svtype, thr in sorted(threshold_by_type.items()): + logging.info(' %s: %.4f', svtype, thr) + if debug_plot: plt, sns = try_import_plotting_libs() if plt is None or sns is None: logging.warning('Debug plotting requested but matplotlib/seaborn are not installed. Skipping plot generation.') else: + dataset_name = sample_name if sample_name else os.path.basename(input_vcf) + if dataset_name.endswith('.vcf.gz'): + dataset_name = dataset_name[:-7] + elif dataset_name.endswith('.vcf'): + dataset_name = dataset_name[:-4] + _, ax = plt.subplots() sns.histplot(y_pred[:, 1], bins=20, ax=ax) ax.set_xlabel('Confidence Score') ax.set_ylabel('Count') - ax.set_title('Probability Distribution') + ax.set_title(f'{dataset_name} Probability Distribution') plot_path = os.path.join(output_dir, 'prob_dist.svg') plt.savefig(plot_path) plt.close() @@ -364,15 +479,15 @@ def score(model, input_vcf, output_vcf, buildver='hg38', threshold=0.05, logging.info('Built variant lookup with %d entries for type-specific filtering', len(variant_lookup)) - # For backward compatibility, also track variants below the default threshold + # Track variants below the fallback threshold for logging/debugging. filtered_indices = np.where(y_pred[:, 1] < prob_threshold)[0] - logging.info('Number of variants under the default probability threshold %.2f: %d', prob_threshold, len(filtered_indices)) + logging.info('Number of variants under the fallback probability threshold %.2f: %d', prob_threshold, len(filtered_indices)) # Get the IDs of the filtered variants (for logging/debugging) filtered_ids = id_col.iloc[filtered_indices].values filtered_ids_file = os.path.join(output_dir, 'filtered_ids.txt') np.savetxt(filtered_ids_file, filtered_ids, fmt='%s') - logging.info('Saved the filtered IDs (using default threshold) to %s', filtered_ids_file) + logging.info('Saved filtered IDs (using fallback threshold) to %s', filtered_ids_file) # Create a VCF file with only the filtered variants removed_svs_vcf = os.path.join(output_dir, 'removed_svs.vcf') @@ -391,6 +506,10 @@ def score(model, input_vcf, output_vcf, buildver='hg38', threshold=0.05, with open_vcf_text(input_vcf) as vcf_in, open(output_vcf, 'w', encoding='utf-8') as vcf_out, open(removed_svs_vcf, 'w', encoding='utf-8') as removed_out: for line in vcf_in: if line.startswith('#'): + # Add CONFSCORE INFO header before the #CHROM line + if line.startswith('#CHROM'): + vcf_out.write('##INFO=\n') + # Write the header lines as they are vcf_out.write(line) removed_out.write(line) @@ -418,7 +537,7 @@ def score(model, input_vcf, output_vcf, buildver='hg38', threshold=0.05, type_filter_stats[svtype_for_stats] = {'total': 0, 'kept': 0, 'filtered': 0} type_filter_stats[svtype_for_stats]['total'] += 1 type_filter_stats[svtype_for_stats]['kept'] += 1 - vcf_out.write(line) + vcf_out.write(add_confidence_to_info(line, -1.0)) pass_count += 1 total_records += 1 current_record += 1 @@ -431,19 +550,20 @@ def score(model, input_vcf, output_vcf, buildver='hg38', threshold=0.05, svtype = svtype_match if svtype_match else predicted_svtype else: # Variant not in predictions (shouldn't happen, but handle gracefully) - logging.warning('Variant %d not found in predictions lookup, using default threshold', current_record) + logging.warning('Variant %d not found in predictions lookup, using fallback threshold policy', current_record) confidence_score = 0.0 svtype = svtype_match if svtype_match else 'UNKNOWN' # Get the appropriate threshold for this SV type type_threshold = threshold_by_type.get(svtype, prob_threshold) - # Determine if variant should be kept - is_large_sv = svlen_match is not None and abs(svlen_match) > large_cutoff - passes_threshold = confidence_score >= type_threshold - - # Keep if: (large SV) OR (passes type-specific threshold) - should_keep = is_large_sv or passes_threshold + # Relax threshold for larger SVs + abs_svlen = abs(svlen_match) + if abs_svlen is not None and abs_svlen > 10000: + type_threshold = 0.1 * type_threshold + + # Keep if larger than threshold or >100kb and not deletion + should_keep = confidence_score >= type_threshold or (abs_svlen is not None and abs_svlen > 100000) # Track statistics by type if svtype not in type_filter_stats: @@ -451,12 +571,12 @@ def score(model, input_vcf, output_vcf, buildver='hg38', threshold=0.05, type_filter_stats[svtype]['total'] += 1 if should_keep: - vcf_out.write(line) + vcf_out.write(add_confidence_to_info(line, confidence_score)) pass_count += 1 type_filter_stats[svtype]['kept'] += 1 else: # Write the line to the removed_svs.vcf file if filtered - removed_out.write(line) + removed_out.write(add_confidence_to_info(line, confidence_score)) filter_count += 1 type_filter_stats[svtype]['filtered'] += 1 @@ -494,20 +614,10 @@ def main(argv=None): help='Path to the model file. Optional if CONTEXTSCORE_MODEL_PATH is set or default packaged model is installed.') parser.add_argument('--buildver', type=str, default='hg38', help='Genome build version (default: hg38).') - parser.add_argument('--threshold', type=float, default=0.2, - help='Default threshold for filtering predictions (default: 0.2). Used for SV types without specific thresholds.') - parser.add_argument('--threshold-del', type=float, default=None, - help='Threshold for DEL variants (default: uses --threshold value).') - parser.add_argument('--threshold-dup', type=float, default=None, - help='Threshold for DUP variants (default: uses --threshold value).') - parser.add_argument('--threshold-ins', type=float, default=None, - help='Threshold for INS variants (default: uses --threshold value).') - parser.add_argument('--threshold-inv', type=float, default=None, - help='Threshold for INV variants (default: uses --threshold value).') parser.add_argument('--sample-coverage', type=float, required=True, help='Mean read depth coverage for the sample (required, used to normalize read_depth).') - parser.add_argument('--large-cutoff', type=int, default=10000, - help='SV size cutoff in bp; variants larger than this are always kept (default: 50000).') + parser.add_argument('--sample-name', type=str, default=None, + help='Optional sample/dataset name used in debug probability plot title.') parser.add_argument('--annovar', type=str, default=None, help='Path to ANNOVAR installation directory. Can also be set via ANNOVAR_PATH.') parser.add_argument('--annovar-db', type=str, default=None, @@ -581,12 +691,11 @@ def main(argv=None): # Run the scoring function summary = score(model, input_vcf, output_vcf, buildver=buildver, - threshold=args.threshold, sample_coverage=args.sample_coverage, - threshold_del=args.threshold_del, threshold_dup=args.threshold_dup, - threshold_ins=args.threshold_ins, threshold_inv=args.threshold_inv, - large_cutoff=args.large_cutoff, annovar_path=annovar_path, + sample_coverage=args.sample_coverage, + annovar_path=annovar_path, annovar_db_path=annovar_db_path, - debug_plot=args.debug_plot) + debug_plot=args.debug_plot, + sample_name=args.sample_name) user_message( f"Completed. Kept {summary['passed_records']}/{summary['total_records']} variants; filtered {summary['filtered_records']}." diff --git a/contextscore/train_full_model.py b/contextscore/train_full_model.py index d1f2d64..0db5b5a 100644 --- a/contextscore/train_full_model.py +++ b/contextscore/train_full_model.py @@ -28,6 +28,12 @@ # Manuscript-friendly display labels for model features. FEATURE_DISPLAY_NAMES = { 'dist_nearest_sv_per_kb': 'Nearest SV distance / kb', + 'dist_nearest_nonoverlap_same_type': 'Nearest non-overlap same-type distance (bp)', + 'dist_nearest_nonoverlap_same_type_log1p': 'Nearest non-overlap same-type distance log1p', + 'dist_nearest_nonoverlap_same_type_rel_log1p': 'Nearest non-overlap same-type distance rel-log1p', + 'local_same_type_count_1kb': 'Same-type local SV count (1kb)', + 'local_same_type_count_10kb': 'Same-type local SV count (10kb)', + 'local_same_type_count_100kb': 'Same-type local SV count (100kb)', 'cluster_size_per_kb': 'Cluster size / kb', 'sv_length': 'SV length (bp)', 'read_depth_normalized': 'Normalized depth', @@ -43,7 +49,11 @@ 'fragile_site': 'Fragile-site overlap', 'phastCons': 'phastCons score', 'hmm_llh': 'HMM log-likelihood', - 'aln_offset': 'Alignment offset' + 'aln_offset': 'Alignment offset', + 'svlen_50_500': 'SV length 50-500bp', + 'svlen_500_5000': 'SV length 500-5,000bp', + 'svlen_5000_50000': 'SV length 5,000-50,000bp', + 'svlen_50000_plus': 'SV length ≥50,000bp', } ENABLE_SHAP = False @@ -85,6 +95,39 @@ def get_cv_splits(y, max_splits=5): return StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) + +def assign_sv_length_bin(sv_length_series): + """Assign SV lengths to manuscript-aligned bins used for weighting and reporting.""" + length_labels = ['50-500bp', '500-5kb', '5-50kb', '>=50kb'] + sv_length_numeric = pd.to_numeric(sv_length_series, errors='coerce').fillna(50).clip(lower=50) + return pd.cut( + sv_length_numeric, + bins=[50, 500, 5000, 50000, float('inf')], + labels=length_labels, + include_lowest=True + ) + + +def compute_length_aware_sample_weights(feature_df, labels): + """Compute sample weights using SV-length rarity only.""" + if 'sv_length' not in feature_df.columns: + logging.warning('sv_length not available; using uniform sample weights.') + return np.ones(len(labels), dtype='float64') + + length_bins = assign_sv_length_bin(feature_df['sv_length']).astype(str).reset_index(drop=True) + + # Inverse-frequency length-bin weights to upweight sparse large-SV bins. + bin_counts = length_bins.value_counts() + bin_weights = length_bins.map(lambda b: len(length_bins) / (len(bin_counts) * bin_counts[b])) + + weights = bin_weights + weights = weights / weights.mean() + + logging.info('SV length-bin counts in training set: %s', bin_counts.to_dict()) + logging.info('Computed length-aware sample weights (min=%.3f, max=%.3f, mean=%.3f)', + float(weights.min()), float(weights.max()), float(weights.mean())) + return weights.to_numpy(dtype='float64') + def balance_tp_fp_datasets(tp_data, fp_data): """Balance the true positive and false positive datasets by undersampling the lower-count class.""" tp_count = tp_data.shape[0] @@ -169,9 +212,18 @@ def stratified_undersample_fp(fp_data, target_count, random_state=42): # Create length bins for stratification fp_data_temp = fp_data.copy() - fp_data_temp['length_bin'] = pd.cut(fp_data_temp['sv_length'], - bins=[0, 1000, 10000, 100000, float('inf')], - labels=['<1kb', '1-10kb', '10-100kb', '>100kb']) + # fp_data_temp['length_bin'] = pd.cut(fp_data_temp['sv_length'], + # bins=[0, 1000, 10000, 100000, float('inf')], + # labels=['<1kb', '1-10kb', '10-100kb', '>100kb']) + + # Use the same four bins used in training/evaluation (all SVs are >=50bp). + sv_length_numeric = pd.to_numeric(fp_data_temp['sv_length'], errors='coerce').fillna(50).clip(lower=50) + fp_data_temp['length_bin'] = pd.cut( + sv_length_numeric, + bins=[50, 500, 5000, 50000, float('inf')], + labels=['50-500bp', '500-5kb', '5-50kb', '>=50kb'], + include_lowest=True + ) # Create stratification column combining SV type and length bin fp_data_temp['stratum'] = fp_data_temp['sv_type'].astype(str) + '_' + fp_data_temp['length_bin'].astype(str) @@ -319,7 +371,7 @@ def train(tp_hg002_grch37, fp_hg002_grch37, tp_visor_grch38, fp_visor_grch38, tp chrom_col = data.pop('chrom') # Drop columns that are not needed for training. - # Keep normalized *_per_kb features; remove raw versions. + # Keep normalized *_per_kb features and keep raw sv_length for length-aware learning. data = data.drop(columns=['start', 'end', 'sv_type_str', 'cluster_size', 'dist_to_nearest_sv', 'read_depth'], errors='ignore') logging.info('Columns list after preprocessing: %s', data.columns.tolist()) @@ -421,6 +473,7 @@ def train(tp_hg002_grch37, fp_hg002_grch37, tp_visor_grch38, fp_visor_grch38, tp X_train_chrom_processed = preprocess_feature_matrix(X_train_chrom) X_test_chrom_processed = preprocess_feature_matrix(X_test_chrom) + sample_weights_chrom = compute_length_aware_sample_weights(X_train_chrom, y_train_chrom) fold_cv = get_cv_splits(y_train_chrom) if fold_cv is None: @@ -438,7 +491,7 @@ def train(tp_hg002_grch37, fp_hg002_grch37, tp_visor_grch38, fp_visor_grch38, tp scoring='precision', n_jobs=-1 ) - grid_search.fit(X_train_chrom_processed, y_train_chrom) + grid_search.fit(X_train_chrom_processed, y_train_chrom, classifier__sample_weight=sample_weights_chrom) best_model = grid_search.best_estimator_ logging.info( 'Best hyperparameters for %s on held-out chromosome %s: %s', @@ -484,23 +537,29 @@ def train(tp_hg002_grch37, fp_hg002_grch37, tp_visor_grch38, fp_visor_grch38, tp ) from exc metrics = ['F1 Score', 'Precision', 'Recall'] for model_name in pipelines.keys(): + model_name_label = model_name.replace("_", " ") # Save a plot with F1, Precision, and Recall scores for chrY if 'chrY' in chromosomes: - logging.info('Plotting scores for %s model on chrY.', model_name) + logging.info('Plotting scores for %s model on chrY.', model_name_label) # Create a bar plot for the F1 scores by chromosome. chry_f1 = f1_scores.get((model_name, 'chrY'), 0) chry_precision = precision_scores.get((model_name, 'chrY'), 0) chry_recall = recall_scores.get((model_name, 'chrY'), 0) - plt.figure(figsize=(6, 4)) + plt.figure(figsize=(5, 3)) # Plot F1, Precision, and Recall scores for chrY. - sns.barplot(x=['F1 Score', 'Precision', 'Recall'], y=[chry_f1, chry_precision, chry_recall], color='black') + ax = sns.barplot( + x=['F1 Score', 'Precision', 'Recall'], + y=[chry_f1, chry_precision, chry_recall], + color='#1f77b4' + ) # plt.xlabel('Metric') plt.ylabel('Score') - plt.title('%s Scores for %s Model on chrY' % (model_name, model_name)) + ax.set_ylim(0, 1.0) + plt.title('%s Scores for %s Model on chrY' % (model_name_label, model_name_label)) plt.xticks(rotation=45) plt.tight_layout() # Save the plot to the output directory. @@ -510,15 +569,20 @@ def train(tp_hg002_grch37, fp_hg002_grch37, tp_visor_grch38, fp_visor_grch38, tp logging.info('Saved the scores plot for chrY to %s', score_plot_path) for metric, scores in zip(metrics, [f1_scores, precision_scores, recall_scores]): - logging.info('Plotting %s for %s model by chromosome.', metric, model_name) + logging.info('Plotting %s for %s model by chromosome.', metric, model_name_label) # Create a bar plot for the F1 scores by chromosome. model_scores = {chrom: scores[(model_name, chrom)] for chrom in chromosomes if (model_name, chrom) in scores} - plt.figure(figsize=(10, 6)) - ax = sns.barplot(x=list(model_scores.keys()), y=list(model_scores.values()), color='black') + plt.figure(figsize=(8, 4)) + ax = sns.barplot( + x=list(model_scores.keys()), + y=list(model_scores.values()), + color='#1f77b4' + ) plt.xlabel('Chromosome') plt.ylabel(metric) - plt.title('%s for %s Model by Chromosome' % (metric, model_name)) + ax.set_ylim(0, 1.0) + plt.title('%s for %s Model by Chromosome' % (metric, model_name_label)) plt.xticks(rotation=45) plt.tight_layout() score_plot_path = os.path.join(output_directory, model_name + '_%s_by_chromosome.svg' % metric.lower().replace(' ', '_')) @@ -535,7 +599,8 @@ def train(tp_hg002_grch37, fp_hg002_grch37, tp_visor_grch38, fp_visor_grch38, tp raise ValueError('Unable to run training: need at least two classes with at least two samples each for stratified CV.') for model_name, pipeline in pipelines.items(): - logging.info('Training model class %s', model_name) + model_name_label = model_name.replace("_", " ") + logging.info('Training model class %s', model_name_label) model_name_fp = "contextscore_" + model_name.lower() + "_leaveout_" + leave_out if split_80_20: @@ -544,9 +609,10 @@ def train(tp_hg002_grch37, fp_hg002_grch37, tp_visor_grch38, fp_visor_grch38, tp # Perform grid search to find the best hyperparameters for the model, optimizing for precision to prioritize reducing false positives. X_train_processed = preprocess_feature_matrix(X_train) X_test_processed = preprocess_feature_matrix(X_test) + sample_weights = compute_length_aware_sample_weights(X_train, y_train) grid_search = GridSearchCV(estimator=pipeline, param_grid=param_grids[model_name], cv=cv, scoring='precision', n_jobs=-1) - grid_search.fit(X_train_processed, y_train) + grid_search.fit(X_train_processed, y_train, classifier__sample_weight=sample_weights) logging.info('Best hyperparameters for %s: %s', model_name, grid_search.best_params_) # Get predicted probabilities for the training and testing sets. @@ -584,7 +650,6 @@ def train(tp_hg002_grch37, fp_hg002_grch37, tp_visor_grch38, fp_visor_grch38, tp plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') - model_name_label = model_name.replace("_", " ") plt.title('{} Receiver Operating Characteristic (Training Set)'.format(model_name_label)) plt.legend(loc='lower right') roc_plot_path = os.path.join(output_directory, model_name_fp + '_roc_curve_train.svg') @@ -611,13 +676,13 @@ def train(tp_hg002_grch37, fp_hg002_grch37, tp_visor_grch38, fp_visor_grch38, tp # Save the model to the output directory as a pickle file. model_path = os.path.join(output_directory, model_name_fp + '_model.pkl') joblib.dump(best_model, model_path) - logging.info('Saved the %s model to %s', model_name, model_path) + logging.info('Saved the %s model to %s', model_name_label, model_path) - logging.info('Completed training and evaluation for %s model.', model_name) + logging.info('Completed training and evaluation for %s model.', model_name_label) # Run SHAP if full analysis and no leave-outs (SHAP is slow) if not split_80_20 and no_leave_out: - logging.info('Running feature importance analysis for %s model.', model_name) + logging.info('Running feature importance analysis for %s model.', model_name_label) classifier = best_model.named_steps['classifier'] # For Random Forest, use both native importance and SHAP (with aggressive sampling) diff --git a/tests/test_predict_io.py b/tests/test_predict_io.py index ef6153d..b50d759 100644 --- a/tests/test_predict_io.py +++ b/tests/test_predict_io.py @@ -115,9 +115,7 @@ def test_score_generates_outputs_in_tests_output(monkeypatch): model='tests/fixtures/dummy_model.pkl', input_vcf=str(FIXTURE_VCF_GZ), output_vcf=str(FILTERED_VCF), - threshold=0.2, sample_coverage=30, - large_cutoff=10000, annovar_path='unused', annovar_db_path='unused', )