Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/rsrs/rsrs_cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use rustc_hash::FxHashSet;
use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::Arc,
time::{Duration, Instant},
};

Expand All @@ -64,6 +65,7 @@ pub struct Rsrs<Item: RlstScalar> {
pub stats: Stats,
options: RsrsOptions<Item>,
anticipated_fixed_rank_samples: Option<FixedRankSampleBudget>,
thread_pool: Arc<rayon::ThreadPool>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -837,6 +839,7 @@ where
};

println!("Configured number of threads = {}", options.num_threads);
let thread_pool = Arc::new(build_rsrs_thread_pool(options.num_threads));
let stats = Stats {
sampling_time: Vec::new(),
sample_loading_time: 0_u128,
Expand Down Expand Up @@ -883,6 +886,7 @@ where
active_samples: 0,
options,
anticipated_fixed_rank_samples,
thread_pool,
}
}

Expand Down Expand Up @@ -1411,14 +1415,14 @@ where
level: usize,
update_type: &UpdateType<Item>,
) -> (u128, u128, Option<u128>) {
let thread_pool = build_rsrs_thread_pool(self.options.num_threads);
let thread_pool = &self.thread_pool;
let (mut tot_id_update, mut tot_lu_update) = self.y_data.update_samples(
update_start,
samples_to_update,
level,
update_type,
&self.options.fact_type,
&thread_pool,
thread_pool,
self.options.num_threads,
);

Expand All @@ -1435,7 +1439,7 @@ where
level,
update_type,
&self.options.fact_type,
&thread_pool,
thread_pool,
self.options.num_threads,
);
tot_id_update += tot_z_id_update;
Expand Down Expand Up @@ -1524,7 +1528,7 @@ where
let current_level_keys: Vec<MortonKey> =
self.level_indexing.level_keys.iter().copied().collect();
// Process all batches *sequentially*, each batch using Rayon internally
let thread_pool = build_rsrs_thread_pool(self.options.num_threads);
let thread_pool = Arc::clone(&self.thread_pool);
let batches_res: Vec<_> = thread_pool.install(|| {
independent_near_fields
.into_iter()
Expand Down Expand Up @@ -1689,7 +1693,6 @@ where
self.options.sketching.fixed_rank_sampling_mode,
&self.options,
);

<Item as Skel<Item, Space>>::lu_step(
&skel_box,
scratch,
Expand Down Expand Up @@ -1811,7 +1814,7 @@ where
)
});
let start = Instant::now();
let thread_pool = build_rsrs_thread_pool(self.options.num_threads);
let thread_pool = Arc::clone(&self.thread_pool);
let id_level_iteration_res: Vec<_> = thread_pool.install(|| {
current_box_indices
.par_iter()
Expand Down Expand Up @@ -1961,7 +1964,7 @@ where
let lu_step_start: Instant = Instant::now();

let mut inactive_inds = Vec::new();
let thread_pool = build_rsrs_thread_pool(self.options.num_threads);
let thread_pool = Arc::clone(&self.thread_pool);
let batches_res: Vec<_> = independent_near_fields
.into_iter()
.map(|batch| {
Expand Down
Loading