diff --git a/src/rsrs/rsrs_cycle.rs b/src/rsrs/rsrs_cycle.rs index 37fee5e..e0be718 100644 --- a/src/rsrs/rsrs_cycle.rs +++ b/src/rsrs/rsrs_cycle.rs @@ -43,6 +43,7 @@ use rustc_hash::FxHashSet; use std::{ collections::HashMap, path::{Path, PathBuf}, + sync::Arc, time::{Duration, Instant}, }; @@ -64,6 +65,7 @@ pub struct Rsrs { pub stats: Stats, options: RsrsOptions, anticipated_fixed_rank_samples: Option, + thread_pool: Arc, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -837,6 +839,7 @@ where }; println!("Configured number of threads = {}", options.num_threads); + let thread_pool = Arc::new(build_rsrs_thread_pool(options.num_threads)); let stats = Stats { sampling_time: Vec::new(), sample_loading_time: 0_u128, @@ -883,6 +886,7 @@ where active_samples: 0, options, anticipated_fixed_rank_samples, + thread_pool, } } @@ -1411,14 +1415,14 @@ where level: usize, update_type: &UpdateType, ) -> (u128, u128, Option) { - let thread_pool = build_rsrs_thread_pool(self.options.num_threads); + let thread_pool = &self.thread_pool; let (mut tot_id_update, mut tot_lu_update) = self.y_data.update_samples( update_start, samples_to_update, level, update_type, &self.options.fact_type, - &thread_pool, + thread_pool, self.options.num_threads, ); @@ -1435,7 +1439,7 @@ where level, update_type, &self.options.fact_type, - &thread_pool, + thread_pool, self.options.num_threads, ); tot_id_update += tot_z_id_update; @@ -1524,7 +1528,7 @@ where let current_level_keys: Vec = self.level_indexing.level_keys.iter().copied().collect(); // Process all batches *sequentially*, each batch using Rayon internally - let thread_pool = build_rsrs_thread_pool(self.options.num_threads); + let thread_pool = Arc::clone(&self.thread_pool); let batches_res: Vec<_> = thread_pool.install(|| { independent_near_fields .into_iter() @@ -1689,7 +1693,6 @@ where self.options.sketching.fixed_rank_sampling_mode, &self.options, ); - >::lu_step( &skel_box, scratch, @@ -1811,7 +1814,7 @@ where ) }); let start = Instant::now(); - let thread_pool = build_rsrs_thread_pool(self.options.num_threads); + let thread_pool = Arc::clone(&self.thread_pool); let id_level_iteration_res: Vec<_> = thread_pool.install(|| { current_box_indices .par_iter() @@ -1961,7 +1964,7 @@ where let lu_step_start: Instant = Instant::now(); let mut inactive_inds = Vec::new(); - let thread_pool = build_rsrs_thread_pool(self.options.num_threads); + let thread_pool = Arc::clone(&self.thread_pool); let batches_res: Vec<_> = independent_near_fields .into_iter() .map(|batch| {