Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "biosynth"
version = "0.1.27"
version = "0.1.28"
edition = "2021"
rust-version = "1.91"
authors = ["Madhava Jay <madhava@openmined.org>"]
Expand Down
323 changes: 323 additions & 0 deletions cli/src/commands/fast_allele_freq.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
use std::collections::BTreeMap;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant;

use anyhow::{bail, Context, Result};
use rayon::prelude::*;
use rayon::ThreadPoolBuilder;

use crate::commands::long_emit::{
collect_merged_long_rows, default_participant, summarize_warning_counts, MissingRefLogger,
ReferenceResolver, SharedReference,
};
use crate::download::ensure_reference_db;
use crate::long_rows::LongRow;
use crate::stats::StatsStore;
use crate::FastAlleleFreqArgs;

/// Per-locus accumulator. Mirrors the running counters in
/// `long_aggregate::merge_chunks` exactly so the emitted TSV is byte-for-byte
/// identical to `emit-long -> aggregate-long --threads 1`.
struct Accum {
rsid: String,
/// Rank (participant sort order) of the participant whose non-empty rsid is
/// currently kept. `u32::MAX` until a non-empty rsid is seen. The smallest
/// rank wins, reproducing aggregate's "first non-empty rsid by participant".
rsid_rank: u32,
allele_count: i64,
n_obs: i64,
num_homo: i64,
num_hetero: i64,
}

type LociMap = BTreeMap<String, Accum>;
type Counts = BTreeMap<String, u64>;

pub fn run_fast_allele_freq(args: FastAlleleFreqArgs) -> Result<()> {
let overall_start = Instant::now();

let sqlite_path = ensure_reference_db(Some(&args.sqlite), args.force_download)?;
let store = StatsStore::connect_read_only(&sqlite_path)?;
// Preload the whole reference into memory once. Lets worker threads resolve
// rsid/position with lock-free reads instead of contending on SQLite.
let preload_start = Instant::now();
let shared = Arc::new(SharedReference::load(&store)?);
eprintln!(
"📚 fast-allele-freq: reference preloaded in {:.2}s",
preload_start.elapsed().as_secs_f64()
);

let files = collect_input_files(&args.inputs)?;
if files.is_empty() {
bail!("No input genotype files found under the provided --input paths");
}
// Sort by participant id. aggregate-long --threads 1 merges rows sorted by
// (locus_key, participant_id), so the rsid kept per locus is from the
// alphabetically-first participant with a non-empty rsid. The rank we assign
// here = that participant order, and the rank-min rule in `fold_row`
// reproduces the choice regardless of parallel parse order.
let mut tasks: Vec<(String, PathBuf)> = files
.into_iter()
.map(|path| (default_participant(&path), path))
.collect();
tasks.sort_by(|a, b| a.0.cmp(&b.0));
let tasks: Vec<(u32, String, PathBuf)> = tasks
.into_iter()
.enumerate()
.map(|(rank, (pid, path))| (rank as u32, pid, path))
.collect();

// Per-row TSV log can't be written safely from many threads -> single thread.
let force_single = args.missing_ref_log.is_some();
let threads = if force_single {
1
} else {
resolve_threads(args.threads)
};
eprintln!(
"▶️ fast-allele-freq: {} input file(s), threads={}",
tasks.len(),
threads
);

let (loci, counts) = if threads <= 1 {
run_sequential(&tasks, &shared, &args)?
} else {
let pool = ThreadPoolBuilder::new()
.num_threads(threads)
.build()
.context("build fast-allele-freq thread pool")?;
pool.install(|| run_parallel(&tasks, &shared))?
};

write_allele_freq(&args.allele_freq_tsv, &loci)?;

if args.warn_detail != crate::WarnDetail::None {
let summary = summarize_warning_counts(&counts);
if !summary.is_empty() {
eprintln!("⚠️ fast-allele-freq warnings: {summary}");
}
}
eprintln!(
"✅ fast-allele-freq: {} files, {} loci, {:.2}s",
tasks.len(),
loci.len(),
overall_start.elapsed().as_secs_f64()
);
println!(
"✅ Wrote allele frequencies to {}",
args.allele_freq_tsv.display()
);
Ok(())
}

fn run_sequential(
tasks: &[(u32, String, PathBuf)],
shared: &Arc<SharedReference>,
args: &FastAlleleFreqArgs,
) -> Result<(LociMap, Counts)> {
let mut resolver = ReferenceResolver::shared(shared.clone());
let mut logger = MissingRefLogger::new(args.missing_ref_log.as_deref(), args.warn_detail)?;
let mut loci: LociMap = BTreeMap::new();
for (idx, (rank, participant, path)) in tasks.iter().enumerate() {
match collect_merged_long_rows(path, participant, &mut resolver, &mut logger) {
Ok((rows, _stats)) => {
for row in rows {
fold_row(&mut loci, row, *rank);
}
}
Err(err) => eprintln!("⚠️ skipping {}: {err:#}", path.display()),
}
if (idx + 1) % 100 == 0 {
eprintln!(
"🧮 fast-allele-freq: {} files, {} loci so far",
idx + 1,
loci.len()
);
}
}
Ok((loci, logger.counts().clone()))
}

