diff --git a/.gitignore b/.gitignore index b7478d26..ece555be 100644 --- a/.gitignore +++ b/.gitignore @@ -7,10 +7,6 @@ *.o *.obj -# SWIG files -src/swig_wrapper.cpp -lib/contextsv.py - # Pycache __pycache__/ @@ -50,12 +46,8 @@ __pycache__/ *.code-workspace CMakeSettings.json -# Shell scripts -*.sh - # Output folder output/ -python/ # Doxygen docs/html/ @@ -64,8 +56,6 @@ docs/html/ *.sif # Test directories -python/dbscan -python/agglo linktoscripts tests/data tests/cpp_module_out @@ -85,11 +75,6 @@ data/sv_scoring_dataset/ data/hg38ToHg19.over.chain.gz data/hg19ToHg38.over.chain.gz -# Test images -python/dbscan_clustering*.png -python/dist_plots -upset_plot*.png - # Temporary files lib/.nfs* valgrind.log diff --git a/Dockerfile b/Dockerfile index 95933f28..a5c3e7ba 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,16 +6,24 @@ ARG CONTEXTSV_VERSION WORKDIR /app -RUN apt-get update -RUN conda update conda +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates && rm -rf /var/lib/apt/lists/* +RUN conda update -y conda + +# Install ContextSV and plotting dependencies. +RUN conda config --add channels wglab \ + && conda config --add channels conda-forge \ + && conda config --add channels bioconda \ + && conda create -y -n contextsv python=3.10 \ + && conda install -y -n contextsv -c wglab -c conda-forge -c bioconda \ + contextsv=${CONTEXTSV_VERSION} plotly python-kaleido \ + && conda clean -afy + +# Smoke test both commands at build time. +RUN conda run -n contextsv contextsv --help \ + && conda run -n contextsv contextsv-cnv-plot --help -# Install ContextSV -RUN conda config --add channels wglab -RUN conda config --add channels conda-forge -RUN conda config --add channels bioconda -RUN conda create -n contextsv python=3.9 -RUN echo "conda activate contextsv" >> ~/.bashrc SHELL ["/bin/bash", "--login", "-c"] -RUN conda install -n contextsv -c wglab -c conda-forge -c bioconda contextsv=${CONTEXTSV_VERSION} && conda clean -afy -ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "contextsv", "contextsv"] +# Default command remains contextsv, but this allows overriding with contextsv-cnv-plot. +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "contextsv"] +CMD ["contextsv"] diff --git a/Doxyfile b/Doxyfile index a11d2045..2faa3e0e 100644 --- a/Doxyfile +++ b/Doxyfile @@ -1063,7 +1063,7 @@ EXCLUDE_SYMLINKS = NO # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories for example use the pattern */test/* -EXCLUDE_PATTERNS = *test* *swig* khmm.cpp kc.cpp khmm.h kc.h +EXCLUDE_PATTERNS = *test* khmm.cpp kc.cpp khmm.h kc.h # The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names # (namespaces, classes, functions, etc.) that should be excluded from the @@ -1071,7 +1071,7 @@ EXCLUDE_PATTERNS = *test* *swig* khmm.cpp kc.cpp khmm.h kc.h # wildcard * is used, a substring. Examples: ANamespace, AClass, # ANamespace::AClass, ANamespace::*Test -EXCLUDE_SYMBOLS = *SWIG* +EXCLUDE_SYMBOLS = # The EXAMPLE_PATH tag can be used to specify one or more files or directories # that contain example code fragments that are included (see the \include diff --git a/Makefile b/Makefile index ae39b32b..ee5442f2 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ CONDA_LIB_DIR := $(CONDA_PREFIX)/lib # Compiler and Flags CXX := g++ -CXXFLAGS := -std=c++17 -g -I$(INCL_DIR) -I$(CONDA_INCL_DIR) -Wall -Wextra -pedantic +CXXFLAGS := -std=c++17 -O3 -DNDEBUG -I$(INCL_DIR) -I$(CONDA_INCL_DIR) -Wall -Wextra -pedantic # Linker Flags # Ensure that the library paths are set correctly for linking @@ -19,15 +19,17 @@ LDFLAGS := -L$(LIB_DIR) -L$(CONDA_LIB_DIR) -Wl,-rpath=$(CONDA_LIB_DIR) # Add rp LDLIBS := -lhts # Link with libhts.a or libhts.so # Sources and Output -SOURCES := $(filter-out $(SRC_DIR)/swig_wrapper.cpp, $(wildcard $(SRC_DIR)/*.cpp)) # Filter out the SWIG wrapper from the sources +SOURCES := $(wildcard $(SRC_DIR)/*.cpp) OBJECTS := $(patsubst $(SRC_DIR)/%.cpp,$(BUILD_DIR)/%.o,$(SOURCES)) TARGET := $(BUILD_DIR)/contextsv +PREFIX ?= $(CONDA_PREFIX) +BINDIR ?= $(PREFIX)/bin # Default target all: $(TARGET) # Debug target -debug: CXXFLAGS += -DDEBUG +debug: CXXFLAGS := -std=c++17 -g -O0 -DDEBUG -I$(INCL_DIR) -I$(CONDA_INCL_DIR) -Wall -Wextra -pedantic debug: all # Link the executable @@ -43,3 +45,13 @@ $(BUILD_DIR)/%.o: $(SRC_DIR)/%.cpp # Clean the build directory clean: rm -rf $(BUILD_DIR) + +# Install binaries and helper scripts +install: $(TARGET) + @if [ -z "$(PREFIX)" ]; then \ + echo "Error: PREFIX is empty. Activate a conda env or run 'make install PREFIX=/your/prefix'."; \ + exit 1; \ + fi + install -d $(BINDIR) + install -m 755 $(TARGET) $(BINDIR)/contextsv + install -m 755 python/cnv_plots_json.py $(BINDIR)/contextsv-cnv-plot diff --git a/README.md b/README.md index 3a7e0cda..24476fdb 100644 --- a/README.md +++ b/README.md @@ -16,65 +16,120 @@ Class documentation is available at File containing per-chromosome population allele frequency filepaths as described in this documentation + --assembly-gaps Assembly gaps file in BED format available from UCSC Genome Browser (https://hgdownload.soe.ucsc.edu/goldenPath/hg38/database/gap.txt.gz for GRCh38) + --save-cnv Save CNV data in JSON for downstream plotting with contextsv-cnv-plot --debug Debug mode with verbose logging --version Print version and exit -h, --help Print usage and exit @@ -95,7 +150,7 @@ Download links for genome VCF files are located here (last updated April 3, ### Script for downloading gnomAD VCFs -``` +```bash download_dir="~/data/gnomad/v4.0.0/" chr_list=("1" "2" "3" "4" "5" "6" "7" "8" "9" "10" "11" "12" "13" "14" "15" "16" "17" "18" "19" "20" "21" "22" "X" "Y") @@ -110,7 +165,7 @@ Finally, create a text file that specifies the chromosome and its corresponding gnomAD filepath. This file will be passed in as an argument: **gnomadv4_filepaths.txt** -``` +```bash 1=~/data/gnomad/v4.0.0/gnomad.genomes.v4.0.sites.chr1.vcf.bgz 2=~/data/gnomad/v4.0.0/gnomad.genomes.v4.0.sites.chr2.vcf.bgz 3=~/data/gnomad/v4.0.0/gnomad.genomes.v4.0.sites.chr3.vcf.bgz diff --git a/conda/build.sh b/conda/build.sh new file mode 100644 index 00000000..265df296 --- /dev/null +++ b/conda/build.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +set -e + +echo "Building ContextSV..." +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${PREFIX}/lib +export CONDA_PREFIX=$PREFIX +export CXXFLAGS="-I$PREFIX/include $CXXFLAGS" +export LDFLAGS="-L$PREFIX/lib $LDFLAGS" + +echo "Checking for HTSLib..." +ls -la $PREFIX/include/htslib/ || echo "HTSLib headers not found" +pkg-config --exists htslib && echo "✓ HTSLib found" || echo "⚠ HTSLib not via pkg-config" + +echo "Compiling ContextSV..." +make + +echo "Installing ContextSV..." +mkdir -p ${PREFIX}/bin +cp build/contextsv ${PREFIX}/bin/ +chmod +x ${PREFIX}/bin/contextsv +cp python/cnv_plots_json.py ${PREFIX}/bin/contextsv-cnv-plot +chmod +x ${PREFIX}/bin/contextsv-cnv-plot + +echo "Verifying ContextSV installation..." +$PREFIX/bin/contextsv --help +$PREFIX/bin/contextsv --version + +echo "Verifying CNV plotting command installation..." +test -x ${PREFIX}/bin/contextsv-cnv-plot +${PREFIX}/bin/contextsv-cnv-plot --help diff --git a/conda/meta.yaml b/conda/meta.yaml index d217ee43..b088db9d 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -10,6 +10,7 @@ source: git_lfs: false channels: + - wglab - conda-forge - bioconda - defaults @@ -25,12 +26,18 @@ requirements: - htslib=1.20 run: - htslib=1.20 + - contextscore + - python >=3.9 + - plotly + - python-kaleido test: commands: - - contextsv --help - - test -f $PREFIX/bin/contextsv - - contextsv --version + - test -x $PREFIX/bin/contextsv + - $PREFIX/bin/contextsv --help + - $PREFIX/bin/contextsv --version + - test -x $PREFIX/bin/contextsv-cnv-plot + - $PREFIX/bin/contextsv-cnv-plot --help about: home: https://github.com/WGLab/ContextSV license: MIT diff --git a/include/fasta_query.h b/include/fasta_query.h index 4486bdb0..339a30c7 100644 --- a/include/fasta_query.h +++ b/include/fasta_query.h @@ -24,7 +24,7 @@ class ReferenceGenome { public: ReferenceGenome(std::shared_mutex& shared_mutex) : shared_mutex(shared_mutex) {} - int setFilepath(std::string fasta_filepath); + int read(std::string fasta_filepath); std::string getFilepath() const; std::string_view query(const std::string& chr, uint32_t pos_start, uint32_t pos_end) const; bool compare(const std::string& chr, uint32_t pos_start, uint32_t pos_end, const std::string& compare_seq, float match_threshold) const; diff --git a/include/input_data.h b/include/input_data.h index 32c78ce2..bc300208 100644 --- a/include/input_data.h +++ b/include/input_data.h @@ -55,13 +55,9 @@ class InputData { void setAssemblyGaps(std::string filepath); std::string getAssemblyGaps() const; - // Set the sample size for HMM predictions. - void setSampleSize(int sample_size); - int getSampleSize() const; - - // Set the minimum CNV length to use for copy number predictions. - void setMinCNVLength(int min_cnv_length); - uint32_t getMinCNVLength() const; + // Set/get a target chromosome for single-chromosome analysis. + void setChromosome(std::string chr); + std::string getChromosome() const; // Set the epsilon parameter for DBSCAN clustering. void setDBSCAN_Epsilon(double epsilon); @@ -72,11 +68,6 @@ class InputData { void setDBSCAN_MinPtsPct(double min_pts_pct); double getDBSCAN_MinPtsPct() const; - // Set the chromosome to analyze. - void setChromosome(std::string chr); - std::string getChromosome() const; - bool isSingleChr() const; - // Set the output directory where the results will be written. void setOutputDir(std::string dirpath); std::string getOutputDir() const; @@ -104,8 +95,6 @@ class InputData { std::string ethnicity; std::unordered_map pfb_filepaths; // Map of population frequency VCF filepaths by chromosome std::string output_dir; - int sample_size; - uint32_t min_cnv_length; int min_reads; double dbscan_epsilon; double dbscan_min_pts_pct; diff --git a/include/sv_caller.h b/include/sv_caller.h index a0883e6c..f0178519 100644 --- a/include/sv_caller.h +++ b/include/sv_caller.h @@ -91,7 +91,6 @@ class SVCaller { void runSplitReadCopyNumberPredictions(const std::string& chr, std::vector& split_sv_calls, const CNVCaller &cnv_caller, const CHMM &hmm, double mean_chr_cov, const std::vector &pos_depth_map, const InputData &input_data); void saveToVCF(const std::unordered_map> &sv_calls, const InputData &input_data, const ReferenceGenome &ref_genome, const std::unordered_map> &chr_pos_depth_map) const; - // void saveToVCF(const std::unordered_map> &sv_calls, const std::string &output_dir, const ReferenceGenome &ref_genome, const std::unordered_map>& chr_pos_depth_map) const; // Query the read depth (INFO/DP) at a position int getReadDepth(const std::vector& pos_depth_map, uint32_t start) const; diff --git a/include/swig_interface.h b/include/swig_interface.h deleted file mode 100644 index 578f4653..00000000 --- a/include/swig_interface.h +++ /dev/null @@ -1,17 +0,0 @@ -// -// swig_interface.h: -// Declare the C++ functions that will be wrapped by SWIG -// - -#ifndef SWIG_INTERFACE_H -#define SWIG_INTERFACE_H - -#include "input_data.h" - -/// @cond -#include -/// @endcond - -int run(const InputData& input_data); - -#endif // SWIG_INTERFACE_H diff --git a/python/cluster_params.py b/python/cluster_params.py deleted file mode 100644 index ff1484ca..00000000 --- a/python/cluster_params.py +++ /dev/null @@ -1,346 +0,0 @@ -""" -test_cluster_params.py: Test the cluster parameters for the cluster + merge -pipeline. - -Usage: - python test_cluster_params.py - - - benchmark_file_path: Path to the benchmark file. - cluster_type: Either 'dbscan' or 'agglo' - output_dir: Directory for saving the output plots. - -""" - -import os -import sys - -import matplotlib.pyplot as plt - -def get_precision_recall(file_path, sv_type='DEL'): - """Parse text file containing epsilon, precision, and recall values.""" - epsilon_values = [] - precision_values = [] - recall_values = [] - fp_counts = [] - fn_counts = [] - comp_counts = [] - base_counts = [] - - with open(file_path, 'r', encoding='utf-8') as file: - lines = file.readlines() - - epsilon = None - sv_section_found = False - for i, line in enumerate(lines): - - if "#EPSILON=" in line: - epsilon = float(line.split('=')[1]) - epsilon_values.append(epsilon) - - # SV type sections - elif "Running truvari" in line: - if sv_type in line: - sv_section_found = True - - # Get the number of SVs in the callset vs. the benchmark - elif "Zipped" in line and "Counter" in line and sv_section_found: - # [INFO] Zipped 269 variants Counter({'comp': 204, 'base': 65}) - - # Split the line by 'Counter' - line = line.split('Counter')[1] - - # Get the value after 'comp': - comp_count = line.split("'comp':")[1] - comp_count = comp_count.split(',')[0] - comp_count = comp_count.split('}')[0] - comp_count = int(comp_count) - - # Get the value after 'base': - base_count = line.split("'base':")[1] - base_count = base_count.split(',')[0] - base_count = base_count.split('}')[0] - base_count = int(base_count) - - # Add the counts to the lists - comp_counts.append(comp_count) - base_counts.append(base_count) - - # Get the number of FPs - elif "FP" in line and sv_section_found: - # Get the value after the ':' - fp = line.split(':')[1] - - # Clean up the string - fp = fp.replace('\n', '') - fp = fp.replace(',', '') - fp = int(fp) - fp_counts.append(fp) - - # Get the number of FNs - elif "FN" in line and sv_section_found: - # Get the value after the ':' - fn = line.split(':')[1] - - # Clean up the string - fn = fn.replace('\n', '') - fn = fn.replace(',', '') - fn = int(fn) - fn_counts.append(fn) - - elif "precision" in line and sv_section_found: - # Get the value after the ':' - p = line.split(':')[1] - - # Clean up the string - p = p.replace('\n', '') - p = p.replace(',', '') - p = float(p) - precision_values.append(p) - - elif "recall" in line and sv_section_found: - # Get the value after the ':' - r = line.split(':')[1] - - # Clean up the string - r = r.replace('\n', '') - r = r.replace(',', '') - r = float(r) - recall_values.append(r) - - # Reset epsilon and sv_section_found - epsilon = None - sv_section_found = False - - print(f'SV Type: {sv_type}') - - ##### Maximizing F1 ##### - # Get the maximum F1 score and the corresponding epsilon, precision, recall - f1_scores = [] - for i, precision in enumerate(precision_values): - recall = recall_values[i] - f1 = 2 * (precision * recall) / (precision + recall) - f1_scores.append(f1) - - max_f1 = max(f1_scores) - max_f1_index = f1_scores.index(max_f1) - - # Print the maximum F1 score - print(f'Maximum F1: {max_f1}') - - # Print the maximum precision and recall values - print(f'Precision at F1: {precision_values[max_f1_index]}') - print(f'Recall at F1: {recall_values[max_f1_index]}') - - # Print the parameter value at the maximum F1 score - print(f'Parameter value : {epsilon_values[max_f1_index]}') - - # Print the FP and FN counts at the maximum F1 score - print(f'FP Count: {fp_counts[max_f1_index]}') - print(f'FN Count: {fn_counts[max_f1_index]}') - - # Print the number of SVs in the callset and benchmark at the maximum F1 - # score - print(f'Number of {sv_type}s in Callset: {comp_counts[max_f1_index]}') - print(f'Number of {sv_type}s in Benchmark: {base_counts[max_f1_index]}') - - ##### Maximizing recall ##### - # # Get the maximum recall value, and then the maximum precision value at that - # # recall value - # max_recall = max(recall_values) - # max_precision = None - # max_index = None # Index of the maximum recall and corresponding precision - # for i, recall in enumerate(recall_values): - # if recall == max_recall: - # if max_precision is None: - # max_precision = precision_values[i] - # max_index = i - # elif precision_values[i] > max_precision: - # max_precision = precision_values[i] - # max_index = i - - # # Print the maximum precision and recall values - # print(f'Maximum Recall: {max_recall}') - # print(f'Maximum Precision at Maximum Recall: {max_precision}') - - # # Print the parameter value at the maximum recall and corresponding precision - # print(f'{parameter_name} at Maximum Recall: {epsilon_values[max_index]}') - - # # Print the FP and FN counts at the maximum recall and corresponding - # # precision - # print(f'FP Count at Maximum Recall: {fp_counts[max_index]}') - # print(f'FN Count at Maximum Recall: {fn_counts[max_index]}') - - # # Print the number of SVs in the callset and benchmark at the maximum recall - # # and corresponding precision - # print(f'Number of {sv_type}s in Callset: {comp_counts[max_index]}') - # print(f'Number of {sv_type}s in Benchmark: {base_counts[max_index]}') - - return epsilon_values, precision_values, recall_values - - -def get_f1_scores(file_path, sv_type='DEL'): - """Parse text file containing epsilon and F1 scores.""" - epsilon_values = [] - f1_values = [] - - with open(file_path, 'r', encoding='utf-8') as file: - lines = file.readlines() - - epsilon = None - sv_section_found = False - for i, line in enumerate(lines): - - if "#EPSILON=" in line: - epsilon = float(line.split('=')[1]) - epsilon_values.append(epsilon) - - # SV type sections - elif "Running truvari" in line: - if sv_type in line: - sv_section_found = True - - elif "f1" in line and sv_section_found: - # Get the value after the ':' - f1 = line.split(':')[1] - - # Clean up the string - f1 = f1.replace('\n', '') - f1 = f1.replace(',', '') - f1 = float(f1) - f1_values.append(f1) - - # Reset epsilon and sv_section_found - epsilon = None - sv_section_found = False - - return epsilon_values, f1_values - - -def plot_precision_recall(epsilon, precision, recall, title="Precision and Recall vs. Epsilon", parameter_name='Epsilon'): - """Plot precision and recall values.""" - # Create figure - plt.figure() - - # Plot precision and recall vs. epsilon on same plot but different axes - ax1 = plt.gca() - ax2 = ax1.twinx() - - # Plot precision vs. epsilon on ax1 - ax1.plot(epsilon, precision, label='Precision', color='black') - - # Plot recall vs. epsilon on ax2 - ax2.plot(epsilon, recall, label='Recall', color='blue') - - # # Show ticks for all epsilon values - # ax1.set_xticks(epsilon) - - # # Make X-ticks vertical - # plt.xticks(rotation=90) - - # # Double the figure width - # plt.gcf().set_size_inches(18.5, 10.5) - - # Add axis labels - ax1.set_xlabel(parameter_name, color='black') - ax1.set_ylabel('Precision', color='black') - ax2.set_xlabel('Epsilon', color='black') - ax2.set_ylabel('Recall', color='blue') - - # Set tick colors - ax1.tick_params(axis='y', colors='black') - ax2.tick_params(axis='y', colors='blue') - - # Add title - plt.title(title) - - return plt - - -def plot_f1(epsilon, f1_scores, title="F1 vs. Epsilon", parameter_name='Epsilon'): - """Plot F1 values.""" - # Create figure - plt.figure() - - # Plot F1 vs. epsilon - plt.plot(epsilon, f1_scores, label='F1') - - # # Show ticks for all epsilon values - # plt.xticks(epsilon) - - # # Make X-ticks vertical - # plt.xticks(rotation=90) - - # # Double the figure width - # plt.gcf().set_size_inches(18.5, 10.5) - - # Add axis labels - plt.xlabel(parameter_name) - plt.ylabel('F1') - - # Add title - plt.title(title) - - # Return figure - return plt - - -if __name__ == '__main__': - # Take in benchmark file path as command line argument - file_path = sys.argv[1] - - print(f'Input file path: {file_path}') - - # Take in cluster type as command line argument - cluster_type = sys.argv[2] - if cluster_type not in ['dbscan', 'agglo']: - print(f"Invalid cluster type: {cluster_type}") - sys.exit(1) - - # Take in output directory name as command line argument - output_dir = sys.argv[3] - - # Create the directory if it doesn't exist - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - # Get the cluster type string - cluster_string = 'DBSCAN' if cluster_type == 'dbscan' else 'Agglomerative' - - # Determine the parameter to test - parameter_name = 'Epsilon' if cluster_type == 'dbscan' else 'Distance Threshold' - - # Create the plot title - plot_title = cluster_string + ' Cluster + Merge' - - # Plot precision and recall values - # Deletions - eps, prec, rec = get_precision_recall(file_path, sv_type='DEL') - fig = plot_precision_recall(eps, prec, rec, title=plot_title + ' (DEL)', parameter_name=parameter_name) - fig.savefig(output_dir + '/Precision_Recall_DEL.png') - - # Duplications - eps, prec, rec = get_precision_recall(file_path, sv_type='DUP') - fig = plot_precision_recall(eps, prec, rec, title=plot_title + ' (DUP)', parameter_name=parameter_name) - fig.savefig(output_dir + '/Precision_Recall_DUP.png') - - # Insertions - eps, prec, rec = get_precision_recall(file_path, sv_type='INS') - fig = plot_precision_recall(eps, prec, rec, title=plot_title + ' (INS)', parameter_name=parameter_name) - fig.savefig(output_dir + '/Precision_Recall_INS.png') - - # Plot F1 scores - # Deletions - eps, f1 = get_f1_scores(file_path, sv_type='DEL') - fig = plot_f1(eps, f1, title=plot_title + ' (DEL)', parameter_name=parameter_name) - fig.savefig(output_dir + '/F1_DEL.png') - - # Duplications - eps, f1 = get_f1_scores(file_path, sv_type='DUP') - fig = plot_f1(eps, f1, title=plot_title + ' (DUP)', parameter_name=parameter_name) - fig.savefig(output_dir + '/F1_DUP.png') - - # Insertions - eps, f1 = get_f1_scores(file_path, sv_type='INS') - fig = plot_f1(eps, f1, title=plot_title + ' (INS)', parameter_name=parameter_name) - fig.savefig(output_dir + '/F1_INS.png') diff --git a/python/cnv_plots.py b/python/cnv_plots.py deleted file mode 100644 index 67c831c6..00000000 --- a/python/cnv_plots.py +++ /dev/null @@ -1,307 +0,0 @@ -"""Plot the copy number variants and their log2_ratio, BAF values.""" - -import os -import sys -import logging as log -import plotly -from plotly.subplots import make_subplots -import pandas as pd - -try: - from .utils import parse_region, get_info_field_column, get_info_field_value -except ImportError: - from utils import parse_region, get_info_field_column, get_info_field_value - -MIN_CNV_LENGTH = 10000 - -# Set up logging. -log.basicConfig( - level=log.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[ - log.StreamHandler(sys.stdout) - ] -) - -def parse_region(region): - """ - Parses the region string to get the chromosome, start position, and end - position. - - Args: - region (str): The region string in the format "chr:start-end". - - Returns: - tuple: A tuple containing the chromosome, start position, and end - position. - """ - - # Split the region string by ":" and "-". - region_parts = region.split(":") - chromosome = region_parts[0] - region_parts = region_parts[1].split("-") - start_position = int(region_parts[0]) - end_position = int(region_parts[1]) - - return chromosome, start_position, end_position - -def run(cnv_data_file, output_html): - """ - Saves a plot of the CNVs and their log2 ratio and B-allele frequency - values. - - Args: - vcf_path (str): The path to the VCF file. - cnv_data_path (str): The path to the CNV data file. - output_path (str): The path to the output directory. - - Returns: - None - """ - - # Filter the CNV data to the region using pandas, and make the chromosome - # column a string. - log.info("Loading CNV data from %s", cnv_data_file) - - # Read the first 3 lines of the file to get metadata. - # Metadata is formatted as follows: - # "SVTYPE=" - # "POS=" - # "HMM_LOGLH=" - metadata = {} - metadata_row_count = 3 - with open(cnv_data_file, "r", encoding="utf-8") as f: - # Read the first 3 lines of the file. - for _ in range(metadata_row_count): - line = f.readline().strip() - if '=' in line: - key, value = line.split("=") - # log.info("Metadata: %s=%s", key, value) - value = value.strip() - metadata[key] = value - - sv_type = metadata["SVTYPE"] - position = metadata["POS"] - chromosome, start_position, end_position = parse_region(position) - hmm_loglh = float(metadata["HMM_LOGLH"]) - - # Extract information from the metadata. - log.info("SV type: %s, chromosome: %s, start position: %d, end position: %d, HMM log likelihood: %f", sv_type, chromosome, start_position, end_position, hmm_loglh) - - # Read the CNV data from the file. - sv_data = pd.read_csv(cnv_data_file, sep="\t", header=metadata_row_count, dtype={"chromosome": str}) - if len(sv_data) == 0: - log.info("No predictions found in %s", cnv_data_file) - return - else: - log.info("Found %d predictions in %s", len(sv_data), cnv_data_file) - - # Create an output html file where we will append the CNV plots. - if start_position is not None and end_position is not None: - html_filename = f"cnv_plots_{chromosome}_{start_position}_{end_position}.html" - else: - html_filename = f"cnv_plots_{chromosome}.html" - - # Create the output html file. - if os.path.exists(output_html): - os.remove(output_html) - - with open(output_html, "w", encoding="utf-8") as output_html_file: - - # Use absolute value of CNV length (deletions are negative). - # cnv_length = abs(cnv_length) - cnv_length = end_position - start_position + 1 - - # Return if the CNV length is less than the minimum CNV length. - if cnv_length < MIN_CNV_LENGTH: - log.info("Skipping CNV %s:%d-%d due to length < %d.", chromosome, start_position, end_position, MIN_CNV_LENGTH) - return - - # Get the plot range as the minimum and maximum positions in the CNV - # data. - plot_start_position = sv_data["position"].min() - plot_end_position = sv_data["position"].max() - - # Get the CNV state, log2 ratio, and BAF values for all SNPs in the - # plot range. - log.info("Getting SNPs in CNV %s:%d-%d.", chromosome, plot_start_position, plot_end_position) - - # If there are no SNPs in the plot range, skip the CNV. - if len(sv_data) == 0: - log.info("No SNPs found in CNV %s:%d-%d.", chromosome, start_position, end_position) - # continue - else: - log.info("Found %d SNPs in CNV %s:%d-%d.", len(sv_data), chromosome, start_position, end_position) - - # Get the marker colors for the state sequence. - marker_colors = [] - for state in sv_data["cnv_state"]: - if state in [1, 2]: - marker_colors.append("red") - elif state in [3, 4]: - marker_colors.append("black") - elif state in [5, 6]: - marker_colors.append("blue") - - # [TEST] Set the marker colors for the SNPs before and after the CNV to - # gray (no state prediction). - # for i in range(len(sv_data)): - # if sv_data["position"].iloc[i] < start_position or sv_data["position"].iloc[i] > end_position: - # marker_colors[i] = "gray" - - - # Use row['snp'] to get whether SNP or not (0=not SNP, 1=SNP). - marker_symbols = ["circle" if snp == 1 else "circle-open" for snp in sv_data["snp"]] - # marker_symbols = marker_symbols_before + marker_symbols + marker_symbols_after - - # Concatenate the SNP data before, during, and after the CNV. - # sv_data = pd.concat([sv_data_before, sv_data, sv_data_after]) - - # Set all -1 B-allele frequency values to 0. - sv_data.loc[sv_data["b_allele_freq"] == -1, "b_allele_freq"] = 0 - - # Get the hover text for the state sequence markers. - hover_text = [] - for _, row in sv_data.iterrows(): - hover_text.append(f"SNP: {row['snp']}
TYPE: {'NA'}
CHR: {row['chromosome']}
POS: {row['position']}
L2R: {row['log2_ratio']}
BAF: {row['b_allele_freq']}
PFB: {row['population_freq']}
STATE: {row['cnv_state']}") - # hover_text.append(f"TYPE: {cnv_types[row['cnv_state']]}
CHR: {row['chromosome']}
POS: {row['position']}
L2R: {row['log2_ratio']}
BAF: {row['b_allele_freq']}
PFB: {row['population_freq']}") - - # Create the log2 ratio trace. - log2_ratio_trace = plotly.graph_objs.Scatter( - x = sv_data["position"], - y = sv_data["log2_ratio"], - mode = "markers+lines", - name = r'Log2 Ratio', - text = hover_text, - hoverinfo = "text", - marker = dict( - color = marker_colors, - size = 10, - symbol = marker_symbols - ), - line = dict( - color = "black", - width = 0 - ), - showlegend = False - ) - - # Create the B-allele frequency trace. - baf_trace = plotly.graph_objs.Scatter( - x = sv_data["position"], - y = sv_data["b_allele_freq"], - mode = "markers+lines", - name = "B-Allele Frequency", - text = hover_text, - marker = dict( - color = marker_colors, - size = 10, - symbol = marker_symbols - ), - line = dict( - color = "black", - width = 0 - ), - showlegend = False - ) - - # Create a subplot for the CNV plot and the BAF plot. - fig = make_subplots( - rows=2, - cols=1, - shared_xaxes=True, - vertical_spacing=0.05, - subplot_titles=(r"SNP Log2 Ratio", "SNP B-Allele Frequency") - ) - - # Add the traces to the figure. - fig.append_trace(log2_ratio_trace, 1, 1) - fig.append_trace(baf_trace, 2, 1) - - # Set the x-axis title. - fig.update_xaxes( - title_text = "Chromosome Position", - row = 2, - col = 1 - ) - - # Set the y-axis titles. - fig.update_yaxes( - title_text = r"Log2 Ratio", - row = 1, - col = 1 - ) - - fig.update_yaxes( - title_text = "B-Allele Frequency", - row = 2, - col = 1 - ) - - # Set the Y-axis range for the log2 ratio plot. - fig.update_yaxes( - range = [-2.0, 2.0], - row = 1, - col = 1 - ) - - # Set the Y-axis range for the BAF plot. - fig.update_yaxes( - range = [-0.2, 1.2], - row = 2, - col = 1 - ) - - # Set the figure title. - # fig.update_layout( - # title_text = f"{svtype} (SUPPORT={read_support}, LEN={cnv_length}bp) at {chromosome}:{start_position}-{end_position} [ALN={aln}]", - # ) - - # Create a shaded rectangle for the CNV, layering it below the CNV - # trace and labeling it with the CNV type. - fig.add_vrect( - x0 = start_position, - x1 = end_position, - fillcolor = "Black", - layer = "below", - line_width = 0, - opacity = 0.1, - annotation_text = '', - annotation_position = "top left", - annotation_font_size = 20, - annotation_font_color = "black" - ) - - # Add vertical lines at the start and end positions of the CNV. - fig.add_vline( - x = start_position, - line_width = 2, - line_color = "black", - layer = "below" - ) - - fig.add_vline( - x = end_position, - line_width = 2, - line_color = "black", - layer = "below" - ) - - # Append the figure to the output html file. - output_html_file.write(fig.to_html(full_html=False, include_plotlyjs="cdn")) - log.info("Plotted CNV %s %s:%d-%d.", 'SVType', chromosome, start_position, end_position) - - # Increment the CNV count. - # cnv_count += 1 - - # # Break if the maximum number of CNVs has been reached. - # if cnv_count == max_cnvs: - # break - - log.info("Saved CNV plots to %s.", output_html) - -if __name__ == "__main__": - cnv_data_file = sys.argv[1] - output_path = sys.argv[2] - - run(cnv_data_file, output_path) diff --git a/python/cnv_plots_json.py b/python/cnv_plots_json.py index 9f2c9458..b9fd6849 100644 --- a/python/cnv_plots_json.py +++ b/python/cnv_plots_json.py @@ -1,242 +1,301 @@ -import os +#!/usr/bin/env python3 import argparse import json -import numpy as np -import plotly -from plotly.subplots import make_subplots - -min_sv_length = 60000 # Minimum SV length in base pairs - -# Set up argument parser -parser = argparse.ArgumentParser(description='Generate CNV plots from JSON data.') -parser.add_argument('json_file', type=str, help='Path to the JSON file containing SV data') -parser.add_argument('chromosome', type=str, help='Chromosome to filter the SVs by (e.g., "chr3")', nargs='?', default=None) -args = parser.parse_args() - -# Load your JSON data -with open(args.json_file) as f: - sv_data = json.load(f) - -# State marker colors -# https://community.plotly.com/t/plotly-colours-list/11730/6 -state_colors_dict = { - '1': 'darkred', - '2': 'red', - '3': 'gray', - '4': 'green', - '5': 'blue', - '6': 'darkblue', +import os + +DEFAULT_MIN_SV_LENGTH = 50000 +MARKER_SIZE = 8 +STATIC_FORMATS = {"svg", "pdf", "png", "jpg", "jpeg", "webp", "eps"} +ALLOWED_FORMATS = STATIC_FORMATS.union({"html"}) + +STATE_COLORS = { + "1": "darkred", + "2": "red", + "3": "gray", + "4": "green", + "5": "blue", + "6": "darkblue", } -sv_type_dict = { - 'DEL': 'Deletion', - 'DUP': 'Duplication', - 'INV': 'Inversion' +SV_TYPE_LABELS = { + "DEL": "Deletion", + "DUP": "Duplication", + "INV": "Inversion", } -# Loop through each SV (assuming your JSON contains multiple SVs) -for sv in sv_data: +REQUIRED_SECTION_KEYS = { + "positions", + "b_allele_freq", + "population_freq", + "log2_ratio", + "is_snp", +} - # If a chromosome is specified, filter the SVs by that chromosome - if args.chromosome and sv['chromosome'] != args.chromosome: - print(f"Skipping SV {sv['chromosome']}:{sv['start']}-{sv['end']} of type {sv['sv_type']} (not on chromosome {args.chromosome})") - continue +REQUIRED_SV_KEYS = { + "chromosome", + "start", + "end", + "sv_type", + "size", + "before_sv", + "sv", + "after_sv", +} - # Filter out SVs that are smaller than the minimum length - if np.abs(sv['size']) < min_sv_length: - print(f"Skipping SV {sv['chromosome']}:{sv['start']}-{sv['end']} of type {sv['sv_type']} with size {sv['size']} bp (smaller than {min_sv_length} bp)") - continue - # Extract data for plotting - positions_before = sv['before_sv']['positions'] - b_allele_freq_before = sv['before_sv']['b_allele_freq'] - positions_after = sv['after_sv']['positions'] - b_allele_freq_after = sv['after_sv']['b_allele_freq'] +def parse_args(): + parser = argparse.ArgumentParser(description="Generate CNV plots from JSON data.") + parser.add_argument("json_file", type=str, help="Path to the JSON file containing SV data") + parser.add_argument( + "chromosome", + type=str, + nargs="?", + default=None, + help="Chromosome to filter SVs by (e.g., chr3)", + ) + parser.add_argument( + "--formats", + type=str, + default="html", + help="Comma-separated output formats (e.g., html,svg,pdf,png)", + ) + parser.add_argument("--width", type=int, default=1200, help="Figure width in pixels for static exports") + parser.add_argument("--height", type=int, default=800, help="Figure height in pixels for static exports") + parser.add_argument("--scale", type=float, default=2.0, help="Scale factor for raster exports") + parser.add_argument( + "--min-sv-length", + type=int, + default=DEFAULT_MIN_SV_LENGTH, + help="Minimum SV length in base pairs to plot", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory to save output plots (default: ./CNV_Plots)", + ) + return parser.parse_args() - # Create a subplot for the CNV plot and the BAF plot. - fig = make_subplots( - rows=2, - cols=1, - shared_xaxes=True, - vertical_spacing=0.05, - subplot_titles=(r"SNP Log2 Ratio", "SNP B-Allele Frequency") - ) - # Get the chromosome, start, end, and sv_type from the SV data - chromosome = sv['chromosome'] - start = sv['start'] - end = sv['end'] - sv_type = sv['sv_type'] - likelihood = sv['likelihood'] - sv_length = sv['size'] +def parse_formats(formats_arg): + formats = [fmt.strip().lower() for fmt in formats_arg.split(",") if fmt.strip()] + invalid = [fmt for fmt in formats if fmt not in ALLOWED_FORMATS] + if invalid: + allowed = ", ".join(sorted(ALLOWED_FORMATS)) + bad = ", ".join(invalid) + raise ValueError(f"Unsupported format(s): {bad}. Allowed formats are: {allowed}") + return formats + + +def load_json_records(path): + if not os.path.isfile(path): + raise FileNotFoundError(f"Input JSON file not found: {path}") + + with open(path, encoding="utf-8") as handle: + data = json.load(handle) + + if not isinstance(data, list): + raise ValueError("Input JSON must contain a top-level list of SV records.") + + return data + + +def validate_record(record, idx): + missing_sv_keys = REQUIRED_SV_KEYS - set(record.keys()) + if missing_sv_keys: + missing = ", ".join(sorted(missing_sv_keys)) + raise ValueError(f"Record {idx} missing required top-level key(s): {missing}") - # Plot the data for 'before_sv', 'sv', and 'after_sv' for section in ["before_sv", "sv", "after_sv"]: - positions = sv[section]['positions'] - b_allele_freq = sv[section]['b_allele_freq'] - population_freq = sv[section]['population_freq'] - log2_ratio = sv[section]['log2_ratio'] - is_snp = sv[section]['is_snp'] + missing_section_keys = REQUIRED_SECTION_KEYS - set(record[section].keys()) + if missing_section_keys: + missing = ", ".join(sorted(missing_section_keys)) + raise ValueError(f"Record {idx}, section {section} missing key(s): {missing}") - # Set all b-allele frequencies to NaN if not SNPs - b_allele_freq = [freq if is_snp_val else float('nan') for freq, is_snp_val in zip(b_allele_freq, is_snp)] +def build_hover_text(section, positions, states, log2_ratio, is_snp, b_allele_freq, population_freq): + hover_text = [] + for i, position in enumerate(positions): if section == "sv": - # is_snp = sv[section]['is_snp'] - states = sv[section]['states'] - state_colors = [state_colors_dict[str(state)] for state in states] - marker_symbols = ['circle' if is_snp_val else 'circle-open' for is_snp_val in is_snp] - - # Set the hover text - hover_text = [] - for i, position in enumerate(positions): - # Add hover text for each point - hover_text.append( - f"Position: {position}
" - f"State: {states[i]}
" - f"Log2 Ratio: {log2_ratio[i]}
" - f"SNP: {is_snp[i]}
" - f"BAF: {b_allele_freq[i]}
" - f"Population Frequency: {population_freq[i]}
" - ) + hover_text.append( + f"Position: {position}
" + f"State: {states[i]}
" + f"Log2 Ratio: {log2_ratio[i]}
" + f"SNP: {is_snp[i]}
" + f"BAF: {b_allele_freq[i]}
" + f"Population Frequency: {population_freq[i]}
" + ) else: - # is_snp = sv[section]['is_snp'] - state_colors = ['black'] * len(positions) - # marker_symbols = ['circle-open'] * len(positions) - marker_symbols = ['circle' if is_snp_val else 'circle-open' for is_snp_val in is_snp] - hover_text = [] - for i, position in enumerate(positions): - # Add hover text for each point - hover_text.append( - f"Position: {position}
" - f"Log2 Ratio: {log2_ratio[i]}
" - f"BAF: {b_allele_freq[i]}
" - f"Population Frequency: {population_freq[i]}
" - ) - - # Create the log2 trace - log2_trace = plotly.graph_objs.Scatter( - x=positions, - y=log2_ratio, - mode='markers+lines', - name=r'Log2 Ratio', - text=hover_text, - hoverinfo='text', - marker=dict( - color=state_colors, - size=5, - symbol=marker_symbols, - ), - line=dict( - color='black', - width=0 - ), - showlegend=False - ) + hover_text.append( + f"Position: {position}
" + f"Log2 Ratio: {log2_ratio[i]}
" + f"BAF: {b_allele_freq[i]}
" + f"Population Frequency: {population_freq[i]}
" + ) + return hover_text - # Create the BAF trace - baf_trace = plotly.graph_objs.Scatter( - x=positions, - y=b_allele_freq, - mode='markers+lines', - name='B-Allele Frequency', - text=hover_text, - hoverinfo='text', - marker=dict( - color=state_colors, - size=5, - symbol=marker_symbols, - ), - line=dict( - color='black', - width=0 - ), - showlegend=False - ) - if section == "sv": - # Create a shaded rectangle for the CNV, layering it below the CNV - # trace and labeling it with the CNV type. - fig.add_vrect( - x0 = start, - x1 = end, - fillcolor = "Black", - layer = "below", - line_width = 0, - opacity = 0.1, - annotation_text = '', - annotation_position = "top left", - annotation_font_size = 20, - annotation_font_color = "black" - ) +def add_section_traces(fig, record, section, start, end): + positions = record[section]["positions"] + b_allele_freq = record[section]["b_allele_freq"] + population_freq = record[section]["population_freq"] + log2_ratio = record[section]["log2_ratio"] + is_snp = record[section]["is_snp"] - # Add vertical lines at the start and end positions of the CNV. - fig.add_vline( - x = start, - line_width = 2, - line_color = "black", - layer = "below" - ) + b_allele_freq = [freq if snp_flag else float("nan") for freq, snp_flag in zip(b_allele_freq, is_snp)] - fig.add_vline( - x = end, - line_width = 2, - line_color = "black", - layer = "below" - ) + if section == "sv": + states = record[section].get("states", ["NA"] * len(positions)) + state_colors = [STATE_COLORS.get(str(state), "black") for state in states] + marker_symbols = ["circle" if snp_flag else "circle-open" for snp_flag in is_snp] + else: + states = ["NA"] * len(positions) + state_colors = ["black"] * len(positions) + marker_symbols = ["circle" if snp_flag else "circle-open" for snp_flag in is_snp] - # Add traces to the figure - fig.append_trace(log2_trace, row=1, col=1) - fig.append_trace(baf_trace, row=2, col=1) - - # Set the x-axis title. - fig.update_xaxes( - title_text = "Chromosome Position", - row = 2, - col = 1 - ) + hover_text = build_hover_text(section, positions, states, log2_ratio, is_snp, b_allele_freq, population_freq) - # Set the y-axis titles. - fig.update_yaxes( - title_text = r"Log2 Ratio", - row = 1, - col = 1 - ) + import plotly - fig.update_yaxes( - title_text = "B-Allele Frequency", - row = 2, - col = 1 + log2_trace = plotly.graph_objs.Scatter( + x=positions, + y=log2_ratio, + mode="markers+lines", + name="Log2 Ratio", + text=hover_text, + hoverinfo="text", + marker={"color": state_colors, "size": MARKER_SIZE, "symbol": marker_symbols}, + line={"color": "black", "width": 0}, + showlegend=False, ) - # Set the Y-axis range for the log2 ratio plot. - fig.update_yaxes( - range = [-2.0, 2.0], - row = 1, - col = 1 + baf_trace = plotly.graph_objs.Scatter( + x=positions, + y=b_allele_freq, + mode="markers+lines", + name="B-Allele Frequency", + text=hover_text, + hoverinfo="text", + marker={"color": state_colors, "size": MARKER_SIZE, "symbol": marker_symbols}, + line={"color": "black", "width": 0}, + showlegend=False, ) - # Set the Y-axis range for the BAF plot. - fig.update_yaxes( - range = [-0.2, 1.2], - row = 2, - col = 1 + if section == "sv": + fig.add_vrect(x0=start, x1=end, fillcolor="black", layer="below", line_width=0, opacity=0.1) + fig.add_vline(x=start, line_width=2, line_color="black", layer="below") + fig.add_vline(x=end, line_width=2, line_color="black", layer="below") + + fig.append_trace(log2_trace, row=1, col=1) + fig.append_trace(baf_trace, row=2, col=1) + + +def build_figure(record, width, height): + from plotly.subplots import make_subplots + + chromosome = record["chromosome"] + start = record["start"] + end = record["end"] + sv_type = record["sv_type"] + sv_length = record["size"] + + fig = make_subplots( + rows=2, + cols=1, + shared_xaxes=True, + vertical_spacing=0.05, + subplot_titles=("SNP Log2 Ratio", "SNP B-Allele Frequency"), ) - # Set the title of the plot. + for section in ["before_sv", "sv", "after_sv"]: + add_section_traces(fig, record, section, start, end) + + fig.update_xaxes(title_text="Chromosome Position", row=2, col=1) + fig.update_yaxes(title_text="Log2 Ratio", range=[-2.0, 2.0], row=1, col=1) + fig.update_yaxes(title_text="B-Allele Frequency", range=[-0.2, 1.2], row=2, col=1) + + title_label = SV_TYPE_LABELS.get(sv_type, sv_type) fig.update_layout( - title_text = f"{sv_type_dict[sv_type]} at {chromosome}:{start}-{end} ({sv_length} bp)", - title_x = 0.5, - showlegend = False, + title_text=f"{title_label} at {chromosome}:{start}-{end} ({sv_length} bp)", + title_x=0.5, + showlegend=False, + template="simple_white", + font={"family": "Arial", "size": 20, "color": "black"}, + width=width, + height=height, + margin={"l": 100, "r": 30, "t": 120, "b": 90}, + ) + fig.update_xaxes(showline=True, linewidth=2, linecolor="black", mirror=True, ticks="outside") + fig.update_yaxes(showline=True, linewidth=2, linecolor="black", mirror=True, ticks="outside") + return fig + + +def write_outputs(fig, base_name, output_dir, formats, width, height, scale): + if "html" in formats: + html_path = os.path.join(output_dir, f"{base_name}.html") + fig.write_html(html_path) + print(f"Plot saved as {html_path}") + + requested_static_formats = [fmt for fmt in formats if fmt in STATIC_FORMATS] + if requested_static_formats: + try: + for fmt in requested_static_formats: + out_path = os.path.join(output_dir, f"{base_name}.{fmt}") + fig.write_image(out_path, format=fmt, width=width, height=height, scale=scale) + print(f"Plot saved as {out_path}") + except ValueError as err: + print("Static image export requires Kaleido. Install with: pip install -U kaleido") + raise err + + +def main(): + args = parse_args() + try: + import plotly # noqa: F401 + except ImportError as err: + raise ImportError( + "Missing required dependency 'plotly'. Install with: conda install -c conda-forge plotly" + ) from err + + formats = parse_formats(args.formats) + records = load_json_records(args.json_file) + + output_dir = args.output_dir if args.output_dir else os.path.join(os.getcwd(), "CNV_Plots") + os.makedirs(output_dir, exist_ok=True) + + skip_count_chrom = 0 + skip_count_length = 0 + save_count = 0 + + for idx, record in enumerate(records, start=1): + validate_record(record, idx) + + if args.chromosome and record["chromosome"] != args.chromosome: + skip_count_chrom += 1 + continue + + if abs(record["size"]) < args.min_sv_length: + skip_count_length += 1 + continue + + fig = build_figure(record, args.width, args.height) + + sv_length = record["size"] + svlen_kb = sv_length // 1000 + base_name = ( + f"SV_{record['chromosome']}_{record['start']}_{record['end']}_" + f"{record['sv_type']}_{svlen_kb}kb" + ) + + write_outputs(fig, base_name, output_dir, formats, args.width, args.height, args.scale) + save_count += 1 + + print( + f"Finished processing {save_count} SVs. " + f"Skipped {skip_count_chrom} SVs due to chromosome filter and " + f"{skip_count_length} SVs due to length filter." ) - # height = 800, - # width = 800 - # ) - # Save the plot to an HTML file (use a unique filename per SV) - # Use the input filepath directory as the output directory - output_dir = os.path.dirname(args.json_file) - svlen_kb = sv_length // 1000 - file_name = f"SV_{chromosome}_{start}_{end}_{sv_type}_{svlen_kb}kb.html" - file_path = os.path.join(output_dir, file_name) - fig.write_html(file_path) - print(f"Plot saved as {file_path}") + + +if __name__ == "__main__": + main() diff --git a/python/environment_merge.yml b/python/environment_merge.yml deleted file mode 100644 index f5b1941c..00000000 --- a/python/environment_merge.yml +++ /dev/null @@ -1,15 +0,0 @@ -name: contextsvmerge -channels: - - bioconda - - anaconda - - conda-forge - - defaults -dependencies: - - python - - numpy - - pytest - - plotly - - pandas - - scikit-learn>=1.3 - -# conda env create --name contextsvmerge --file environment.yml diff --git a/python/extract_features.py b/python/extract_features.py deleted file mode 100644 index 37dbcb24..00000000 --- a/python/extract_features.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -extract_features.py: Extract features from the input VCF file. - -Usage: - extract_features.py - -Arguments: - Path to the input VCF file. - -Output: - A dataframe with a column for each feature. -""" - -import os -import sys -import logging -import numpy as np -import pandas as pd - - -def read_vcf(filepath): - """Read in the VCF file.""" - vcf_df = pd.read_csv(filepath, sep='\t', comment='#', header=None, usecols=[0, 1, 7], \ - names=['CHROM', 'POS', 'INFO'], \ - dtype={'CHROM': str, 'POS': np.int64, 'INFO': str}) - return vcf_df - -def extract_features(input_vcf): - """Extract the features from the VCF file's data.""" - # Read in the VCF file. - vcf_df = read_vcf(input_vcf) - - # Extract the read and clipped base support from the INFO column. - read_support = vcf_df['INFO'].str.extract(r'SUPPORT=(\d+)', expand=False).astype(np.int32) - - # Check if any read depths are missing. - if read_support.isnull().values.any(): - logging.error('Read support is missing.') - sys.exit(1) - - clipped_bases = vcf_df['INFO'].str.extract(r'CLIPSUP=(\d+)', expand=False).astype(np.int32) - - # Check if any clipped bases are missing. - if clipped_bases.isnull().values.any(): - logging.error('Clipped bases is missing.') - sys.exit(1) - - # Get the array of chromosome names. - chrom = vcf_df['CHROM'] - - # Create a key to map the chromosome names to a unique integer. - - # First, get all unique chromosome names. - chrom_unique = chrom.unique() - - # Next, create a dictionary to map the chromosome names to integers. - chrom_dict = {chrom: i for i, chrom in enumerate(chrom_unique)} - - # Finally, map the chromosome names to integers. - chrom = chrom.map(chrom_dict) - - - # Check if any chromosome names are missing. - if chrom.isnull().values.any(): - logging.error('Chromosome name is missing.') - sys.exit(1) - else: - # Print space-separated chromosome names. - logging.info('Chromosomes: ' + ' '.join(chrom.unique().astype(str))) - - # Get the start and end positions. - start = vcf_df['POS'] - - # Check if any start positions are missing. - if start.isnull().values.any(): - logging.error('Start position is missing.') - sys.exit(1) - - # Get the SV length from the INFO column. - sv_length = vcf_df['INFO'].str.extract(r'SVLEN=(-?\d+)', expand=False).astype(np.int32) - - # Check if any SV lengths are missing. - if sv_length.isnull().values.any(): - logging.error('SV length is missing.') - sys.exit(1) - - # Get the SV type from the INFO column. - sv_type = vcf_df['INFO'].str.extract(r'SVTYPE=(\w+)', expand=False) - - # If INFO/REPTYPE=DUP, then the SV type is a duplication. - sv_type[vcf_df['INFO'].str.contains('REPTYPE=DUP')] = 'DUP' - - # Convert the SV type to integers. - sv_type = sv_type.replace('DEL', '0') - sv_type = sv_type.replace('DUP', '1') - sv_type = sv_type.replace('INV', '2') - sv_type = sv_type.replace('INS', '3') - sv_type = sv_type.replace('BND', '4') - sv_type = sv_type.astype(np.int32) - - # Check if any SV types are missing. - if sv_type.isnull().values.any(): - logging.error('SV type is missing.') - sys.exit(1) - - # Loop through the columns and check if any values are missing for all of - # the feature arrays. - for col in [chrom, start, sv_length, sv_type, read_support, clipped_bases]: - if col.isnull().values.all(): - logging.error('All values are missing for a feature.') - logging.error(col) - sys.exit(1) - - # Print the first 4 rows of the features. - logging.info('Features:') - logging.info(pd.DataFrame({'chrom': chrom.head(4), 'start': start.head(4), 'sv_length': sv_length.head(4), \ - 'sv_type': sv_type.head(4), 'read_support': read_support.head(4), \ - 'clipped_bases': clipped_bases.head(4)})) - - # Check that all features have the same length. - if not all(len(col) == len(chrom) for col in [start, sv_length, sv_type, read_support, clipped_bases]): - logging.error('Features do not have the same length.') - - # Print the length of each feature. - logging.error('Chromosomes: ' + str(len(chrom))) - logging.error('Start positions: ' + str(len(start))) - logging.error('SV lengths: ' + str(len(sv_length))) - logging.error('SV types: ' + str(len(sv_type))) - logging.error('Read support: ' + str(len(read_support))) - logging.error('Clipped bases: ' + str(len(clipped_bases))) - - sys.exit(1) - - # Create a dataframe of the features. - features = pd.DataFrame({'chrom': chrom, 'start': start, 'sv_length': sv_length, 'sv_type': sv_type, \ - 'read_support': read_support, 'clipped_bases': clipped_bases}) - - # Check if any features are missing. - if features.isnull().values.any(): - logging.error('Features are missing.') - - # Get the rows with missing features. - missing_features = features[features.isnull().any(axis=1)] - - # Print the rows with missing features. - logging.error(missing_features) - sys.exit(1) - - # Return the features. - return features diff --git a/python/mendelian_error.py b/python/mendelian_error.py deleted file mode 100644 index 2cf69572..00000000 --- a/python/mendelian_error.py +++ /dev/null @@ -1,135 +0,0 @@ -""" -mendelian_error.py: Compute the Mendelian error rate from the VCF files of a -father, mother, and son. - -Usage: - mendelian_error.py - -Arguments: - Path to the father's VCF file. - Path to the mother's VCF file. - Path to the son's VCF file. - -Output: - The Mendelian error rate (proportion of variants with Mendelian errors). - -Example: - mendelian_error.py father.vcf mother.vcf son.vcf -""" - -import sys -import logging -import numpy as np -import pandas as pd - -def get_genotype(sample): - """ - Parse the genotype (GT) field from the SAMPLE column of a VCF file. - """ - genotype = sample.split(':')[0] - - if genotype == './.': - return None - else: - return genotype - - -def compute_mendelian_error_rates(father_file, mother_file, child_file): - """ - Compute the Mendelian error rate from the VCF files of a father, mother, - and child. - """ - # Read the VCF files into pandas dataframes - father_df = pd.read_csv(father_file, sep='\t', comment='#', header=None, \ - names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE'], \ - dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ - 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str}, nrows=10000) - - mother_df = pd.read_csv(mother_file, sep='\t', comment='#', header=None, \ - names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE'], \ - dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ - 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str}, nrows=10000) - - child_df = pd.read_csv(child_file, sep='\t', comment='#', header=None, \ - names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE'], \ - dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ - 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str}, nrows=10000) - - # Parse the genotype (GT) fields and compute the Mendelian error rates - total_variants = len(child_df) - mendelian_errors = 0 - - for i in range(total_variants): - # Loop through the father's variants and compare with the mother's and - # child's variants - - # Get the current variant's location - chrom = child_df['CHROM'][i] - pos = child_df['POS'][i] - svlen = child_df['INFO'][i].split(';')[0].split('=')[1] - - #print(f"Chrom: {chrom}, Pos: {pos}, SVLEN: {svlen}") - - # Find the same variant in the mother's and father's VCF files - mother_df = mother_df[(mother_df['CHROM'] == chrom) & (mother_df['POS'] == pos)] - father_df = father_df[(father_df['CHROM'] == chrom) & (father_df['POS'] == pos)] - - # Check if the variant is present in the mother's and child's VCF files - if mother_df.empty or father_df.empty: - #logging.warning("Variant not found in mother's or child's VCF file") - continue - else: - print("Variant found in mother's and father's VCF file at %s:%d" % (chrom, pos)) - - # Get the samples - child_sample = child_df['SAMPLE'][i] - mother_sample = mother_df['SAMPLE'].values[0] - father_sample = father_df['SAMPLE'].values[0] - - # Get the genotypes - father_genotype = get_genotype(father_sample) - mother_genotype = get_genotype(mother_sample) - child_genotype = get_genotype(child_sample) - - # Skip if any of the genotypes are missing - if father_genotype is None or mother_genotype is None or child_genotype is None: - logging.warning("Missing genotype(s) for variant at %s:%d", chrom, pos) - continue - - # Print the genotypes - print(f"Father: {father_genotype}, Mother: {mother_genotype}, Child: {child_genotype}") - - # Mendelian error: Child's genotype is inconsistent with inheritance of - # exactly one allele from each parent. - # Scenario 1: Father and mother have the same genotype, but the child's - # genotype is different. - # Scenario 2: Father and mother have different genotypes, but the - # child's genotype is the same as one of the parents'. - # See Smolka et al. (2022) for more details (preprint for Sniffles2): - # https://www.biorxiv.org/content/10.1101/2022.04.04.487055v2.full - - # Scenario 1 - if father_genotype == mother_genotype and father_genotype != son_genotype: - mendelian_errors += 1 - - # Scenario 2 - if father_genotype != mother_genotype and (father_genotype == son_genotype or mother_genotype == son_genotype): - mendelian_errors += 1 - - mendelian_error_rate = mendelian_errors / total_variants - - return mendelian_error_rate - - -if __name__ == '__main__': - logging.basicConfig(level=logging.INFO) - logging.info("Running mendelian_error.py") - if len(sys.argv) != 4: - logging.error("Incorrect number of arguments") - sys.exit(__doc__) - - father_file = sys.argv[1] - mother_file = sys.argv[2] - child_file = sys.argv[3] - me_rate = compute_mendelian_error_rates(father_file, mother_file, child_file) - logging.info("Mendelian error rate: %.4f", me_rate) diff --git a/python/mendelian_inheritance.py b/python/mendelian_inheritance.py deleted file mode 100644 index 128b1d1a..00000000 --- a/python/mendelian_inheritance.py +++ /dev/null @@ -1,78 +0,0 @@ -import csv -import sys - - -def read_tsv(file_path): - with open(file_path, 'r') as file: - reader = csv.reader(file, delimiter='\t') - return [row for row in reader] - -def calculate_mendelian_error(father_genotype, mother_genotype, child_genotype): - # Generate all possible child genotypes - child_genotypes = set() - for allele1 in father_genotype.split('/'): - for allele2 in mother_genotype.split('/'): - child_genotypes.add('/'.join(sorted([allele1, allele2]))) - - # Print the parent and child genotypes if invalid - if child_genotype not in child_genotypes: - print(f"ME: Father: {father_genotype}, Mother: {mother_genotype}, Child: {child_genotype}") - - # Check if the child genotype is valid - return 0 if child_genotype in child_genotypes else 1 - - -def main(father_file, mother_file, child_file): - father_records = read_tsv(father_file) - mother_records = read_tsv(mother_file) - child_records = read_tsv(child_file) - - if len(father_records) != len(mother_records) or len(father_records) != len(child_records): - raise ValueError("All files must have the same number of records") - - total_records = len(father_records) - error_count = 0 - - sv_type_dict = {} - sv_type_error_dict = {} - - for i in range(total_records): - father_genotype = father_records[i][5] - mother_genotype = mother_records[i][5] - child_genotype = child_records[i][5] - child_sv_type = child_records[i][2] - sv_type_dict[child_sv_type] = sv_type_dict.get(child_sv_type, 0) + 1 - - # Print SV size if error occurs - error_value = calculate_mendelian_error(father_genotype, mother_genotype, child_genotype) - if error_value == 1: - # print(f"SV size: {father_records[i][2]}") - sv_type_error_dict[child_sv_type] = sv_type_error_dict.get(child_sv_type, 0) + 1 - - error_count += error_value - # error_count += calculate_mendelian_error(father_genotype, mother_genotype, child_genotype) - - if total_records == 0: - error_rate = 0 - print("No records found") - else: - error_rate = error_count / total_records - - print(f"Mendelian Inheritance Error Rate: {error_rate:.2%} for {total_records} shared trio SVs") - - print("SV Type Distribution:") - for sv_type, count in sv_type_dict.items(): - error_count = sv_type_error_dict.get(sv_type, 0) - error_rate = error_count / count - print(f"{sv_type}: {error_rate:.2%} ({error_count}/{count})") - -if __name__ == "__main__": - if len(sys.argv) != 4: - print("Usage: python mendelian_inheritance.py ") - sys.exit(1) - - father_file = sys.argv[1] - mother_file = sys.argv[2] - child_file = sys.argv[3] - - main(father_file, mother_file, child_file) diff --git a/python/plot_distributions.py b/python/plot_distributions.py deleted file mode 100644 index c2644a8a..00000000 --- a/python/plot_distributions.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -plot_distributions.py: Plot the distributions of SV sizes in the input VCF file -and save the plot as a PNG file. - -Usage: - plot_distributions.py - -Arguments: - Path to the input VCF file. - Path to the output PNG file. - -Output: - A PNG file with the SV size distributions. - -Example: - python plot_distributions.py input.vcf output.png -""" - -import sys -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt - -# import plotly -import plotly.graph_objects as go - -def generate_sv_size_plot(input_vcf, output_png, plot_title="SV Caller"): - # Read VCF file into a pandas DataFrame - try: - vcf_df = pd.read_csv(input_vcf, sep='\t', comment='#', header=None, \ - names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE'], \ - dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ - 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str}) - except Exception as e: - try: - print("[DEBUG] Caught TypeError") - # Truvari merged VCF format with different columns - vcf_df = pd.read_csv(input_vcf, sep='\t', comment='#', header=None, \ - names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE', 'SAMPLE2'], \ - dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ - 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str, 'SAMPLE2': str}) - except Exception as e: - print("[DEBUG] Caught Exception") - # Platinum pedigree VCF format with different columns - vcf_df = pd.read_csv(input_vcf, sep='\t', comment='#', header=None, \ - names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE', 'SAMPLE2', 'SAMPLE3', 'SAMPLE4', 'SAMPLE5', 'SAMPLE6', 'SAMPLE7'], \ - dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ - 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE1': str, 'SAMPLE2': str, 'SAMPLE3': str, 'SAMPLE4': str, \ - 'SAMPLE5': str, 'SAMPLE6': str, 'SAMPLE7': str}) - - # Initialize dictionaries to store SV sizes for each type of SV - sv_sizes = {} - - # Iterate over each record in the VCF file - print("SV CALLER: ", plot_title) - for _, record in vcf_df.iterrows(): - - # Get the POS - pos = record['POS'] - - # Get the SV data by splitting semi-colon separated INFO field and - # extracting SVTYPE and SVLEN - info_fields = record['INFO'].split(';') - sv_type = None - sv_len = None # INFO/SVLEN - sv_span = None # INFO/END - POS - alignment = "NA" - for field in info_fields: - if field.startswith('SVTYPE='): - sv_type = field.split('=')[1] - elif field.startswith('SVLEN='): - sv_len = abs(int(field.split('=')[1])) - elif field.startswith('END='): - sv_span = int(field.split('=')[1]) - pos - elif field.startswith('ALN='): - alignment = field.split('=')[1] - - # Continue if SV type is BND (no SV size) - if sv_type == "BND": - continue - - # If the SV caller is DELLY, then we use the second SV size for non-INS - # (they don't have SVLEN) and the first SV size for INS - sv_size = None - if plot_title == "DELLY" and sv_type != "INS": - sv_size = sv_span - else: - sv_size = sv_len - - # If the plot title is GIAB, then we need to convert INS to DUP if - # INFO/SVTYPE is INS and INFO/REPTYPE is DUP - if "GIAB" in plot_title and sv_type == "INS": - if 'REPTYPE=DUP' in record['INFO']: - sv_type = "DUP" - - # Add the SV type if it's not in the dictionary - if sv_type not in sv_sizes: - sv_sizes[sv_type] = [] - - # Add the SV size to the dictionary - sv_sizes[sv_type].append(sv_size) - - # Create a tiled plot where each tile shows the SV size distribution for a - # different SV type - sv_type_count = len(sv_sizes) - fig, axes = plt.subplots(sv_type_count, 1, figsize=(10, 5 * sv_type_count)) - print(f'Number of SV types: {sv_type_count}') - - # Create a dictionary of SV types and their corresponding colors. - # From: https://davidmathlogic.com/colorblind/ - # WONG colors - sv_colors = {'DEL': '#E69F00', 'DUP': '#56B4E9', 'INV': '#009E73', 'INS': '#F0E442', 'INVDUP': '#D55E00', 'COMPLEX': '#CC79A7'} - - # Create a dictionary of SV types and their corresponding labels - sv_labels = {'DEL': 'Deletion', 'DUP': 'Duplication', 'INV': 'Inversion', 'INS': 'Insertion', 'INVDUP': 'Inverted Duplication', 'COMPLEX': 'Complex'} - - # Get the list of SV types and sort them in the order of the labels - sv_types = sorted(sv_sizes.keys(), key=lambda x: sv_labels[x]) - - # Print the number of SVs for each type, starting with the label - print("SV Caller: ", plot_title) - print("Total number of SVs: ", len(vcf_df)) - - print('Number of SVs for each type:') - total_sv_count = 0 - for sv_type in sv_types: - print(f'{sv_labels[sv_type]}: {len(sv_sizes[sv_type])}') - total_sv_count += len(sv_sizes[sv_type]) - - print(f'Total number of SVs (sum): {total_sv_count}') - - # Print the number of SVs for each type with size > 50kb - print('Number of SVs for each type with size > 50kb:') - for sv_type in sv_types: - print(f'{sv_labels[sv_type]}: {len([x for x in sv_sizes[sv_type] if abs(x) > 50000])}') - - # Summary statistics - all_sv_sizes = [] - for sv_type in sv_types: - all_sv_sizes.extend(sv_sizes[sv_type]) - print('Summary statistics:') - print(f'Minimum SV size: {min(all_sv_sizes)}') - print(f'Maximum SV size: {max(all_sv_sizes)}') - print(f'Mean SV size: {np.mean(all_sv_sizes)}') - print(f'Median SV size: {np.median(all_sv_sizes)}') - print(f'Standard deviation of SV sizes: {np.std(all_sv_sizes)}') - print(f'Number of SVs >10kb: {len([x for x in all_sv_sizes if abs(x) > 10000])}') - print(f'Number of SVs >50kb: {len([x for x in all_sv_sizes if abs(x) > 50000])}') - print(f'Number of SVs >100kb: {len([x for x in all_sv_sizes if abs(x) > 100000])}') - - # Plot the SV size distributions - size_scale = 1000 # Convert SV sizes from bp to kb. Use abs() to handle negative deletion sizes - for i, sv_type in enumerate(sv_types): - sizes = np.array(sv_sizes[sv_type]) - axes[i].hist(np.abs(sizes) / size_scale, bins=100, color=sv_colors[sv_type], alpha=0.7, label=sv_labels[sv_type], edgecolor='black') - axes[i].set_xlabel('SV size (kb)') - axes[i].set_ylabel('Frequency (log scale)') - axes[i].set_title(f'{plot_title}: {sv_labels[sv_type]}') - - # Use a log scale for the y-axis - axes[i].set_yscale('log') - - # In the same axis, plot a known duplication if within the range of the plot - # if sv_type == 'DUP': - # print("TEST: Found DUP") - # cnv_size = 776237 / size_scale - # x_min, x_max = axes[i].get_xlim() - # if cnv_size > x_min and cnv_size < x_max: - # axes[i].axvline(x=cnv_size, color='black', linestyle='--') - # else: - # # Print the values - # print(f'CNV size: {cnv_size}, x_min: {x_min}, x_max: {x_max}') - - # Refresh the plot - plt.draw() - - # Save the plot as a PNG file - plt.tight_layout() - plt.savefig(output_png) - - # Plot an additional plot with suffix _full.png that includes all SV types - # (using plotly to avoid overlapping histograms) - max_size = np.max(np.abs(all_sv_sizes)) - max_bin_edge = np.max([1000000, max_size]) # Set the maximum bin edge to 1Mb or the max size - bin_edges = [0, 1000, 5000, 10000, 50000, 100000, 500000, max_bin_edge] # Bin edges - bin_edges = np.array(bin_edges) / size_scale # Convert to kb - bin_labels = ['0-1kb', '1-5kb', '5-10kb', '10-50kb', '50-100kb', '100-500kb', '500kb+'] - x_values = np.arange(len(bin_edges) - 1) # x values for the histogram - - # Create histograms using the bin edges - fig = go.Figure() - for sv_type in sv_types: - sizes = np.array(np.abs(sv_sizes[sv_type])) / size_scale - - counts, _ = np.histogram(sizes, bins=bin_edges) - fig.add_trace(go.Bar(x=x_values, y=counts, name=sv_labels[sv_type], marker_color=sv_colors[sv_type])) - - - # Update the layout to group the bars side by side - fig.update_layout( - barmode='group', - title=f'{plot_title}: All SV types', - xaxis_title='SV size (kb)', - yaxis_title='Frequency (log scale)', - yaxis_type='log', - bargap=0.3, - ) - - # Add the bin edges to the x-axis ticks as a range - fig.update_xaxes(tickvals=x_values, ticktext=bin_labels) - - # Move the legend to the top right inside the plot - fig.update_layout(legend=dict( - orientation='v', - yanchor='top', - y=0.9, - xanchor='right', - x=0.9, - )) - # # Move the legend to the bottom right outside the plot - # fig.update_layout(legend=dict( - # orientation='v', - # yanchor='top', - # y=1.0, - # xanchor='right', - # x=1.15, - # )) - - # Set a larger font size for all text in the plot - fig.update_layout(font=dict(size=26)) - - # # Save the plot as a high-resolution PNG file for using in posters - fig.write_image(output_png.replace('.png', '_full.png'), width=1200, height=800) - print(f'Saved plot to {output_png.replace(".png", "_full.png")}') - - -if __name__ == '__main__': - # Get the input and output file paths from the command line arguments - input_file = sys.argv[1] - output_file = sys.argv[2] - plot_title = sys.argv[3] - - print(f'Input file: {input_file}') - print(f'Output file: {output_file}') - - # Generate the SV size plot - generate_sv_size_plot(input_file, output_file, plot_title=plot_title) diff --git a/python/plot_venn.py b/python/plot_venn.py deleted file mode 100644 index 757f4408..00000000 --- a/python/plot_venn.py +++ /dev/null @@ -1,48 +0,0 @@ -# from matplotlib_venn import venn3 -from matplotlib_venn import venn2 -import argparse - -import matplotlib.pyplot as plt - -def plot_venn(AB, Ab, aB, output, plot_title, title_Ab, title_aB): - plt.figure(figsize=(8, 8)) - - print('AB:', AB) - print('Ab:', Ab) - print('aB:', aB) - - # Create scaled subsets for the venn diagram - scaling_factor = 1000 - scaled_AB = AB / scaling_factor - scaled_Ab = Ab / scaling_factor - scaled_aB = aB / scaling_factor - - # Create a venn diagram scaled to the number of elements in each set - # venn = venn2(subsets=(AB, Ab, aB), set_labels=(title_Ab, title_aB)) - venn = venn2(subsets=(scaled_Ab, scaled_aB, scaled_AB), set_labels=(title_Ab, title_aB)) - - # Update the labels to reflect the actual counts - venn.get_label_by_id('10').set_text(str(Ab)) - venn.get_label_by_id('01').set_text(str(aB)) - venn.get_label_by_id('11').set_text(str(AB)) - - # Update the title - # plt.title("contextsv and " + title_aB + " venn diagram (all SV types)") - plt.title(plot_title) - plt.savefig(output) - plt.close() - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Generate a Venn diagram.') - parser.add_argument('-a', type=int, required=True, help='Shared count') - parser.add_argument('-b', type=int, required=True, help='False positive count') - parser.add_argument('-c', type=int, required=True, help='False negative count') - parser.add_argument('-o', '--output', type=str, required=True, help='Output file path') - parser.add_argument('-a_title', type=str, required=True, help='Title for set A') - parser.add_argument('-b_title', type=str, required=True, help='Title for set B') - parser.add_argument('-c_title', type=str, required=True, help='Title for set C') - - args = parser.parse_args() - - plot_venn(args.a, args.b, args.c, args.output, args.a_title, args.b_title, args.c_title) - print(f'Venn diagram saved to {args.output}') diff --git a/python/predict.py b/python/predict.py deleted file mode 100644 index 70d186eb..00000000 --- a/python/predict.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -scoring_model.py: Score the structural variants using the binary classification -model. - -Usage: - scoring_model.py - -Arguments: - Path to the input VCF file. - Path to the model file. -""" - -import os -import sys -import logging -import numpy as np -import joblib -import pandas as pd - -import matplotlib.pyplot as plt - -from extract_features import extract_features - - -def score(model, input_vcf, output_vcf): - """Score the structural variants using the binary classification model. - - Args: - model (str): Path to the model file. - input_vcf (str): Path to the input VCF file. - output_vcf (str): Path to the output VCF file. - """ - # Load the model - clf = joblib.load(model) - - # Extract the features from the VCF file - X = extract_features(input_vcf) - - # Predict the labels and get the probabilities - y_pred = clf.predict_proba(X) - - # logging.info('Predicted labels:\n%s', y_pred) - - # Plot a histogram of the probabilities - plt.hist(y_pred[:, 1], bins=20) - plt.xlabel('Probability') - plt.ylabel('Count') - - # # Save the plot to the input VCF file's directory - # output_dir = os.path.dirname(output_vcf) - # output_filepath = os.path.join(output_dir, 'probabilities.png') - # plt.savefig(output_filepath) - # logging.info('Saved the plot of the probabilities to %s.', output_filepath) - - # Save the plot to the working directory - plt.savefig('output/probabilities.png') - - -if __name__ == '__main__': - - # Model file - model = sys.argv[1] - - # Input VCF file to score - input_vcf = sys.argv[2] - - # Output VCF file - output_vcf = sys.argv[3] - - # Set up the logger - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s %(levelname)s: %(message)s', - datefmt='%Y-%m-%d %H:%M:%S') - - # Score the structural variants - score(model, input_vcf, output_vcf) - \ No newline at end of file diff --git a/python/score_vcf.py b/python/score_vcf.py deleted file mode 100644 index 1e805017..00000000 --- a/python/score_vcf.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -score_vcf.py - Score structural variants in a VCF file using a binary classification model. - -This script prioritizes structural variants in a VCF file by scoring them using -a binary classification model. The model is trained using a VCF file of true -positive structural variants and a VCF file of false positive structural -variants. The model is trained using the following features extracted from the -VCF files: chromosome, start position, structural variant length, structural -variant type, read support, and clipped bases. The model is a logistic -regression model. - -Usage: - python score_vcf.py - -Arguments: - model_path: str - Path to the trained model file. - vcf_filepath: str - Path to the VCF file to score. - -Example: - python score_vcf.py model.pkl structural_variants.vcf - -""" - -import os -import sys -import logging -import numpy as np -import joblib -import pandas as pd -from sklearn.linear_model import LogisticRegression -import matplotlib.pyplot as plt - -from extract_features import extract_features - - -# Set up the logger. -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - - -def score(model_path, vcf_filepath, output_vcf): - """Load the model and VCF file and score the structural variants.""" - # Load the VCF file. - logging.info('Extracting features from the VCF file.') - features = extract_features(vcf_filepath) - - # Load the model. - logging.info('Loading the model.') - model = joblib.load(model_path) - - # Score the structural variants. - logging.info('Scoring the structural variants.') - scores = model.predict_proba(features) - - # Plot a histogram of the scores. - logging.info('Plotting the distribution of scores.') - plt.hist(scores) - plt.xlabel('Score') - plt.ylabel('Frequency') - plt.title('Distribution of Scores') - - # Save the plot as a PNG file. - output_png = "scores.png" - plt.tight_layout() - plt.savefig(output_png) - logging.info('Saved the plot as %s.', output_png) - - -if __name__ == '__main__': - # Get the command line arguments. - if len(sys.argv) != 4: - logging.error('Usage: python score_vcf.py \n') - sys.exit(1) - - # Get the model path and VCF file path. - model_path = sys.argv[1] - vcf_filepath = sys.argv[2] - output_vcf = sys.argv[3] - - # Run the program. - score(model_path, vcf_filepath, output_vcf) - logging.info('done.') diff --git a/python/sv_merger.py b/python/sv_merger.py deleted file mode 100644 index 2f5cb94f..00000000 --- a/python/sv_merger.py +++ /dev/null @@ -1,396 +0,0 @@ -""" -sv_merger.py -Use DBSCAN to merge SVs with the same breakpoint. -Mode can be 'dbscan', 'gmm', or 'agglomerative'. -https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html - -Usage: python sv_merger.py -Output: .merged.vcf -""" - -import os, sys -import numpy as np -import pandas as pd - -import logging -logging.basicConfig(level=logging.INFO) - -import matplotlib.pyplot as plt # For plotting merge behavior - -# HDBSCAN clustering algorithm -from sklearn.cluster import HDBSCAN -from sklearn.cluster import DBSCAN - - -def plot_dbscan(breakpoints, chosen_breakpoints, filename='dbscan_clustering.png'): - """ - Plot the DBSCAN clustering behavior for SV breakpoints. - """ - # logging.info the filename - logging.info(f"Plotting DBSCAN clustering behavior to {filename}...") - - # logging.info all breakpoints - logging.info(f"Breakpoints:") - for i in range(breakpoints.shape[0]): - logging.info(f"Row {i+1} - Breakpoints: {breakpoints[i, :]}") - - # Remove the chosen breakpoints from the breakpoints array - breakpoints = np.delete(breakpoints, np.where(breakpoints == chosen_breakpoints), axis=0) - - # Plot the SV breakpoints as individual lines in each row, and the chosen - # SV breakpoint as a red line at the top - # Create a new figure - plt.close() - plt.clf() - plt.cla() - plt.figure(figsize=(10, 10)) - for i in range(breakpoints.shape[0]): - row = i+1 - plt.plot(breakpoints[i, :], [row, row], 'b-') - - plt.plot(chosen_breakpoints, [0, 0], 'r-') - logging.info(f"Chosen breakpoints: {chosen_breakpoints}") - - # Set plot labels - plt.title('DBSCAN Clustering Behavior') - plt.xlabel('Breakpoint Position') - plt.ylabel('SVs') - plt.legend() - - # Save the plot - plt.savefig(filename) - - -def update_support(record, cluster_size): - """ - Set the SUPPORT field in the INFO column of a VCF record to the cluster size. - """ - # Get the INFO column - info = record['INFO'] - - # Parse the INFO columns - info_fields = info.split(';') - - # Loop and update the SUPPORT field, while creating a new INFO string - updated_info = '' - for field in info_fields: - if field.startswith('SUPPORT='): - # Get the current SUPPORT field - previous_support = int(field.split('=')[1]) - - # Add the cluster size to the SUPPORT field - updated_info += f'SUPPORT={previous_support + cluster_size};' - # updated_info += f'SUPPORT={cluster_size};' - else: - updated_info += field + ';' # Append the field to the updated INFO - - # Update the INFO column - record['INFO'] = updated_info - - return record - -def weighted_score(sv_len, hmm_score, weight_hmm): - """ - Calculate a weighted score based on read support and HMM score. - """ - return (1 - weight_hmm) * sv_len + weight_hmm * hmm_score - -def cluster_breakpoints(vcf_df, sv_type, cluster_size_min): - """ - Cluster SV breakpoints using HDBSCAN. - """ - # Set up the output DataFrame - merged_records = pd.DataFrame(columns=['INDEX', 'CHROM', 'POS', 'INFO']) - - # Format the SV breakpoints - breakpoints = None - if sv_type == 'DEL': - sv_start = vcf_df['POS'].values - sv_end = vcf_df['INFO'].str.extract(r'END=(\d+)', expand=False).astype(np.int32) - - # Format the deletion breakpoints - breakpoints = np.column_stack((sv_start, sv_end)) - - elif sv_type == 'INS/DUP': - sv_start = vcf_df['POS'].values - sv_len = vcf_df['INFO'].str.extract(r'SVLEN=(-?\d+)', expand=False).astype(np.int32) - sv_end = sv_start + sv_len - 1 - - # Format the insertion and duplication breakpoints - breakpoints = np.column_stack((sv_start, sv_end)) - else: - logging.error("Invalid SV type: %s", sv_type) - sys.exit(1) - - # Get the combined SV read and clipped base support - sv_support = vcf_df['INFO'].str.extract(r'SUPPORT=(\d+)', expand=False).astype(np.int32) - sv_clipped_base_support = vcf_df['INFO'].str.extract(r'CLIPSUP=(\d+)', expand=False).astype(np.int32) - sv_support = sv_support + sv_clipped_base_support - - # Get the HMM likelihood scores - hmm_scores = vcf_df['INFO'].str.extract(r'HMM=(-?\d+\.?\d*)', expand=False).astype(float) - - # Set all 0 values to a low negative value - hmm_scores[hmm_scores == 0] = -1e-100 - # hmm_scores[hmm_scores == 0] = np.nan - - # Cluster SV breakpoints using HDBSCAN - cluster_labels = [] - - # dbscan = DBSCAN(eps=30000, min_samples=3) - - if len(breakpoints) == 1: - return merged_records - - logging.info("Clustering %d SV breakpoints with parameters: min_cluster_size=%d", len(breakpoints), cluster_size_min) - dbscan = HDBSCAN(min_cluster_size=cluster_size_min, min_samples=2) - if len(breakpoints) > 0: - logging.info("Clustering %d SV breakpoints...", len(breakpoints)) - cluster_labels = dbscan.fit_predict(breakpoints) - - logging.info("Label counts: %d", len(np.unique(cluster_labels))) - - - # Merge SVs with the same label - unique_labels = np.unique(cluster_labels) - #logging.info("Unique labels: %s", unique_labels) - - for label in unique_labels: - - # Skip label -1 (outliers) only if there are no other clusters - if label == -1 and len(unique_labels) > 1: - continue - - # Get the indices of SVs with the same label - idx = cluster_labels == label - - # Get HMM and read support values for the cluster - # max_score_idx = 0 # Default to the first SV in the cluster - cluster_hmm_scores = np.array(hmm_scores[idx]) - # cluster_depth_scores = np.array(sv_support[idx]) - cluster_sv_lengths = np.array(breakpoints[idx][:, 1] - breakpoints[idx][:, 0] + 1) - # max_hmm = None - # max_support = None - # max_hmm_idx = None - # max_support_idx = None - - # Find the maximum HMM score - # if len(np.unique(cluster_hmm_scores)) > 1: - # max_hmm_idx = np.nanargmax(cluster_hmm_scores) - # max_hmm = cluster_hmm_scores[max_hmm_idx] - - # Find the maximum read alignment and clipped base support - # if len(np.unique(cluster_depth_scores)) > 1: - # max_support_idx = np.argmax(cluster_depth_scores) - # max_support = cluster_depth_scores[max_support_idx] - - # Normalize the HMM scores. Since the HMM scores are negative (log lh), we - # normalize them to the range [0, 1] by subtracting the minimum value - cluster_hmm_norm = (cluster_hmm_scores - np.min(cluster_hmm_scores)) / (np.max(cluster_hmm_scores) - np.min(cluster_hmm_scores)) - - # Normalize the SV lengths to the range [0, 1] - cluster_sv_lengths_norm = (cluster_sv_lengths - np.min(cluster_sv_lengths)) / (np.max(cluster_sv_lengths) - np.min(cluster_sv_lengths)) - - # Use a weighted approach to choose the best SV based on HMM and - # support. Deletions have higher priority for HMM scores, while - # insertions and duplications have higher priority for read alignment - # support. - # hmm_weight = 0.7 if sv_type == 'DEL' else 0.3 - hmm_weight = 0.5 - max_score_idx = 0 # Default to the first SV in the cluster - max_score = weighted_score(cluster_hmm_norm[max_score_idx], cluster_sv_lengths_norm[max_score_idx], hmm_weight) - # max_score = weighted_score(cluster_sv_lengths[max_score_idx], cluster_hmm_scores[max_score_idx], hmm_weight) - for k, hmm_norm in enumerate(cluster_hmm_norm): - svlen_norm = cluster_sv_lengths_norm[k] - score = weighted_score(svlen_norm, hmm_norm, hmm_weight) - if score > max_score: - max_score = score - max_score_idx = k - - # Get the VCF record with the highest score - max_record = vcf_df.iloc[idx, :].iloc[max_score_idx, :] - - # # For deletions, choose the SV with the highest HMM score if available - # if sv_type == 'DEL': - # if max_hmm is not None: - # max_score_idx = max_hmm_idx - # elif max_support is not None: - # max_score_idx = max_support_idx - - # # For insertions and duplications, choose the SV with the highest read - # # support if available - # elif sv_type == 'INS/DUP': - # if max_support is not None: - # max_score_idx = max_support_idx - # elif max_hmm is not None: - # max_score_idx = max_hmm_idx - - # Get the VCF record with the highest depth score - # max_record = vcf_df.iloc[idx, :].iloc[max_score_idx, :] - - # Get the number of SVs in this cluster - cluster_size = np.sum(idx) - # logging.info("DEL Cluster size: %s", cluster_size) - - # Update the SUPPORT field in the INFO column - max_record = update_support(max_record, cluster_size) - # pos_values = breakpoints[idx][:, 0] - - # Append the chosen record to the dataframe of records that will - # form the merged VCF file - merged_records.loc[merged_records.shape[0]] = max_record - - return merged_records - -def sv_merger(vcf_file_path, cluster_size_min=3, suffix='.merged'): - """ - Use DBSCAN to merge SVs with the same breakpoint. - Mode can be 'dbscan', 'gmm', or 'agglomerative'. - https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html - """ - - logging.info("Merging SVs in %s using HDBSCAN with minimum cluster size=%d...", vcf_file_path, cluster_size_min) - - # Read VCF file into a pandas DataFrame, using only CHROM, POS, and INFO - # columns - logging.info("Reading VCF file into a pandas DataFrame...") - vcf_df = pd.read_csv(vcf_file_path, sep='\t', comment='#', header=None, usecols=[0, 1, 7], \ - names=['CHROM', 'POS', 'INFO'], \ - dtype={'CHROM': str, 'POS': np.int64, 'INFO': str}) - - # Add a column at the beginning with the index - vcf_df.insert(0, 'INDEX', range(0, len(vcf_df))) - logging.info("Reading complete.") - - # Print total number of records - logging.info("Total number of records: %d", vcf_df.shape[0]) - - # Store a dataframe of records that will form the merged VCF file - merged_records = pd.DataFrame(columns=['INDEX', 'CHROM', 'POS', 'INFO']) - - # Create a set with each chromosome in the VCF file - chromosomes = set(vcf_df['CHROM'].values) - - # [TEST] Use only chromosome 5 - # chromosomes = ['chr5'] - - # Iterate over each chromosome - records_processed = 0 - current_chromosome = 0 - chromosome_count = len(chromosomes) - for chromosome in chromosomes: - - # Cluster deletions - logging.info("Clustering deletions on chromosome %s...", chromosome) - chr_del_df = vcf_df[(vcf_df['CHROM'] == chromosome) & (vcf_df['INFO'].str.contains('SVTYPE=DEL'))] - del_records = cluster_breakpoints(chr_del_df, 'DEL', cluster_size_min) - del chr_del_df - - # Cluster insertions and duplications - logging.info("Clustering all other SVs on chromosome %s...", chromosome) - # chr_ins_dup_df = vcf_df[(vcf_df['CHROM'] == chromosome) & - # ((vcf_df['INFO'].str.contains('SVTYPE=INS')) | - # (vcf_df['INFO'].str.contains('SVTYPE=DUP')))] - chr_non_del_df = vcf_df[(vcf_df['CHROM'] == chromosome) & (~vcf_df['INFO'].str.contains('SVTYPE=DEL'))] - ins_dup_records = cluster_breakpoints(chr_non_del_df, 'INS/DUP', cluster_size_min) - del chr_non_del_df - - # Summarize the number of deletions and insertions/duplications - del_count = del_records.shape[0] - ins_dup_count = ins_dup_records.shape[0] - records_processed += del_count + ins_dup_count - logging.info("Chromosome %s - %d deletions, %d other types merged.", chromosome, del_count, ins_dup_count) - - # Append the deletion and insertion/duplication records to the merged - # records DataFrame - merged_records = pd.concat([merged_records, del_records, ins_dup_records], ignore_index=True) - - current_chromosome += 1 - logging.info("Processed %d of %d chromosomes.", current_chromosome, chromosome_count) - - logging.info("Processed %d records of %d total records.", records_processed, vcf_df.shape[0]) - - # Free up memory - del vcf_df - - # Open a new VCF file for writing - logging.info("Writing merged VCF file...") - merged_vcf = os.path.splitext(vcf_file_path)[0] + suffix + '.vcf' - - total_records = merged_records.shape[0] - logging.info("Writing %d records to merged VCF file...", total_records) - - merge_count = 0 - index_start = 0 - with open(merged_vcf, 'w', encoding='utf-8') as merged_vcf_file: - - # Write the VCF header to the merged VCF file - with open(vcf_file_path, 'r', encoding='utf-8') as vcf_file: - for line in vcf_file: - if line.startswith('#'): - merged_vcf_file.write(line) - else: - break - - # Read the next 1000 records from the original VCF file - logging.info("Reading a chunk of 1000 records from the original VCF file...") - for chunk in pd.read_csv(vcf_file_path, sep='\t', comment='#', header=None, \ - names=['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'SAMPLE'], \ - dtype={'CHROM': str, 'POS': np.int64, 'ID': str, 'REF': str, 'ALT': str, 'QUAL': str, \ - 'FILTER': str, 'INFO': str, 'FORMAT': str, 'SAMPLE': str}, \ - chunksize=1000): - - # Add an INDEX column to the chunk - chunk.insert(0, 'INDEX', range(index_start, index_start + chunk.shape[0])) - index_start += chunk.shape[0] - - # Merge on INDEX, and use all information from the original VCF file - # (chunk) but update the INFO field with the merged INFO field. - # This is done by dropping the INFO column from the chunk so that - # the INFO column from the merged_records dataframe is used. - matching_records = pd.merge(chunk.drop(columns=['INFO']), merged_records[['INDEX', 'INFO']], on=['INDEX'], how='inner') - matching_records = matching_records.drop_duplicates(subset=['INDEX']) # Drop duplicate records - matching_records = matching_records.drop(columns=['INDEX']) # Drop the INDEX column - - # Remove the matching records from the merged records dataframe - merged_records = merged_records[~merged_records.isin(matching_records)].dropna() - - # Write the matching records to the merged VCF file - for _, matching_record in matching_records.iterrows(): - merge_count += 1 - merged_vcf_file.write(f"{matching_record['CHROM']}\t{matching_record['POS']}\t{matching_record['ID']}\t{matching_record['REF']}\t{matching_record['ALT']}\t{matching_record['QUAL']}\t{matching_record['FILTER']}\t{matching_record['INFO']}\t{matching_record['FORMAT']}\t{matching_record['SAMPLE']}\n") - - logging.info("Wrote %d of %d total records to merged VCF file...", merge_count, total_records) - - logging.info("Merged VCF file written to %s", merged_vcf) - - return merged_vcf - -if __name__ == '__main__': - import sys - if len(sys.argv) < 2: - logging.info("Usage: %s ", sys.argv[0]) - sys.exit(1) - - # Get the VCF file path from the command line - vcf_file_path = sys.argv[1] - - # Check if the file exists - if not os.path.exists(vcf_file_path): - logging.error("Error: %s not found.", vcf_file_path) - sys.exit(1) - - # Get the minimum cluster size from the command line - if len(sys.argv) > 2: - cluster_size_min = int(sys.argv[2]) - else: - cluster_size_min = 2 - - # Get the suffix from the command line - suffix = '.merged' - if len(sys.argv) > 3: - suffix += sys.argv[3] - - # DBSCAN - sv_merger(vcf_file_path, cluster_size_min=cluster_size_min, suffix=suffix) - diff --git a/python/train_model.py b/python/train_model.py deleted file mode 100644 index 1e161749..00000000 --- a/python/train_model.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -train_model.py - Train the binary classification model. - -This script trains the binary classification model using the true positive and -false positive data. The true positive data is obtained from a benchmarking -dataset. The false positive data is obtained from running the caller on data -that is known to be negative for SVs. This data can be obtained by running the -caller on a normal sample with known SVs accounted for in the reference genome. - -For example for HG002, the true positive data is obtained from the Genome in a -Bottle benchmarking dataset, and the false positive data is obtained from -running the caller on the HG002 normal sample and extracting the SV calls that -are not in the benchmarking dataset. This can be repeated for other samples such -as HG001 and HG005 as long as the known SVs are accounted for. - -In the HG002 SV v0.6 dataset, there are low-confidence regions which -are excluded from the true positive data. Thus, we must include true SVs from -other publicly available normal samples with information from complex regions, -such as those aligned to CHM13. - -The model is trained using logistic regression. The features are the LRR and -BAF values. The labels are 1 for true positives and 0 for false positives. - -The model is saved to the output directory as a pickle file. - -Usage: - python train_model.py - - - true_positives_filepath: Path to the VCF of true positive SV calls obtained - from a benchmarking dataset. - false_positives_filepath: Path to the VCF of false positive SV calls - obtained from running the caller on data that is known to be negative - for SVs. This data can be obtained by running the caller on a normal - sample with known SVs accounted for in the reference genome. - - output_directory: Path to the output directory. - -Output: - model.pkl: The binary classification model. - -Example: - python train_model.py data/sv_scoring_dataset/true_positives.vcf - sv_scoring_dataset/false_positives.vcf data/sv_scoring_dataset/model -""" - -import os -import sys -import logging -import numpy as np -import joblib -import pandas as pd -from sklearn.linear_model import LogisticRegression -import matplotlib.pyplot as plt - -from extract_features import extract_features - -# Set up the logger. -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - - -def train(true_positives_filepath, false_positives_filepath): - """Train the binary classification model.""" - - # Extract the features from the VCF files. - logging.info('Extracting features from the true positive VCF file.') - tp_data = extract_features(true_positives_filepath) - - # Check if any features are missing. - if tp_data.isnull().values.any(): - logging.error('Features are missing.') - - # Get the rows with missing features. - missing_features = tp_data[tp_data.isnull().any(axis=1)] - - # Print the rows with missing features. - logging.error(missing_features) - sys.exit(1) - - logging.info('Extracting features from the false positive VCF file.') - fp_data = extract_features(false_positives_filepath) - - # Check if any features are missing. - if fp_data.isnull().values.any(): - logging.error('Features are missing.') - - # Get the rows with missing features. - missing_features = fp_data[fp_data.isnull().any(axis=1)] - - # Print the rows with missing features. - logging.error(missing_features) - sys.exit(1) - - # Add the labels. - tp_data['label'] = 1 - fp_data['label'] = 0 - - # Print the number of true positives and false positives. - logging.info('Number of true labels: %d', tp_data.shape[0]) - logging.info('Number of false labels: %d', fp_data.shape[0]) - - # Combine the true positive and false positive data. - data = pd.concat([tp_data, fp_data]) - - # Get the features and labels. - features = data[["chrom", "start", "sv_length", "sv_type", "read_support", "clipped_bases"]] - labels = data["label"] - - # Check if any features are missing. - if features.isnull().values.any(): - logging.error('Features are missing.') - - # Get the rows with missing features. - missing_features = features[features.isnull().any(axis=1)] - - # Print the rows with missing features. - logging.error(missing_features) - sys.exit(1) - - # Check if any labels are missing. - if labels.isnull().values.any(): - logging.error('Labels are missing.') - sys.exit(1) - - # Train the model. - model = LogisticRegression() - model.fit(features, labels) - - # Return the model. - return model - -# Run the program. -def run(true_positives_filepath, false_positives_filepath, output_directory): - """Run the program.""" - # Train the model. - model = train(true_positives_filepath, false_positives_filepath) - - # Create the output directory if it does not exist. - if not os.path.exists(output_directory): - os.makedirs(output_directory) - - # Save the model - model_path = os.path.join(output_directory, "model.pkl") - joblib.dump(model, model_path) - - # Print the model. - print(model) - - # Return the model. - # return model - - -if __name__ == '__main__': - # Get the command line arguments. - if len(sys.argv) != 4: - logging.error('Usage: python train_model.py \n') - sys.exit(1) - - # Input VCF of true positive SV calls obtained from a benchmarking dataset. - tp_filepath = sys.argv[1] - - # Input VCF of false positive SV calls obtained from running the caller on - # data that is known to be negative for SVs. This data can be obtained by - # running the caller on a normal sample with known SVs accounted for in the - # reference genome. - fp_filepath = sys.argv[2] - output_dir = sys.argv[3] - - # Run the program. - logging.info('Training the model...') - run(tp_filepath, fp_filepath, output_dir) - logging.info('done.') diff --git a/python/utils.py b/python/utils.py deleted file mode 100644 index 9176dc9f..00000000 --- a/python/utils.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Utility functions for genome data analysis.""" - -def parse_region(region): - """Parse a region string into its chromosome and start and end positions.""" - region_parts = region.split(":") - chromosome = str(region_parts[0]) - - try: - start_position = int(region_parts[1].split("-")[0]) - end_position = int(region_parts[1].split("-")[1]) - except IndexError: - start_position, end_position = None, None - - return chromosome, start_position, end_position - -def get_info_field_column(vcf_data): - """Return the column index of the INFO field in a VCF file.""" - index = vcf_data.apply(lambda col: col.astype(str).str.contains("SVTYPE=").any(), axis=0).idxmax() - return index - -def get_info_field_value(info_field, field_name): - """ - Get the value of a field in the INFO field of a VCF file. - - Args: - info_field (str): The INFO field. - field_name (str): The name of the field to get the value of. - - Returns: - str: The value of the field. - """ - - # Split the INFO field into its parts. - info_field_parts = info_field.split(";") - - # Get the field value. - field_value = "" - for info_field_part in info_field_parts: - if info_field_part.startswith("{}=".format(field_name)): - field_value = info_field_part.split("=")[1] - break - - # Return the field value. - return field_value diff --git a/src/cnv_caller.cpp b/src/cnv_caller.cpp index c91d1d31..35bdbb73 100644 --- a/src/cnv_caller.cpp +++ b/src/cnv_caller.cpp @@ -19,6 +19,7 @@ #include #include // Progress bar #include // std::iota +#include #include #include #include @@ -52,8 +53,8 @@ void CNVCaller::runViterbi(const CHMM& hmm, SNPData& snp_data, std::pair& pos_depth_map, double mean_chr_cov, SNPData& snp_data, const InputData& input_data) const { - // Initialize the SNP data with default values and sample size length - int sample_size = input_data.getSampleSize(); + // Initialize SNP sampling using recommended fixed sample size + int sample_size = 20; std::vector snp_pos; std::unordered_map snp_baf_map; std::unordered_map snp_pfb_map; @@ -61,8 +62,10 @@ void CNVCaller::querySNPRegion(std::string chr, uint32_t start_pos, uint32_t end this->readSNPAlleleFrequencies(chr, start_pos, end_pos, snp_pos, snp_baf_map, snp_pfb_map, input_data); // Get the log2 ratio for evenly spaced positions in the - // region - sample_size = std::max((int) snp_pos.size(), sample_size); + // region. Scale sample size with region length to ensure sufficient + // observations for large SVs (minimum 1 observation per 1kb for better resolution) + int region_based_samples = (int)((end_pos - start_pos + 1) / 1000); + sample_size = std::max({(int) snp_pos.size(), sample_size, region_based_samples}); // Print an error if the end position is less than or equal to the start // position @@ -74,28 +77,47 @@ void CNVCaller::querySNPRegion(std::string chr, uint32_t start_pos, uint32_t end // Loop through evenly spaced positions in the region and get the log2 ratio double pos_step = static_cast(end_pos - start_pos + 1) / static_cast(sample_size); - std::unordered_map window_log2_map; + size_t depth_map_size = pos_depth_map.size(); // Cache size for bounds checking + + // Keep windows in deterministic genomic order and avoid string key conversions + std::vector window_starts; + std::vector window_ends; + std::vector window_log2; + window_starts.reserve(sample_size); + window_ends.reserve(sample_size); + window_log2.reserve(sample_size); + for (int i = 0; i < sample_size; i++) { - uint32_t window_start = (uint32_t) (start_pos + i * pos_step); - uint32_t window_end = (uint32_t) (start_pos + (i + 1) * pos_step); + uint32_t window_start = static_cast(start_pos + i * pos_step); + uint32_t window_end = static_cast(start_pos + (i + 1) * pos_step); + + if (window_start > end_pos) + { + window_start = end_pos; + } + if (window_end > end_pos) + { + window_end = end_pos; + } + if (window_end < window_start) + { + window_end = window_start; + } // Calculate the mean depth for the window double cov_sum = 0.0; int pos_count = 0; - for (int j = 0; j < pos_step; j++) + if (depth_map_size > 0 && window_start < depth_map_size) { - uint32_t pos = (uint32_t) (start_pos + i * pos_step + j); - if (pos > end_pos) + uint32_t bounded_end = std::min(window_end, static_cast(depth_map_size - 1)); + for (uint32_t pos = window_start; pos <= bounded_end; pos++) { - break; - } - if (pos < pos_depth_map.size()) { cov_sum += pos_depth_map[pos]; pos_count++; } - } + double log2_cov = 0.0; if (pos_count > 0) { @@ -104,12 +126,12 @@ void CNVCaller::querySNPRegion(std::string chr, uint32_t start_pos, uint32_t end // Use a small value to avoid division by zero cov_sum = 1e-9; } - log2_cov = log2((cov_sum / (double) pos_count) / mean_chr_cov); + log2_cov = log2((cov_sum / static_cast(pos_count)) / mean_chr_cov); } - // Store the log2 ratio for the window - std::string window_key = std::to_string(window_start) + "-" + std::to_string(window_end); - window_log2_map[window_key] = log2_cov; + window_starts.push_back(window_start); + window_ends.push_back(window_end); + window_log2.push_back(log2_cov); } // Create new vectors for the SNP data @@ -119,28 +141,62 @@ void CNVCaller::querySNPRegion(std::string chr, uint32_t start_pos, uint32_t end std::vector snp_log2_hmm; std::vector is_snp_hmm; - // Loop through the window ranges and append all SNPs in the range, using - // the log2 ratio for the window - for (const auto& window : window_log2_map) + size_t reserve_hint = std::max(static_cast(sample_size), snp_pos.size()); + snp_pos_hmm.reserve(reserve_hint); + snp_baf_hmm.reserve(reserve_hint); + snp_pfb_hmm.reserve(reserve_hint); + snp_log2_hmm.reserve(reserve_hint); + is_snp_hmm.reserve(reserve_hint); + + // Loop through the window ranges and append SNPs in each range, using + // the log2 ratio for the window. Use a two-pointer scan to avoid + // O(num_windows * num_snps) behavior. + size_t snp_idx = 0; + for (size_t w = 0; w < window_starts.size(); w++) { - uint32_t window_start = std::stoi(window.first.substr(0, window.first.find('-'))); - uint32_t window_end = std::stoi(window.first.substr(window.first.find('-') + 1)); - double log2_cov = window.second; + uint32_t window_start = window_starts[w]; + uint32_t window_end = window_ends[w]; + double log2_cov = window_log2[w]; + + while (snp_idx < snp_pos.size() && snp_pos[snp_idx] < window_start) + { + snp_idx++; + } - // Loop through the SNP positions and add them to the SNP data bool snp_found = false; - for (uint32_t pos : snp_pos) + size_t local_idx = snp_idx; + while (local_idx < snp_pos.size() && snp_pos[local_idx] <= window_end) { - if (pos >= window_start && pos <= window_end) + uint32_t pos = snp_pos[local_idx]; + double baf = -1.0; + double pfb = 0.5; + + auto baf_it = snp_baf_map.find(pos); + if (baf_it != snp_baf_map.end()) { - snp_pos_hmm.push_back(pos); - snp_baf_hmm.push_back(snp_baf_map[pos]); - snp_pfb_hmm.push_back(snp_pfb_map[pos]); - snp_log2_hmm.push_back(log2_cov); - is_snp_hmm.push_back(true); - snp_found = true; + baf = baf_it->second; } + + auto pfb_it = snp_pfb_map.find(pos); + if (pfb_it != snp_pfb_map.end()) + { + pfb = pfb_it->second; + } + + snp_pos_hmm.push_back(pos); + snp_baf_hmm.push_back(baf); + snp_pfb_hmm.push_back(pfb); + snp_log2_hmm.push_back(log2_cov); + is_snp_hmm.push_back(true); + snp_found = true; + local_idx++; } + + if (snp_found) + { + snp_idx = local_idx; + } + if (!snp_found) { // If no SNPs were found in the window, add a dummy SNP with the @@ -210,14 +266,23 @@ std::tuple CNVCaller::runCopyNumberPrediction(std std::vector& state_sequence = prediction.first; double likelihood = prediction.second; - // Get state percentages + // Get state percentages (single pass) + std::array state_counts = {0, 0, 0, 0, 0, 0, 0}; + for (int state : state_sequence) + { + if (state >= 1 && state <= 6) + { + state_counts[state]++; + } + } + std::unordered_map state_pct; - double state_count = (double) state_sequence.size(); + double state_count = static_cast(state_sequence.size()); double largest_non_neutral_pct = 0.0; int non_neutral_state = 0; for (int i = 0; i < 6; i++) { - state_pct[i+1] = (double)std::count(state_sequence.begin(), state_sequence.end(), i+1) / state_count; + state_pct[i+1] = static_cast(state_counts[i+1]) / state_count; if (i+1 != 3 && state_pct[i+1] > largest_non_neutral_pct) { largest_non_neutral_pct = state_pct[i+1]; @@ -226,7 +291,9 @@ std::tuple CNVCaller::runCopyNumberPrediction(std } // Use the state exceeding the threshold if non-neutral - double pct_threshold = 0.3; + // Adaptive threshold: regions >5kb are noisier due to coverage fragmentation + uint32_t region_length = end_pos - start_pos; + double pct_threshold = (region_length > 5000) ? 0.25 : 0.3; int max_state = 0; // Unknown state if (largest_non_neutral_pct > pct_threshold) { @@ -242,7 +309,7 @@ std::tuple CNVCaller::runCopyNumberPrediction(std SVType predicted_cnv_type = getSVTypeFromCNState(max_state); // Save the SV calls if enabled - uint32_t min_length = 30000; + uint32_t min_length = 10000; // Lowered from 30kb to include 10-30kb SVs bool copy_number_change = (predicted_cnv_type != SVType::UNKNOWN && predicted_cnv_type != SVType::NEUTRAL); if (input_data.getSaveCNVData() && copy_number_change && (end_pos - start_pos) >= min_length) { @@ -303,6 +370,7 @@ void CNVCaller::runCIGARCopyNumberPrediction(std::string chr, std::vector end if (start_pos > end_pos) @@ -311,8 +379,8 @@ void CNVCaller::runCIGARCopyNumberPrediction(std::string chr, std::vector& state_sequence = prediction.first; double likelihood = prediction.second; - // Get all the states in the SV region - std::vector sv_states; + // Get state counts in the SV region (single pass) + std::array sv_state_counts = {0, 0, 0, 0, 0, 0, 0}; + int state_count = 0; for (size_t i = 0; i < state_sequence.size(); i++) { if (snp_data.pos[i] >= start_pos && snp_data.pos[i] <= end_pos) { - sv_states.push_back(state_sequence[i]); + int state = state_sequence[i]; + if (state >= 1 && state <= 6) + { + sv_state_counts[state]++; + state_count++; + } } } + if (state_count == 0) + { + continue; + } + // Determine if there is a majority state within the SV region and if it // is greater than 50% int max_state = 0; int max_count = 0; for (int i = 0; i < 6; i++) { - int state_count = std::count(sv_states.begin(), sv_states.end(), i+1); - if (state_count > max_count) + int count_i = sv_state_counts[i+1]; + if (count_i > max_count) { max_state = i+1; - max_count = state_count; + max_count = count_i; } } - // If there is no majority state, then set the state to unknown + // If there is no majority state, then set the state to unknown. + // Use stricter HMM majority for INS->DUP conversion in 10-50kb, + // where depth-driven relabeling is noisier. double pct_threshold = 0.50; - int state_count = (int) sv_states.size(); + if (sv_call.sv_type == SVType::INS && + (max_state == 5 || max_state == 6) && + sv_length >= 10000 && sv_length <= 50000) + { + pct_threshold = 0.65; + } if ((double) max_count / (double) state_count < pct_threshold) { max_state = 0; @@ -374,6 +460,14 @@ void CNVCaller::runCIGARCopyNumberPrediction(std::string chr, std::vector CNVCaller::splitRegionIntoChunks(std::string chr, uint3 void CNVCaller::calculateMeanChromosomeCoverage(const std::vector& chromosomes, std::unordered_map>& chr_pos_depth_map, std::unordered_map& chr_mean_cov_map, const std::string& bam_filepath, int thread_count) const { // Open the BAM file - printMessage("Opening BAM file: " + bam_filepath); samFile *bam_file = sam_open(bam_filepath.c_str(), "r"); if (!bam_file) { @@ -458,6 +551,7 @@ void CNVCaller::calculateMeanChromosomeCoverage(const std::vector& } // Iterate through each chromosome and update the depth map + printMessage("Calculating mean chromosome coverage for copy number prediction..."); int current_chr = 0; int total_chr_count = chromosomes.size(); for (const std::string& chr : chromosomes) @@ -470,7 +564,7 @@ void CNVCaller::calculateMeanChromosomeCoverage(const std::vector& continue; } - printMessage("(" + std::to_string(++current_chr) + "/" + std::to_string(total_chr_count) + ") Reading BAM file for chromosome: " + chr); + printMessage("(" + std::to_string(++current_chr) + "/" + std::to_string(total_chr_count) + ") Processing chromosome: " + chr); std::vector& pos_depth_map = chr_pos_depth_map[chr]; int tid = bam_name2id(bam_header, chr.c_str()); if (tid < 0) @@ -485,6 +579,11 @@ void CNVCaller::calculateMeanChromosomeCoverage(const std::vector& printError("ERROR: Chromosome length mismatch for " + chr + ": expected " + std::to_string(chr_length) + ", found " + std::to_string(pos_depth_map.size()) + ", resizing to " + std::to_string(chr_length)); pos_depth_map.resize(chr_length, 0); } + + // Difference-array depth accumulation: O(#CIGAR ops + chr_len) + // instead of O(total aligned bases) + std::vector depth_delta(pos_depth_map.size() + 1, 0); + while (sam_itr_next(bam_file, bam_iter, bam_record) >= 0) { // Ignore UNMAP, SECONDARY, QCFAIL, and DUP reads @@ -506,15 +605,17 @@ void CNVCaller::calculateMeanChromosomeCoverage(const std::vector& uint32_t op_len = bam_cigar_oplen(cigar[i]); if (op == BAM_CMATCH || op == BAM_CEQUAL || op == BAM_CDIFF) { - // Update the depth for each position in the alignment - for (uint32_t j = 0; j < op_len; j++) + if (ref_pos < pos_depth_map.size()) { - if (ref_pos + j >= pos_depth_map.size()) + uint32_t start_cov = ref_pos; + uint64_t end_cov_u64 = static_cast(ref_pos) + static_cast(op_len) - 1; + uint32_t end_cov = static_cast(std::min(end_cov_u64, static_cast(pos_depth_map.size() - 1))); + + depth_delta[start_cov] += 1; + if (static_cast(end_cov + 1) < depth_delta.size()) { - printError("ERROR: Reference position out of range for " + chr + ":" + std::to_string(ref_pos+j)); - continue; + depth_delta[end_cov + 1] -= 1; } - pos_depth_map[ref_pos + j]++; } } @@ -531,19 +632,38 @@ void CNVCaller::calculateMeanChromosomeCoverage(const std::vector& } hts_itr_destroy(bam_iter); - uint64_t cum_depth = std::accumulate(pos_depth_map.begin(), pos_depth_map.end(), 0ULL); - uint32_t pos_count = std::count_if(pos_depth_map.begin(), pos_depth_map.end(), [](uint32_t depth) { return depth > 0; }); + uint64_t cum_depth = 0; + uint32_t pos_count = 0; + int64_t running_depth = 0; + if (!pos_depth_map.empty()) + { + pos_depth_map[0] = 0; + } + for (size_t pos = 1; pos < pos_depth_map.size(); pos++) + { + running_depth += depth_delta[pos]; + if (running_depth < 0) + { + running_depth = 0; + } + + uint32_t depth = static_cast(running_depth); + pos_depth_map[pos] = depth; + cum_depth += depth; + if (depth > 0) + { + pos_count++; + } + } // Calculate the mean coverage for the chromosome double mean_chr_cov = (pos_count > 0) ? static_cast(cum_depth) / static_cast(pos_count) : 0.0; - printMessage("Mean coverage for chromosome " + chr + ": " + std::to_string(mean_chr_cov)); if (mean_chr_cov != 0.0) { chr_mean_cov_map[chr] = mean_chr_cov; } } // Clean up the BAM file and index - printMessage("Closing BAM file " + bam_filepath); bam_destroy1(bam_record); hts_idx_destroy(bam_index); bam_hdr_destroy(bam_header); @@ -552,13 +672,24 @@ void CNVCaller::calculateMeanChromosomeCoverage(const std::vector& bam_index = nullptr; bam_header = nullptr; bam_file = nullptr; - printMessage("BAM file closed."); } void CNVCaller::readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, uint32_t end_pos, std::vector& snp_pos, std::unordered_map& snp_baf, std::unordered_map& snp_pfb, const InputData& input_data) const { - // Lock during reading - std::shared_lock lock(this->shared_mutex); + struct ReaderCache { + bcf_srs_t* reader = nullptr; + std::string filepath; + int thread_count = -1; + + ~ReaderCache() { + if (reader) { + bcf_sr_destroy(reader); + } + } + }; + + thread_local ReaderCache snp_cache; + thread_local ReaderCache pfb_cache; // --------- SNP file --------- const std::string snp_filepath = input_data.getSNPFilepath(); @@ -568,24 +699,52 @@ void CNVCaller::readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, ui return; } - // Initialize the SNP file reader - bcf_srs_t *snp_reader = bcf_sr_init(); - if (!snp_reader) - { - printError("ERROR: Could not initialize SNP reader."); - return; - } - snp_reader->require_index = 1; - - // Use multi-threading if not threading by chromosome int thread_count = input_data.getThreadCount(); - bcf_sr_set_threads(snp_reader, thread_count); + auto get_cached_reader = [&](ReaderCache& cache, const std::string& filepath, const std::string& label) -> bcf_srs_t* { + if (filepath.empty()) { + return nullptr; + } - // Add the SNP file to the reader - if (bcf_sr_add_reader(snp_reader, snp_filepath.c_str()) < 0) + bool needs_reload = (cache.reader == nullptr) || (cache.filepath != filepath) || (cache.thread_count != thread_count); + if (needs_reload) + { + if (cache.reader) + { + bcf_sr_destroy(cache.reader); + cache.reader = nullptr; + } + + cache.reader = bcf_sr_init(); + if (!cache.reader) + { + printError("ERROR: Could not initialize " + label + " reader."); + return nullptr; + } + cache.reader->require_index = 1; + + // Add the file to the reader + if (bcf_sr_add_reader(cache.reader, filepath.c_str()) < 0) + { + printError("ERROR: Could not add " + label + " file to reader: " + filepath); + bcf_sr_destroy(cache.reader); + cache.reader = nullptr; + cache.filepath.clear(); + cache.thread_count = -1; + return nullptr; + } + + bcf_sr_set_threads(cache.reader, thread_count); + cache.filepath = filepath; + cache.thread_count = thread_count; + } + + return cache.reader; + }; + + // Initialize/reuse the SNP file reader + bcf_srs_t *snp_reader = get_cached_reader(snp_cache, snp_filepath, "SNP"); + if (!snp_reader) { - bcf_sr_destroy(snp_reader); - printError("ERROR: Could not add SNP file to reader: " + snp_filepath); return; } @@ -608,7 +767,7 @@ void CNVCaller::readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, ui } pfb_file.close(); - bcf_srs_t *pfb_reader = bcf_sr_init(); + bcf_srs_t *pfb_reader = nullptr; std::string chr_gnomad = chr; std::string AF_key; if (use_pfb) @@ -639,31 +798,12 @@ void CNVCaller::readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, ui } } - // Initialize the population allele frequency reader + // Initialize/reuse the population allele frequency reader + pfb_reader = get_cached_reader(pfb_cache, pfb_filepath, "population allele frequency"); if (!pfb_reader) { - printError("ERROR: Could not initialize population allele frequency reader."); - - // Clean up - bcf_sr_destroy(snp_reader); - return; - } - pfb_reader->require_index = 1; - - // Add the population allele frequency file to the reader - if (bcf_sr_add_reader(pfb_reader, pfb_filepath.c_str()) < 0) - { - printError("ERROR: Could not add population allele frequency file to reader: " + pfb_filepath); - - // Clean up - bcf_sr_destroy(pfb_reader); - bcf_sr_destroy(snp_reader); - return; + use_pfb = false; } - - // Use multi-threading if not threading by chromosome - int thread_count = input_data.getThreadCount(); - bcf_sr_set_threads(pfb_reader, thread_count); } // Read the SNP data @@ -673,8 +813,6 @@ void CNVCaller::readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, ui if (bcf_sr_set_regions(snp_reader, region_str.c_str(), 0) < 0) //chr.c_str(), 0) < 0) { printError("ERROR: Could not set region for SNP reader: " + chr); - bcf_sr_destroy(snp_reader); - bcf_sr_destroy(pfb_reader); return; } @@ -706,8 +844,12 @@ void CNVCaller::readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, ui int32_t *dp = 0; int dp_count = 0; int dp_ret = bcf_get_format_int32(snp_reader->readers[0].header, snp_record, "DP", &dp, &dp_count); - if (dp_ret < 0 || dp[0] <= 10) + if (dp_ret < 0 || dp_count == 0 || dp[0] <= 10) { + if (dp) + { + free(dp); + } continue; } free(dp); @@ -724,6 +866,10 @@ void CNVCaller::readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, ui int ad_ret = bcf_get_format_int32(snp_reader->readers[0].header, snp_record, "AD", &ad, &ad_count); if (ad_ret < 0 || ad_count < 2) { + if (ad) + { + free(ad); + } continue; } @@ -746,8 +892,6 @@ void CNVCaller::readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, ui // Continue if no SNP was found in the region if (!snp_found) { - bcf_sr_destroy(snp_reader); - bcf_sr_destroy(pfb_reader); return; } @@ -798,14 +942,9 @@ void CNVCaller::readSNPAlleleFrequencies(std::string chr, uint32_t start_pos, ui continue; } snp_pfb[pfb_pos] = pfb; - break; } free(pfb_f); } - - // Clean up - bcf_sr_destroy(snp_reader); - bcf_sr_destroy(pfb_reader); } void CNVCaller::saveSVCopyNumberToJSON(SNPData &before_sv, SNPData &after_sv, SNPData &snp_data, std::string chr, uint32_t start, uint32_t end, std::string sv_type, double likelihood, const std::string& filepath) const diff --git a/src/fasta_query.cpp b/src/fasta_query.cpp index 237db443..12444376 100644 --- a/src/fasta_query.cpp +++ b/src/fasta_query.cpp @@ -15,7 +15,7 @@ #include "utils.h" -int ReferenceGenome::setFilepath(std::string fasta_filepath) +int ReferenceGenome::read(std::string fasta_filepath) { if (fasta_filepath == "") { @@ -45,10 +45,10 @@ int ReferenceGenome::setFilepath(std::string fasta_filepath) // Header line, indicating a new chromosome // Store the previous chromosome and sequence if (current_chr != "") - { - this->chromosomes.push_back(current_chr); // Add the chromosome to the list - this->chr_to_seq[current_chr] = sequence; // Add the sequence to the map - this->chr_to_length[current_chr] = sequence.length(); // Add the sequence length to the map + { + this->chromosomes.push_back(current_chr); // Add the chromosome to the list + this->chr_to_seq[current_chr] = sequence; // Add the sequence to the map + this->chr_to_length[current_chr] = sequence.length(); // Add the sequence length to the map sequence = ""; // Reset the sequence } diff --git a/src/input_data.cpp b/src/input_data.cpp index 936f7e62..51cbfe57 100644 --- a/src/input_data.cpp +++ b/src/input_data.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include "utils.h" #include "debug.h" // For DEBUG_PRINT @@ -22,8 +23,6 @@ InputData::InputData() this->snp_vcf_filepath = ""; this->chr = ""; this->output_dir = ""; - this->sample_size = 20; - this->min_cnv_length = 2000; // Default minimum CNV length this->min_reads = 5; this->dbscan_epsilon = 0.1; this->dbscan_min_pts_pct = 0.1; @@ -43,8 +42,6 @@ void InputData::printParameters() const DEBUG_PRINT("Reference genome: " << this->ref_filepath); DEBUG_PRINT("SNP VCF: " << this->snp_vcf_filepath); DEBUG_PRINT("Output directory: " << this->output_dir); - DEBUG_PRINT("Sample size: " << this->sample_size); - DEBUG_PRINT("Minimum CNV length: " << this->min_cnv_length); DEBUG_PRINT("DBSCAN epsilon: " << this->dbscan_epsilon); DEBUG_PRINT("DBSCAN minimum points percentage: " << this->dbscan_min_pts_pct * 100.0f << "%"); } @@ -72,6 +69,40 @@ void InputData::setLongReadBam(std::string filepath) } else { fclose(fp); } + + // Check if pgbam file is being used and warn user + if (filepath.find(".pgbam") != std::string::npos) + { + std::cerr << "================================================================================\n" + << "WARNING: Using PetaGene-compressed BAM file (.pgbam)\n" + << " This format does NOT support safe concurrent decompression.\n" + << " Multi-threaded access may cause CRC32 checksum errors.\n" + << "\n" + << "RECOMMENDED: Decompress the pgbam file to standard BAM format using:\n" + << " petasuite --decompress input.pgbam\n" + << "================================================================================\n"; + } + + // Check if BAM index file exists and is newer than BAM file + std::string index_filepath = filepath + ".bai"; + struct stat bam_stat, index_stat; + if (stat(filepath.c_str(), &bam_stat) == 0) + { + if (stat(index_filepath.c_str(), &index_stat) == 0) + { + if (index_stat.st_mtime < bam_stat.st_mtime) + { + std::cerr << "================================================================================\n" + << "WARNING: BAM index file is older than BAM file\n" + << " BAM: " << filepath << "\n" + << " Index: " << index_filepath << "\n" + << "\n" + << "RECOMMENDED: Rebuild the BAM index using:\n" + << " samtools index " << filepath << "\n" + << "================================================================================\n"; + } + } + } } } @@ -104,16 +135,6 @@ void InputData::setOutputDir(std::string dirpath) } } -int InputData::getSampleSize() const -{ - return this->sample_size; -} - -void InputData::setSampleSize(int sample_size) -{ - this->sample_size = sample_size; -} - std::string InputData::getSNPFilepath() const { return this->snp_vcf_filepath; @@ -162,14 +183,14 @@ std::string InputData::getAssemblyGaps() const return this->assembly_gaps; } -uint32_t InputData::getMinCNVLength() const +void InputData::setChromosome(std::string chr) { - return this->min_cnv_length; + this->chr = chr; } -void InputData::setMinCNVLength(int min_cnv_length) +std::string InputData::getChromosome() const { - this->min_cnv_length = (uint32_t) min_cnv_length; + return this->chr; } void InputData::setDBSCAN_Epsilon(double epsilon) @@ -192,22 +213,6 @@ double InputData::getDBSCAN_MinPtsPct() const return this->dbscan_min_pts_pct; } -void InputData::setChromosome(std::string chr) -{ - this->chr = chr; - this->single_chr = true; -} - -std::string InputData::getChromosome() const -{ - return this->chr; -} - -bool InputData::isSingleChr() const -{ - return this->single_chr; -} - void InputData::setAlleleFreqFilepaths(std::string filepath) { // Check if empty string @@ -341,7 +346,6 @@ void InputData::setHMMFilepath(std::string filepath) exit(1); } else { this->hmm_filepath = filepath; - std::cout << "Using HMM file: " << this->hmm_filepath << std::endl; } } } diff --git a/src/main.cpp b/src/main.cpp index 45984eb2..d60cbe16 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,5 +1,5 @@ -#include "swig_interface.h" +#include "contextsv.h" /// @cond DOXYGEN_IGNORE #include @@ -65,21 +65,12 @@ void runContextSV(const std::unordered_map& args) input_data.setRefGenome(args.at("ref-file")); input_data.setSNPFilepath(args.at("snps-file")); input_data.setOutputDir(args.at("output-dir")); - if (args.find("chr") != args.end()) { - input_data.setChromosome(args.at("chr")); - } if (args.find("thread-count") != args.end()) { input_data.setThreadCount(std::stoi(args.at("thread-count"))); } if (args.find("hmm-file") != args.end()) { input_data.setHMMFilepath(args.at("hmm-file")); } - if (args.find("sample-size") != args.end()) { - input_data.setSampleSize(std::stoi(args.at("sample-size"))); - } - if (args.find("min-cnv") != args.end()) { - input_data.setMinCNVLength(std::stoi(args.at("min-cnv"))); - } if (args.find("eth") != args.end()) { input_data.setEthnicity(args.at("eth")); } @@ -89,6 +80,9 @@ void runContextSV(const std::unordered_map& args) if (args.find("assembly-gaps") != args.end()) { input_data.setAssemblyGaps(args.at("assembly-gaps")); } + if (args.find("chr") != args.end()) { + input_data.setChromosome(args.at("chr")); + } if (args.find("save-cnv") != args.end()) { input_data.saveCNVData(true); } @@ -96,15 +90,6 @@ void runContextSV(const std::unordered_map& args) input_data.setVerbose(true); } - // DBSCAN parameters - if (args.find("epsilon") != args.end()) { - input_data.setDBSCAN_Epsilon(std::stod(args.at("epsilon"))); - } - - if (args.find("min-pts-pct") != args.end()) { - input_data.setDBSCAN_MinPtsPct(std::stod(args.at("min-pts-pct"))); - } - // Set up the CNV JSON file if enabled if (input_data.getSaveCNVData()) { const std::string output_dir = input_data.getOutputDir(); @@ -118,8 +103,42 @@ void runContextSV(const std::unordered_map& args) std::cout << "Saving CNV data to: " << json_filepath << std::endl; } + // Print all parameters being used + std::cout << "\n========================================" << std::endl; + std::cout << "ContextSV Parameters:" << std::endl; + std::cout << "========================================" << std::endl; + std::cout << "Input BAM: " << input_data.getLongReadBam() << std::endl; + std::cout << "Reference genome: " << input_data.getRefGenome() << std::endl; + std::cout << "SNP VCF file: " << input_data.getSNPFilepath() << std::endl; + std::cout << "Output directory: " << input_data.getOutputDir() << std::endl; + std::cout << "HMM file: " << input_data.getHMMFilepath() << std::endl; + std::cout << "Thread count: " << input_data.getThreadCount() << std::endl; + std::cout << "DBSCAN epsilon: " << input_data.getDBSCAN_Epsilon() << std::endl; + std::cout << "DBSCAN min pts pct: " << input_data.getDBSCAN_MinPtsPct() << std::endl; + if (!input_data.getEthnicity().empty()) { + std::cout << "Ethnicity: " << input_data.getEthnicity() << std::endl; + } + if (args.find("pfb-file") != args.end()) { + std::cout << "PFB file: " << args.at("pfb-file") << std::endl; + } + if (!input_data.getAssemblyGaps().empty()) { + std::cout << "Assembly gaps: " << input_data.getAssemblyGaps() << std::endl; + } + std::cout << "Save CNV data: " << (input_data.getSaveCNVData() ? "true" : "false") << std::endl; + std::cout << "Verbose mode: " << (input_data.getVerbose() ? "true" : "false") << std::endl; + std::cout << "========================================\n" << std::endl; + // Run ContextSV - run(input_data); + ContextSV contextsv; + try + { + contextsv.run(input_data); + } + catch (std::exception& e) + { + std::cerr << e.what() << std::endl; + exit(1); + } } void printUsage(const std::string& programName) { @@ -130,14 +149,9 @@ void printUsage(const std::string& programName) { << " -r, --ref Reference genome FASTA file (required)\n" << " -s, --snp SNPs VCF file (required)\n" << " -o, --outdir Output directory (required)\n" - << " -c, --chr Chromosome\n" << " -t, --threads Number of threads\n" << " -h, --hmm HMM file\n" - << " -n, --sample-size Sample size for HMM predictions\n" - << " --min-cnv Minimum CNV length\n" - << " --eps DBSCAN epsilon\n" - << " --min-pts-pct Percentage of mean chr. coverage to use for DBSCAN minimum points\n" - << " -e, --eth ETH file\n" + << " -e, --eth Ethnicity identifier (e.g. nfe, asj)\n" << " -p, --pfb PFB file\n" << " --assembly-gaps Assembly gaps file\n" << " --save-cnv Save CNV data\n" @@ -168,16 +182,6 @@ std::unordered_map parseArguments(int argc, char* argv args["thread-count"] = argv[++i]; } else if ((arg == "-h" || arg == "--hmm") && i + 1 < argc) { args["hmm-file"] = argv[++i]; - } else if ((arg == "-n" || arg == "--sample-size") && i + 1 < argc) { - args["sample-size"] = argv[++i]; - } else if (arg == "--min-cnv" && i + 1 < argc) { - args["min-cnv"] = argv[++i]; - } else if (arg == "--min-reads" && i + 1 < argc) { - args["min-reads"] = argv[++i]; - } else if (arg == "--eps" && i + 1 < argc) { - args["epsilon"] = argv[++i]; - } else if (arg == "--min-pts-pct" && i + 1 < argc) { - args["min-pts-pct"] = argv[++i]; } else if ((arg == "-e" || arg == "--eth") && i + 1 < argc) { args["eth"] = argv[++i]; } else if ((arg == "-p" || arg == "--pfb") && i + 1 < argc) { diff --git a/src/sv_caller.cpp b/src/sv_caller.cpp index d59e6598..e9fbebeb 100644 --- a/src/sv_caller.cpp +++ b/src/sv_caller.cpp @@ -37,9 +37,7 @@ int SVCaller::readNextAlignment(samFile *fp_in, hts_itr_t *itr, bam1_t *bam1) { - std::shared_lock lock(this->shared_mutex); - int ret = sam_itr_next(fp_in, itr, bam1); - return ret; + return sam_itr_next(fp_in, itr, bam1); } std::vector SVCaller::getChromosomes(const std::string &bam_filepath) @@ -78,7 +76,6 @@ void SVCaller::findSplitSVSignatures(std::unordered_map> primary_map; // TID-> qname -> primary alignment - std::unordered_map> supp_map; // qname -> supplementary alignment - bam1_t *bam1 = bam_init1(); if (!bam1) { printError("ERROR: failed to initialize BAM record"); + bam_hdr_destroy(bamHdr); + hts_idx_destroy(idx); + sam_close(fp_in); return; } - - // Set the region to the whole genome, or a user-specified chromosome - hts_itr_t *itr = nullptr; - if (input_data.isSingleChr()) { - std::string chr = input_data.getChromosome(); - itr = sam_itr_querys(idx, bamHdr, chr.c_str()); - if (!itr) { + // Build chromosome list (single chromosome if requested) + std::vector chromosomes; + const std::string target_chr = input_data.getChromosome(); + if (!target_chr.empty()) { + if (bam_name2id(bamHdr, target_chr.c_str()) < 0) { + printError("ERROR: Requested chromosome " + target_chr + " not found in BAM header"); bam_destroy1(bam1); - printError("ERROR: failed to create iterator for " + chr); + bam_hdr_destroy(bamHdr); + hts_idx_destroy(idx); + sam_close(fp_in); return; } + chromosomes.push_back(target_chr); } else { - itr = sam_itr_queryi(idx, HTS_IDX_START, 0, 0); - if (!itr) { - bam_destroy1(bam1); - printError("ERROR: failed to create iterator for the whole genome"); - return; + chromosomes.reserve(static_cast(bamHdr->n_targets)); + for (int i = 0; i < bamHdr->n_targets; i++) { + chromosomes.push_back(bamHdr->target_name[i]); } } - uint32_t primary_count = 0; - uint32_t supplementary_count = 0; + int current_chr = 0; + int total_chr = static_cast(chromosomes.size()); - // Main loop to process the alignments - printMessage("Processing alignments from " + bam_filepath); - uint32_t num_alignments = 0; - std::unordered_set alignment_tids; // All unique chromosome IDs - std::unordered_set supp_qnames; // All unique query names - while (readNextAlignment(fp_in, itr, bam1) >= 0) { + for (const auto& chr_name : chromosomes) { + current_chr++; + int primary_tid = bam_name2id(bamHdr, chr_name.c_str()); + if (primary_tid < 0) { + printError("ERROR: Chromosome " + chr_name + " not found in BAM header"); + continue; + } + + // Per-chromosome maps to avoid whole-genome materialization + std::unordered_map chr_primary_map; + std::unordered_map> supp_map; + std::unordered_set supp_qnames; - // Skip secondary and unmapped alignments, duplicates, QC failures, and low mapping quality - if (bam1->core.flag & BAM_FSECONDARY || bam1->core.flag & BAM_FUNMAP || bam1->core.flag & BAM_FDUP || bam1->core.flag & BAM_FQCFAIL || bam1->core.qual < this->min_mapq) { + hts_itr_t* itr = sam_itr_querys(idx, bamHdr, chr_name.c_str()); + if (!itr) { + printError("ERROR: failed to query chromosome " + chr_name); continue; } - const std::string qname = bam_get_qname(bam1); // Query template name - // Process primary alignments - if (!(bam1->core.flag & BAM_FSUPPLEMENTARY)) { - // Store chromosome (TID), start, and end positions (1-based) of the - // primary alignment, and the strand (true for forward, false for - // reverse) - std::pair qpos = getAlignmentReadPositions(bam1); + uint32_t num_alignments = 0; + while (readNextAlignment(fp_in, itr, bam1) >= 0) { - primary_map[bam1->core.tid][qname] = PrimaryAlignment{static_cast(bam1->core.pos + 1), static_cast(bam_endpos(bam1)), static_cast(qpos.first), static_cast(qpos.second), !(bam1->core.flag & BAM_FREVERSE), 0}; - alignment_tids.insert(bam1->core.tid); - primary_count++; + // Skip secondary and unmapped alignments, duplicates, QC failures, and low mapping quality + if (bam1->core.flag & BAM_FSECONDARY || bam1->core.flag & BAM_FUNMAP || bam1->core.flag & BAM_FDUP || bam1->core.flag & BAM_FQCFAIL || bam1->core.qual < this->min_mapq) { + continue; + } + const std::string qname = bam_get_qname(bam1); // Query template name - // Process supplementary alignments - } else if (bam1->core.flag & BAM_FSUPPLEMENTARY) { - // Store chromosome (TID), start, and end positions (1-based) of the - // supplementary alignment, and the strand (true for forward, false - // for reverse) std::pair qpos = getAlignmentReadPositions(bam1); - supp_map[qname].push_back(SuppAlignment{bam1->core.tid, static_cast(bam1->core.pos + 1), static_cast(bam_endpos(bam1)), static_cast(qpos.first), static_cast(qpos.second), !(bam1->core.flag & BAM_FREVERSE)}); - alignment_tids.insert(bam1->core.tid); - supp_qnames.insert(qname); - supplementary_count++; - } - num_alignments++; - if (num_alignments % 1000000 == 0) { - printMessage("Processed " + std::to_string(num_alignments) + " alignments"); - } - } + // Process primary alignments + if (!(bam1->core.flag & BAM_FSUPPLEMENTARY)) { + chr_primary_map[qname] = PrimaryAlignment{static_cast(bam1->core.pos + 1), static_cast(bam_endpos(bam1)), static_cast(qpos.first), static_cast(qpos.second), !(bam1->core.flag & BAM_FREVERSE), 0}; - // Clean up the iterator and alignment - hts_itr_destroy(itr); - bam_destroy1(bam1); - - // Clean up the BAM file and index - sam_close(fp_in); - hts_idx_destroy(idx); - // bam_hdr_destroy(bamHdr); - - // Remove primary alignments without supplementary alignments - std::unordered_map> to_remove; - for (auto& chr_primary : primary_map) { - std::unordered_set qnames; - for (const auto& entry : chr_primary.second) { - if (supp_qnames.find(entry.first) == supp_qnames.end()) { - to_remove[chr_primary.first].insert(entry.first); + // Process supplementary alignments + } else { + supp_map[qname].push_back(SuppAlignment{bam1->core.tid, static_cast(bam1->core.pos + 1), static_cast(bam_endpos(bam1)), static_cast(qpos.first), static_cast(qpos.second), !(bam1->core.flag & BAM_FREVERSE)}); + supp_qnames.insert(qname); } + num_alignments++; } - } - - int total_removed = 0; - for (auto& chr_primary : primary_map) { - // Remove the qnames from the primary map - total_removed += to_remove[chr_primary.first].size(); - for (const auto& qname : to_remove[chr_primary.first]) { - chr_primary.second.erase(qname); + hts_itr_destroy(itr); + + // Remove primary alignments without supplementary alignments + int removed = 0; + for (auto it = chr_primary_map.begin(); it != chr_primary_map.end();) { + if (supp_qnames.find(it->first) == supp_qnames.end()) { + it = chr_primary_map.erase(it); + removed++; + } else { + ++it; + } } - } - printMessage("Removed " + std::to_string(total_removed) + " primary alignments without supplementary alignments"); - // Process the primary alignments and find SVs - for (const auto& chr_primary : primary_map) { - int primary_tid = chr_primary.first; - std::string chr_name = bamHdr->target_name[primary_tid]; - printMessage("Processing chromosome " + chr_name + " with " + std::to_string(chr_primary.second.size()) + " primary alignments"); + printMessage("(" + std::to_string(current_chr) + "/" + std::to_string(total_chr) + ") Processing " + chr_name + " (" + std::to_string(chr_primary_map.size()) + " primary alignments)"); + + if (chr_primary_map.empty()) { + continue; + } std::vector chr_sv_calls; chr_sv_calls.reserve(1000); - const std::unordered_map& chr_primary_map = chr_primary.second; // Identify overlapping primary alignments and cluster endpoints std::unique_ptr root = nullptr; @@ -240,8 +217,9 @@ void SVCaller::findSplitSVSignatures(std::unordered_map& supp_alns = supp_map[qname]; + auto supp_it = supp_map.find(qname); + if (supp_it == supp_map.end()) { + continue; + } + const std::vector& supp_alns = supp_it->second; bool primary_strand = chr_primary_map.at(qname).strand; bool has_opposite_strand = false; for (const SuppAlignment& supp_aln : supp_alns) { @@ -262,7 +244,13 @@ void SVCaller::findSplitSVSignatures(std::unordered_map(num_supp_opposite_strand) / static_cast(num_primary) > 0.5) { + double opposite_strand_ratio = (num_primary > 0) + ? static_cast(num_supp_opposite_strand) / static_cast(num_primary) + : 0.0; + + // Classify inversion when opposite-strand support is moderate-to-strong. + // This avoids missing true mid-size inversions that can have mixed strand evidence. + if (num_primary >= 3 && num_supp_opposite_strand >= 2 && opposite_strand_ratio >= 0.5) { inversion = true; } @@ -301,9 +289,21 @@ void SVCaller::findSplitSVSignatures(std::unordered_map ref_distances; for (const std::string& qname : primary_cluster) { const PrimaryAlignment& primary_aln = chr_primary_map.at(qname); - const std::vector& supp_alns = supp_map.at(qname); + auto supp_it = supp_map.find(qname); + if (supp_it == supp_map.end()) { + continue; + } + const std::vector& supp_alns = supp_it->second; for (const SuppAlignment& supp_aln : supp_alns) { if (supp_aln.tid == primary_tid) { + bool is_opposite_strand = supp_aln.strand != primary_aln.strand; + + // For inversion clusters, only keep opposite-strand supplementary + // alignments to avoid contaminating inversion breakpoint evidence. + if (inversion && !is_opposite_strand) { + continue; + } + // Same chromosome int read_distance = 0; int ref_distance = 0; @@ -450,13 +450,14 @@ void SVCaller::findSplitSVSignatures(std::unordered_map(SVDataType::SPLITDIST1)); if (split_candidate_sv) { int aln_offset = static_cast(ref_distance - read_distance); - if (read_distance > ref_distance && read_distance >= min_length && read_distance <= max_length) { + + if (read_distance > ref_distance && read_distance >= min_length && read_distance <= max_length_noninv) { // Add an insertion SV call at the 5'-most primary position SVType sv_type = SVType::INS; SVCall sv_candidate(sv_start, sv_start + (read_distance-1), sv_type, getSVTypeSymbol(sv_type), aln_type, Genotype::UNKNOWN, 0.0, 0, aln_offset, primary_cluster_size); addSVCall(chr_sv_calls, sv_candidate); - // } - } else if (ref_distance > read_distance && ref_distance >= min_length && ref_distance <= max_length) { + // } + } else if (ref_distance > read_distance && ref_distance >= min_length && ref_distance <= max_length_noninv) { // Set it to unknown, SV type will be determined by the // HMM prediction @@ -476,9 +477,28 @@ void SVCaller::findSplitSVSignatures(std::unordered_map= min_length && sv_length <= max_length) { + int max_allowed_length = (sv_type == SVType::INV) ? max_length_inv : max_length_noninv; + if (sv_length >= min_length && sv_length <= max_allowed_length) { + // Use balanced support for inversions. + // For non-inversions, keep large events even with sparse + // split-read support because >100kb SVs often have few + // spanning split reads. + int balanced_cluster_size = std::min(primary_cluster_size, supp_cluster_size); + if (sv_type == SVType::INV) { + const int INV_MIN_LENGTH = 500; + // Size-dependent cluster threshold: large inversions (>50kb) + // may have sparse split-read support, similar to other SV types + int min_cluster = (sv_length > 50000) ? 3 : 5; + if (sv_length < INV_MIN_LENGTH || balanced_cluster_size < min_cluster) { + continue; + } + } SVEvidenceFlags aln_type; - aln_type.set(static_cast(SVDataType::SPLIT)); + if (sv_type == SVType::INV) { + aln_type.set(static_cast(SVDataType::SPLITINV)); + } else { + aln_type.set(static_cast(SVDataType::SPLIT)); + } SVCall sv_candidate(sv_start, sv_end, sv_type, alt, aln_type, Genotype::UNKNOWN, 0.0, 0, 0, cluster_size); addSVCall(chr_sv_calls, sv_candidate); } @@ -499,8 +519,11 @@ void SVCaller::findSplitSVSignatures(std::unordered_map& sv_calls, const std::vector& pos_depth_map) @@ -596,7 +619,7 @@ void SVCaller::processCIGARRecord(bam_hdr_t *header, bam1_t *alignment, std::vec cigar_sv_calls.emplace_back(sv_call); // Process clipped bases as potential insertions - } else if (op == BAM_CSOFT_CLIP) { + } else if (op == BAM_CSOFT_CLIP && op_len >= 200) { // Increased from 100bp to reduce adapter/error artifacts // Soft-clipped bases are considered as potential insertions // Skip if the position exceeds the reference genome length if (pos + 1 >= pos_depth_map.size()) { @@ -698,7 +721,6 @@ void SVCaller::processChromosome(const std::string& chr, std::vector& ch printError("ERROR: failed to open " + bam_filepath); return; } - hts_set_threads(fp_in, 1); // Load the header bam_hdr_t *bamHdr = sam_hdr_read(fp_in); @@ -708,6 +730,10 @@ void SVCaller::processChromosome(const std::string& chr, std::vector& ch return; } + // Single-threaded I/O in worker threads to prevent index contention + // (ThreadPool already provides parallelism across chromosomes) + hts_set_threads(fp_in, 1); + // Load the index hts_idx_t *idx = sam_index_load(fp_in, bam_filepath.c_str()); if (!idx) { @@ -724,12 +750,10 @@ void SVCaller::processChromosome(const std::string& chr, std::vector& ch double dbscan_min_pts_pct = input_data.getDBSCAN_MinPtsPct(); if (dbscan_min_pts_pct > 0.0) { dbscan_min_pts = (int)std::ceil(mean_chr_cov * dbscan_min_pts_pct); - printMessage(chr + ": Mean chr. cov.: " + std::to_string(mean_chr_cov) + " (DBSCAN min. pts.= " + std::to_string(dbscan_min_pts) + ", min. pts. pct.= " + std::to_string(dbscan_min_pts_pct) + ")"); - } + } // ----------------------------------------------------------------------- // Detect SVs from the CIGAR strings - printMessage(chr + ": CIGAR SVs..."); this->findCIGARSVs(fp_in, idx, bamHdr, chr, chr_sv_calls, chr_pos_depth_map); // Clean up the BAM file and index @@ -737,11 +761,10 @@ void SVCaller::processChromosome(const std::string& chr, std::vector& ch hts_idx_destroy(idx); bam_hdr_destroy(bamHdr); - printMessage(chr + ": Merging CIGAR..."); mergeSVs(chr_sv_calls, dbscan_epsilon, dbscan_min_pts, false); int region_sv_count = getSVCount(chr_sv_calls); - printMessage(chr + ": Found " + std::to_string(region_sv_count) + " SV candidates in the CIGAR string"); + printMessage(chr + ": Found " + std::to_string(region_sv_count) + " SV candidates"); } void SVCaller::run(const InputData& input_data) @@ -756,25 +779,27 @@ void SVCaller::run(const InputData& input_data) input_data.printParameters(); // Set up the reference genome - printMessage("Loading the reference genome..."); const std::string ref_filepath = input_data.getRefGenome(); std::shared_mutex ref_mutex; // Dummy mutex (remove later) ReferenceGenome ref_genome(ref_mutex); - ref_genome.setFilepath(ref_filepath); + ref_genome.read(ref_filepath); // Get the chromosomes - std::vector chromosomes; - if (input_data.isSingleChr()) { - // Get the chromosome from the user input argument - chromosomes.push_back(input_data.getChromosome()); - } else { - // Get the chromosomes from the input BAM file - chromosomes = this->getChromosomes(input_data.getLongReadBam()); + std::vector chromosomes = this->getChromosomes(input_data.getLongReadBam()); + + // Restrict to a single chromosome if requested + const std::string target_chr = input_data.getChromosome(); + if (!target_chr.empty()) { + auto chr_it = std::find(chromosomes.begin(), chromosomes.end(), target_chr); + if (chr_it == chromosomes.end()) { + printError("Requested chromosome " + target_chr + " not found in BAM header"); + return; + } + chromosomes = {target_chr}; } // Read the HMM from the file std::string hmm_filepath = input_data.getHMMFilepath(); - std::cout << "Reading HMM from file: " << hmm_filepath << std::endl; const CHMM& hmm = ReadCHMM(hmm_filepath.c_str()); // Set up the JSON output file for CNV data @@ -791,15 +816,23 @@ void SVCaller::run(const InputData& input_data) int chr_thread_count = input_data.getThreadCount(); // Initialize the chromosome position depth map and mean coverage map + // (skip chromosomes missing from the reference instead of aborting) + std::vector ref_valid_chromosomes; for (const auto& chr : chromosomes) { uint32_t chr_len = ref_genome.getChromosomeLength(chr); if (chr_len == 0) { - printError("Chromosome " + chr + " not found in reference genome"); - return; - // continue; + printError("Chromosome " + chr + " not found in reference genome, skipping"); + continue; } chr_pos_depth_map[chr] = std::vector(chr_len+1, 0); // 1-based index chr_mean_cov_map[chr] = 0.0; + ref_valid_chromosomes.push_back(chr); + } + + chromosomes = std::move(ref_valid_chromosomes); + if (chromosomes.empty()) { + printError("No chromosomes with reference sequence were available for processing"); + return; } cnv_caller.calculateMeanChromosomeCoverage(chromosomes, chr_pos_depth_map, chr_mean_cov_map, bam_filepath, chr_thread_count); @@ -810,8 +843,8 @@ void SVCaller::run(const InputData& input_data) if (chr_mean_cov_map.find(chr) != chr_mean_cov_map.end()) { valid_chr.push_back(chr); } - chromosomes = valid_chr; } + chromosomes = valid_chr; std::unordered_map> whole_genome_sv_calls; int current_chr = 0; int total_chr_count = chromosomes.size(); @@ -819,11 +852,8 @@ void SVCaller::run(const InputData& input_data) if (cigar_svs) { // Use multi-threading across chromosomes. If a single chromosome is // specified, use a single main thread (multi-threading is used for file I/O) - int thread_count = 1; - if (!input_data.isSingleChr()) { - thread_count = input_data.getThreadCount(); - std::cout << "Using " << thread_count << " threads for chr processing..." << std::endl; - } + int thread_count = input_data.getThreadCount(); + std::cout << "Using " << thread_count << " threads for chr processing..." << std::endl; ThreadPool pool(thread_count); auto process_chr = [&](const std::string& chr) { try { @@ -832,7 +862,7 @@ void SVCaller::run(const InputData& input_data) InputData chr_input_data = input_data; // Use a thread-local copy this->processChromosome(chr, sv_calls, chr_input_data, chr_pos_depth_map[chr], chr_mean_cov_map[chr]); { - std::shared_lock lock(this->shared_mutex); + std::unique_lock lock(this->shared_mutex); whole_genome_sv_calls[chr] = std::move(sv_calls); } } catch (const std::exception& e) { @@ -904,7 +934,7 @@ void SVCaller::run(const InputData& input_data) DEBUG_PRINT("Merging split-read SVs..."); for (auto& entry : whole_genome_split_sv_calls) { std::vector& sv_calls = entry.second; - mergeSVs(sv_calls, 0.1, 2, true); + mergeSVs(sv_calls, 0.05, 3, true); // Tightened epsilon/min_pts, keep singletons } } @@ -922,7 +952,7 @@ void SVCaller::run(const InputData& input_data) DEBUG_PRINT("Merging CIGAR and split read SV calls..."); for (auto& entry : whole_genome_sv_calls) { std::vector& sv_calls = entry.second; - mergeSVs(sv_calls, 0.1, 2, true); + mergeSVs(sv_calls, 0.05, 3, true); // Tightened epsilon/min_pts, keep singletons } } @@ -936,7 +966,6 @@ void SVCaller::run(const InputData& input_data) std::string chr = entry.first; int sv_count = getSVCount(entry.second); total_sv_count += sv_count; - printMessage("Total SVs detected for " + chr + ": " + std::to_string(sv_count)); } printMessage("Total SVs detected: " + std::to_string(total_sv_count)); @@ -957,8 +986,9 @@ void SVCaller::findOverlaps(const std::unique_ptr &root, const Pri if (root->left && root->left->max_end >= query.start) findOverlaps(root->left, query, result); - // Always check the right subtree - findOverlaps(root->right, query, result); + // Check right subtree only when the query can overlap intervals there + if (root->right && query.end >= root->region.start) + findOverlaps(root->right, query, result); } void SVCaller::insert(std::unique_ptr &root, const PrimaryAlignment ®ion, std::string qname) @@ -984,6 +1014,12 @@ void SVCaller::runSplitReadCopyNumberPredictions(const std::string& chr, std::ve { std::vector additional_calls; for (auto& sv_candidate : split_sv_calls) { + const uint32_t MAX_INV_HMM_LENGTH = 1000000; // Avoid expensive CNV/HMM over very large inversion spans + uint32_t sv_length = sv_candidate.end - sv_candidate.start + 1; + if (sv_candidate.sv_type == SVType::INV && sv_length > MAX_INV_HMM_LENGTH) { + // Keep split-read inversion call as-is; skip CNV/HMM refinement for very large regions. + continue; + } std::tuple result = cnv_caller.runCopyNumberPrediction(chr, hmm, sv_candidate.start, sv_candidate.end, mean_chr_cov, pos_depth_map, input_data); double supp_lh = std::get<0>(result); @@ -1023,12 +1059,21 @@ void SVCaller::runSplitReadCopyNumberPredictions(const std::string& chr, std::ve sv_candidate.cn_state = cn_state; // For insertions predicted as duplications, update all information } else if (sv_candidate.sv_type == SVType::INS && supp_type == SVType::DUP) { - sv_candidate.sv_type = supp_type; - sv_candidate.alt_allele = getSVTypeSymbol(supp_type); // Update the ALT allele format - sv_candidate.aln_type.set(static_cast(SVDataType::HMM)); - sv_candidate.hmm_likelihood = supp_lh; - sv_candidate.genotype = genotype; - sv_candidate.cn_state = cn_state; + // Only reclassify INS to DUP if it's larger than the minimum DUP threshold + // This reduces false positives from small/mid-sized insertions being + // misclassified as duplications in the 10-50kb range where depth signal is weak + const uint32_t DUP_MIN_SIZE = 10000; // 10kb minimum for DUP reclassification + uint32_t sv_size = sv_candidate.end - sv_candidate.start + 1; + + if (sv_size >= DUP_MIN_SIZE) { + sv_candidate.sv_type = supp_type; + sv_candidate.alt_allele = ""; // Explicitly set to + sv_candidate.aln_type.set(static_cast(SVDataType::HMM)); + sv_candidate.hmm_likelihood = supp_lh; + sv_candidate.genotype = genotype; + sv_candidate.cn_state = cn_state; + } + // Otherwise, keep as INS } else { // Add a new SV call with the conflicting type SVCall new_sv_call = sv_candidate; // Copy the original SV call @@ -1099,7 +1144,6 @@ void SVCaller::saveToVCF(const std::unordered_map header_lines = { std::string("##reference=") + ref_genome.getFilepath(), contig_header, @@ -1145,8 +1185,6 @@ void SVCaller::saveToVCF(const std::unordered_map", }; - std::cout << "Writing VCF header..." << std::endl; - // Add the file format std::string file_format = "##fileformat=VCFv4.2"; vcf_stream << file_format << std::endl; @@ -1173,7 +1211,6 @@ void SVCaller::saveToVCF(const std::unordered_map& sv_calls = pair.second; - std::cout << "Saving SV calls for " << chr << "..." << std::endl; for (const auto& sv_call : sv_calls) { uint32_t start = sv_call.start; uint32_t end = sv_call.end; @@ -1238,6 +1274,37 @@ void SVCaller::saveToVCF(const std::unordered_mapDUP conversion) + bool has_conflict = false; + if (sv_type != SVType::UNKNOWN && cnv_type != SVType::UNKNOWN && cnv_type != SVType::NEUTRAL) { + if ((sv_type == SVType::INS && cnv_type == SVType::DEL) || + (sv_type == SVType::DEL && cnv_type == SVType::DUP) || + (sv_type == SVType::DUP && cnv_type == SVType::DEL)) { + has_conflict = true; + } + } + + // Check cluster support for inversions (unreliable with low support, except for large events where depth evidence may be weak) + bool low_cluster_support = (sv_type == SVType::INV && cluster_size < 5 && sv_length < 100000); + + if (has_conflict || low_cluster_support) { + filter = "LowQual"; + filtered_svs += 1; + } + } + // Deletion if (sv_type == SVType::DEL) { // Get the deleted sequence from the reference genome, also including the preceding base @@ -1280,14 +1347,17 @@ void SVCaller::saveToVCF(const std::unordered_map 0) { - std::cout << "Total unclassified SVs: " << unclassified_svs << std::endl; + std::cout << " Unclassified SVs: " << unclassified_svs << std::endl; } printMessage("Total PASS filtered SVs: " + std::to_string(filtered_svs)); printMessage("Total filtered assembly gaps: " + std::to_string(assembly_gap_filtered_svs)); diff --git a/src/sv_object.cpp b/src/sv_object.cpp index b68ef3c7..79ccea85 100644 --- a/src/sv_object.cpp +++ b/src/sv_object.cpp @@ -27,9 +27,8 @@ void addSVCall(std::vector& sv_calls, SVCall& sv_call) return; } - // Insert the SV call in sorted order - auto it = std::lower_bound(sv_calls.begin(), sv_calls.end(), sv_call); - sv_calls.insert(it, sv_call); + // Append and defer sorting/merging to downstream steps + sv_calls.push_back(sv_call); } uint32_t getSVCount(const std::vector& sv_calls) @@ -43,18 +42,11 @@ void concatenateSVCalls(std::vector &target, const std::vector& } void mergeSVs(std::vector& sv_calls, double epsilon, int min_pts, bool keep_noise, const std::string& json_filepath) -{ - printMessage("Merging SVs with DBSCAN, eps=" + std::to_string(epsilon) + ", min_pts=" + std::to_string(min_pts)); - +{ if (sv_calls.size() < 2) { return; } - // Set this to print cluster information for a specific SV call for debugging - // This is useful for debugging purposes to see how the SVs are merged - bool debug_mode = false; - SVType debug_sv_type = SVType::INV; - // Cluster SVs using DBSCAN for each SV type int initial_size = sv_calls.size(); std::vector merged_sv_calls; @@ -67,12 +59,6 @@ void mergeSVs(std::vector& sv_calls, double epsilon, int min_pts, bool k SVType::BND, }) { - // Skip if not the debug SV type - if (debug_mode && (sv_type != debug_sv_type)) { - DEBUG_PRINT("DEBUG: Skipping SV type " + getSVTypeString(sv_type) + " for debug mode"); - continue; - } - std::vector merged_sv_type_calls; // Create a vector of SV calls for the current SV type and size interval @@ -134,15 +120,6 @@ void mergeSVs(std::vector& sv_calls, double epsilon, int min_pts, bool k for (const auto& sv_call : cluster_sv_calls) { SVCall noise_sv_call = sv_call; merged_sv_type_calls.push_back(noise_sv_call); - - // Print the added SV calls if >10 kb and the debug SV type - if (debug_mode && noise_sv_call.sv_type == debug_sv_type && (noise_sv_call.end - noise_sv_call.start) > 10000) { - DEBUG_PRINT("DEBUG: Adding noise SV call at " + std::to_string(noise_sv_call.start) + "-" + std::to_string(noise_sv_call.end) + - ", type: " + getSVTypeString(noise_sv_call.sv_type) + - ", length: " + std::to_string(noise_sv_call.end - noise_sv_call.start) + - ", cluster size: " + std::to_string(noise_sv_call.cluster_size) + - ", likelihood: " + std::to_string(noise_sv_call.hmm_likelihood)); - } } // Merge clustered SV calls @@ -191,53 +168,15 @@ void mergeSVs(std::vector& sv_calls, double epsilon, int min_pts, bool k return (a.end - a.start) > (b.end - b.start); }); - // Print the added SV calls if >10 kb and the debug SV type - if (debug_mode && sv_type == debug_sv_type) { - DEBUG_PRINT("DEBUG: Cluster " + std::to_string(cluster_id) + " with " + std::to_string(cluster_sv_calls.size()) + " SV calls (length sorted):"); - for (const auto& sv_call : cluster_sv_calls) { - if ((sv_call.end - sv_call.start) > 10000) { - DEBUG_PRINT("DEBUG: SV call at " + std::to_string(sv_call.start) + "-" + std::to_string(sv_call.end) + - ", type: " + getSVTypeString(sv_call.sv_type) + - ", length: " + std::to_string(sv_call.end - sv_call.start) + - ", cluster size: " + std::to_string(sv_call.cluster_size) + - ", likelihood: " + std::to_string(sv_call.hmm_likelihood)); - } - } - } - // Get the top % of the cluster double top_pct = 0.2; size_t top_pct_size = std::max(1, (int) (cluster_sv_calls.size() * top_pct)); std::vector top_pct_calls(cluster_sv_calls.begin(), cluster_sv_calls.begin() + top_pct_size); - // Print the added SV calls if >10 kb and the debug SV type - if (debug_mode && sv_type == debug_sv_type) { - DEBUG_PRINT("DEBUG: Top " + std::to_string((int)(top_pct * 100)) + "% of cluster " + std::to_string(cluster_id) + " with " + - std::to_string(top_pct_calls.size()) + " SV calls (length sorted):"); - for (const auto& sv_call : top_pct_calls) { - if ((sv_call.end - sv_call.start) > 10000) { - DEBUG_PRINT("DEBUG: SV call at " + std::to_string(sv_call.start) + "-" + std::to_string(sv_call.end) + - ", type: " + getSVTypeString(sv_call.sv_type) + - ", length: " + std::to_string(sv_call.end - sv_call.start) + - ", cluster size: " + std::to_string(sv_call.cluster_size) + - ", likelihood: " + std::to_string(sv_call.hmm_likelihood)); - } - } - } - // Get the median SV for the top % of the cluster size_t median_index = top_pct_calls.size() / 2; merged_sv_call = top_pct_calls[median_index]; - // Print the merged SV call - if (debug_mode && sv_type == debug_sv_type) { - DEBUG_PRINT("DEBUG: Merged SV call at " + std::to_string(merged_sv_call.start) + "-" + std::to_string(merged_sv_call.end) + - ", type: " + getSVTypeString(merged_sv_call.sv_type) + - ", length: " + std::to_string(merged_sv_call.end - merged_sv_call.start) + - ", cluster size: " + std::to_string(merged_sv_call.cluster_size) + - ", likelihood: " + std::to_string(merged_sv_call.hmm_likelihood)); - } - // Add SV call merged_sv_call.cluster_size = (int) cluster_sv_calls.size(); merged_sv_type_calls.push_back(merged_sv_call); @@ -247,25 +186,10 @@ void mergeSVs(std::vector& sv_calls, double epsilon, int min_pts, bool k } DEBUG_PRINT("Merged " + std::to_string(cluster_count) + " clusters of " + getSVTypeString(sv_type) + ", found " + std::to_string(merged_sv_type_calls.size()) + " merged SV calls"); - // Print SV call start, end, type, and length for debugging if > 10 kb - if (debug_mode && sv_type == debug_sv_type) { - DEBUG_PRINT("DEBUG: Merged SV calls for " + getSVTypeString(sv_type) + ":"); - for (const auto& sv_call : merged_sv_type_calls) { - if ((sv_call.end - sv_call.start) > 10000) { - DEBUG_PRINT("DEBUG: SV call at " + std::to_string(sv_call.start) + "-" + std::to_string(sv_call.end) + - ", type: " + getSVTypeString(sv_call.sv_type) + - ", length: " + std::to_string(sv_call.end - sv_call.start) + - ", cluster size: " + std::to_string(sv_call.cluster_size) + - ", likelihood: " + std::to_string(sv_call.hmm_likelihood)); - } - } - } merged_sv_calls.insert(merged_sv_calls.end(), merged_sv_type_calls.begin(), merged_sv_type_calls.end()); } sv_calls = std::move(merged_sv_calls); // Replace with filtered list - int updated_size = sv_calls.size(); - printMessage("Merged " + std::to_string(initial_size) + " SV calls into " + std::to_string(updated_size) + " SV calls"); } void saveClustersToJSON(const std::string &filename, const std::map> &clusters) @@ -312,18 +236,15 @@ void saveClustersToJSON(const std::string &filename, const std::map &sv_calls) { - int initial_size = sv_calls.size(); std::vector combined_sv_calls; // Sort first by start position, then by SV type @@ -342,9 +263,6 @@ void mergeDuplicateSVs(std::vector &sv_calls) combined_sv_calls.push_back(sv_call); } } - int merge_count = initial_size - combined_sv_calls.size(); sv_calls = std::move(combined_sv_calls); // Replace with filtered list - if (merge_count > 0) { - printMessage("Merged " + std::to_string(merge_count) + " SV candidates with identical start and end positions"); - } + } diff --git a/src/swig_interface.cpp b/src/swig_interface.cpp deleted file mode 100644 index 76eb2151..00000000 --- a/src/swig_interface.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include "swig_interface.h" -#include "contextsv.h" - -/// @cond -#include -/// @endcond - - -// Run the CLI with the given parameters -int run(const InputData& input_data) -{ - // Run ContextSV - ContextSV contextsv; - try - { - contextsv.run(input_data); - } - - catch (std::exception& e) - { - std::cerr << e.what() << std::endl; - return -1; - } - - return 0; -} diff --git a/src/swig_wrapper.i b/src/swig_wrapper.i deleted file mode 100644 index 62903afe..00000000 --- a/src/swig_wrapper.i +++ /dev/null @@ -1,22 +0,0 @@ -/* -SWIG wrapper for C++ code. -*/ - -%module contextsv - -// Include header -%{ -#include "swig_interface.h" -#include "input_data.h" -%} - -// Set up types -%include "std_string.i" -%include "stdint.i" - -// Set up the namespace -%include "input_data.h" - -// Include functions -int run(InputData input_data); - diff --git a/tests/test_general.py b/tests/test_general.py index c1c59da7..8f8a1223 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -90,12 +90,7 @@ def test_run_basic(): "--hmm", HMM_FILE, "--eth", "nfe", "--pfb", PFB_FILE, - "--sample-size", "20", - "--min-cnv", "2000", - "--eps", "0.1", - "--min-pts-pct", "0.1", "--assembly-gaps", GAP_FILE, - "--chr", "chr3", "--save-cnv", "--debug" ],