fn run_parallel(
tasks: &[(u32, String, PathBuf)],
shared: &Arc<SharedReference>,
) -> Result<(LociMap, Counts)> {
// Each thread folds its share of files into a local map + counts, then maps
// are merged. Counts are order-independent; rsid uses the rank-min rule so the
// merge is deterministic and matches the single-threaded result.
let result = tasks
.par_iter()
.fold(
|| ThreadState::new(shared.clone()),
|mut st, (rank, participant, path)| {
match collect_merged_long_rows(path, participant, &mut st.resolver, &mut st.logger)
{
Ok((rows, _stats)) => {
for row in rows {
fold_row(&mut st.loci, row, *rank);
}
}
Err(err) => eprintln!("⚠️ skipping {}: {err:#}", path.display()),
}
st
},
)
.map(|st| (st.loci, st.logger.counts().clone()))
.reduce(
|| (BTreeMap::new(), BTreeMap::new()),
|(mut m1, mut c1), (m2, c2)| {
merge_loci(&mut m1, m2);
merge_counts(&mut c1, c2);
(m1, c1)
},
);
Ok(result)
}

struct ThreadState {
resolver: ReferenceResolver,
logger: MissingRefLogger,
loci: LociMap,
}

impl ThreadState {
fn new(shared: Arc<SharedReference>) -> Self {
// Lock-free in-memory resolver per thread (cheap Arc clone). No per-row
// TSV in parallel mode (counts only); WarnDetail::None keeps worker
// threads from interleaving per-row stderr. Construction cannot fail.
Self {
resolver: ReferenceResolver::shared(shared),
logger: MissingRefLogger::new(None, crate::WarnDetail::None)
.expect("count-only logger has no file handle and cannot fail"),
loci: BTreeMap::new(),
}
}
}

fn fold_row(loci: &mut LociMap, row: LongRow, rank: u32) {
let entry = loci.entry(row.locus_key).or_insert_with(|| Accum {
rsid: String::new(),
rsid_rank: u32::MAX,
allele_count: 0,
n_obs: 0,
num_homo: 0,
num_hetero: 0,
});
if !row.rsid.is_empty() && rank < entry.rsid_rank {
entry.rsid = row.rsid;
entry.rsid_rank = rank;
}
if row.dosage != -1 {
entry.allele_count += row.dosage as i64;
entry.n_obs += 1;
if row.dosage == 2 {
entry.num_homo += 1;
} else if row.dosage == 1 {
entry.num_hetero += 1;
}
}
}

fn merge_loci(into: &mut LociMap, from: LociMap) {
for (locus, b) in from {
match into.get_mut(&locus) {
Some(a) => {
a.allele_count += b.allele_count;
a.n_obs += b.n_obs;
a.num_homo += b.num_homo;
a.num_hetero += b.num_hetero;
if b.rsid_rank < a.rsid_rank {
a.rsid = b.rsid;
a.rsid_rank = b.rsid_rank;
}
}
None => {
into.insert(locus, b);
}
}
}
}

fn merge_counts(into: &mut Counts, from: Counts) {
for (code, n) in from {
*into.entry(code).or_insert(0) += n;
}
}

fn resolve_threads(requested: usize) -> usize {
if requested > 0 {
return requested;
}
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
.max(1)
}

fn write_allele_freq(path: &Path, loci: &LociMap) -> Result<()> {
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)
.with_context(|| format!("Create directory {:?}", parent))?;
}
}
let mut writer =
BufWriter::new(File::create(path).with_context(|| format!("Create {:?}", path))?);
writeln!(
writer,
"locus_key\tallele_count\tallele_number\tnum_homo\tnum_hetero\tallele_freq\trsid"
)?;
for (locus, acc) in loci {
let allele_number = 2 * acc.n_obs;
let allele_freq = if allele_number > 0 {
acc.allele_count as f64 / allele_number as f64
} else {
0.0
};
writeln!(
writer,
"{locus}\t{}\t{allele_number}\t{}\t{}\t{allele_freq:.6}\t{}",
acc.allele_count, acc.num_homo, acc.num_hetero, acc.rsid
)?;
}
writer.flush()?;
Ok(())
}

fn collect_input_files(inputs: &[PathBuf]) -> Result<Vec<PathBuf>> {
let mut files = Vec::new();
for input in inputs {
if input.is_file() {
files.push(input.clone());
} else if input.is_dir() {
// follow_links: flow 04 stages genotype files as symlinks; without
// this they'd be reported as symlink (not file) and skipped.
for entry in walkdir::WalkDir::new(input)
.follow_links(true)
.into_iter()
.filter_map(|e| e.ok())
{
if entry.file_type().is_file() {
let path = entry.path();
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if name.starts_with('.') {
continue;
}
if path.extension().and_then(|e| e.to_str()) == Some("bvlr") {
continue;
}
files.push(path.to_path_buf());
}
}
} else {
eprintln!("⚠️ Input path {:?} is not a file or directory", input);
}
}
files.sort();
files.dedup();
Ok(files)
}
Loading
Loading