diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 4a5818d..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 7e8d20f..3e9a2df 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -11,6 +11,8 @@ jobs: run-tests-rust: name: Run Rust tests runs-on: ubuntu-latest + env: + OPENBLAS_NUM_THREADS: 1 strategy: matrix: rust-version: ["stable"] @@ -28,32 +30,116 @@ jobs: - name: Install cargo-mpirun and cargo-templated-examples run: cargo install cargo-mpirun cargo-templated-examples - uses: actions/checkout@v4 - - name: Install LAPACK, OpenBLAS + - name: Install system libraries run: - sudo apt-get install -y libopenblas-dev liblapack-dev + sudo apt-get install -y libopenblas-dev liblapack-dev libhdf5-dev pkg-config - name: Run unit tests run: cargo test ${{ matrix.feature-flags }} + - name: Run operator matvec regression test + run: cargo test ${{ matrix.feature-flags }} --test rsrs_operator_mv rsrs_operator_matvec_diagnostic -- --nocapture - name: Run unit tests in release mode run: cargo test --release ${{ matrix.feature-flags }} + - name: Run operator matvec regression test in release mode + run: cargo test --release ${{ matrix.feature-flags }} --test rsrs_operator_mv rsrs_operator_matvec_diagnostic -- --nocapture - name: Run tests run: cargo test --examples --release ${{ matrix.feature-flags }} - name: Run examples - run: OPENBLAS_NUM_THREADS=1 cargo templated-examples + run: cargo templated-examples + + perturbed-biegrid-regressions: + name: Perturbed BIEGrid regressions + runs-on: ubuntu-22.04 + timeout-minutes: 120 + steps: + - name: Set up Rust + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: stable + - name: Set up MPI + uses: mpi4py/setup-mpi@v1 + with: + mpi: openmpi + - uses: actions/checkout@v4 + - name: Check out public rsrs-exps harness + uses: actions/checkout@v4 + with: + repository: ignacia-fp/rsrs-exps + path: rsrs-exps + - name: Install public rsrs-exps system dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + git \ + pkg-config \ + build-essential \ + cmake \ + clang \ + lld \ + gfortran \ + libssl-dev \ + openmpi-bin \ + libopenmpi-dev \ + libhdf5-dev \ + libopenblas-dev \ + liblapack-dev \ + libfftw3-dev \ + python3.10 \ + python3.10-dev \ + python3.10-venv \ + python3-pip \ + patchelf \ + libfontconfig1-dev \ + libfreetype6-dev \ + gmsh \ + libgl1 \ + libglu1-mesa \ + libx11-6 \ + libxext6 \ + libxi6 \ + libxrender1 \ + libxfixes3 \ + libxrandr2 \ + libxcursor1 \ + libxinerama1 \ + libsm6 \ + libice6 + - name: Patch rsrs-exps to current bempp-rsrs checkout + run: | + mkdir -p rsrs-exps/.cargo + printf '%s\n%s\n' \ + '[patch."https://github.com/bempp/rsrs.git"]' \ + "bempp-rsrs = { path = \"${GITHUB_WORKSPACE}\" }" \ + > rsrs-exps/.cargo/config.toml + - name: Bootstrap public rsrs-exps harness + run: | + cd rsrs-exps + DEPS_DIR="$PWD/.deps" WORKSPACE="$PWD" bash scripts/setup_deps.sh true + - name: Run perturbed BIEGrid regression suite + run: | + cd rsrs-exps + . .venv/bin/activate + python scripts/run_perturbed_biegrid_suite.py check-dependencies: name: Check dependencies + if: false runs-on: ubuntu-latest + continue-on-error: true strategy: matrix: rust-version: ["stable"] + env: + CARGO_REGISTRIES_CRATES_IO_PROTOCOL: git steps: - name: Set up Rust uses: actions-rust-lang/setup-rust-toolchain@v1 with: toolchain: ${{ matrix.rust-version }} - name: Install cargo-upgrades - run: cargo install cargo-upgrades + run: cargo install cargo-upgrades --version 2.1.2 - uses: actions/checkout@v4 - name: Check that dependencies are up to date run: diff --git a/.github/workflows/style-checks.yml b/.github/workflows/style-checks.yml index 605a563..745ea39 100644 --- a/.github/workflows/style-checks.yml +++ b/.github/workflows/style-checks.yml @@ -11,6 +11,8 @@ jobs: style-checks: name: Rust style checks runs-on: ubuntu-latest + env: + OPENBLAS_NUM_THREADS: 1 strategy: matrix: feature-flags: [''] @@ -25,13 +27,11 @@ jobs: with: mpi: mpich - uses: actions/checkout@v4 - - name: Install LAPACK, OpenBLAS + - name: Install system libraries run: - sudo apt-get install -y libopenblas-dev liblapack-dev + sudo apt-get install -y libopenblas-dev liblapack-dev libhdf5-dev pkg-config - name: Rust style checks run: | cargo fmt -- --check - #cargo clippy ${{ matrix.feature-flags }} -- -D warnings - #cargo clippy --tests ${{ matrix.feature-flags }} -- -D warnings - #cargo clippy --examples ${{ matrix.feature-flags }} -- -D warnings + cargo clippy --all-targets --all-features --no-deps -- -D warnings diff --git a/.gitignore b/.gitignore index c54c04a..15b9f9f 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,9 @@ Cargo.lock #.idea/ examples.sh + +.DS_Store +*/.DS_Store +*/*/.DS_Store +**/.DS_Store +src.zip diff --git a/Cargo.toml b/Cargo.toml index a95d7ed..f258460 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,10 @@ itertools = "0.14.*" rand = { version = "0.8.5", features = ["alloc"] } rand_chacha = "0.3.*" num = "0.4.*" -rlst = {git = "https://github.com/linalg-rs/rlst.git", branch = "extend_abstract_operator", features = ["mpi"]} +rlst = { git = "https://github.com/ignacia-fp/rlst.git", branch = "extend_abstract_operator", features = [ + "enable_tracing", + "mpi", +] } bempp-octree = { git = "https://github.com/bempp/octree.git" } mpi = { version = "0.8.*", features = ["derive", "user-operations"] } ndarray = "0.16.1" @@ -37,6 +40,9 @@ paste = "1" serde = { version = "1.0", features = ["derive"] } rayon = "1.8" rustc-hash = "2.1" +hdf5 = { git = "https://github.com/aldanor/hdf5-rust.git", branch = "master" } +num-complex = "0.4" +num_cpus = "1.16" [profile.release] debug = 1 diff --git a/examples/.DS_Store b/examples/.DS_Store deleted file mode 100644 index 801be6c..0000000 Binary files a/examples/.DS_Store and /dev/null differ diff --git a/examples/rsrs_errors.rs b/examples/rsrs_errors.rs index 7458b11..b4ccb36 100644 --- a/examples/rsrs_errors.rs +++ b/examples/rsrs_errors.rs @@ -1,32 +1,153 @@ use bempp_octree::{generate_random_points, Octree}; use bempp_rsrs::{ rsrs::{ - rsrs_cycle::{RankPicking, Rsrs, RsrsArgs, RsrsOptions}, + args::{RankPicking, RsrsArgs, RsrsOptions}, + rsrs_cycle::Rsrs, rsrs_factors::{ - CommutativeFactors, Factor, FactorOperations, FactorType, IdFactor, LuFactor, - MulOptions, PivotMethod, RsrsFactors, RsrsFactorsImpl, RsrsSide, + base_factors::BaseFactorOptions, + commutative_factors::{ + CommutativeFactors, Factor, FactorOperations, FactorType, IdFactor, LuFactor, + MulOptions, MultiLevelIdFactors, RsrsFactors, + }, + null_and_extract::PivotMethod, + rsrs_operator::{FactType, RsrsApply, RsrsFactorsImpl}, }, + sketch::Shift, }, utils::{ data_ins_ext::{ExtInsType, Extraction, MatrixExtraction}, - least_squares_and_null::{BlockExtractionMethod, NullMethod}, + linear_algebra::{BlockExtractionMethod, NullMethod}, }, }; use mpi::{topology::SimpleCommunicator, traits::CommunicatorCollectives}; use num::{Complex, NumCast}; +use rand::rngs::StdRng; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; use rand_distr::{Distribution, Standard, StandardNormal}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use rlst::{ - dense::{linalg::lu::MatrixLu, tools::RandScalar}, + dense::{ + linalg::{interpolative_decomposition::MatrixIdNoSkel, lu::MatrixLu}, + tools::RandScalar, + }, prelude::*, }; -use std::sync::{Arc, Mutex}; +use std::env; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, +}; type Errors = (f64, f64); type ErrorStats = (f64, f64, f64, f64); type Real = ::Real; +const DEFAULT_ALGO_SEED: u64 = 0; +const PROBE_SEED_XOR_TAG: u64 = 0x5052_4F42_455F_5345; +static PROBE_SEED: AtomicU64 = AtomicU64::new(0); + +fn mix_seed(mut seed: u64) -> u64 { + seed = seed.wrapping_add(0x9E3779B97F4A7C15); + seed = (seed ^ (seed >> 30)).wrapping_mul(0xBF58476D1CE4E5B9); + seed = (seed ^ (seed >> 27)).wrapping_mul(0x94D049BB133111EB); + seed ^ (seed >> 31) +} + +fn set_probe_seed(seed: u64) { + PROBE_SEED.store(seed, Ordering::Relaxed); +} + +fn current_probe_seed() -> u64 { + PROBE_SEED.load(Ordering::Relaxed) +} + +fn seeded_rng(tag: u64) -> StdRng { + StdRng::seed_from_u64(mix_seed(current_probe_seed() ^ tag)) +} + +#[derive(Clone, Copy)] +struct BenchmarkConfig { + algo_seed: u64, + probe_seed: u64, + seed_runs: usize, + include_box_errors: bool, +} + +impl BenchmarkConfig { + fn algo_seed_for_run(&self, run_index: usize) -> u64 { + if run_index == 0 { + self.algo_seed + } else { + mix_seed( + self.algo_seed + ^ (run_index as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15) + ^ 0xA11A_6000_0000_0001, + ) + } + } + + fn probe_seed_for_run(&self, run_index: usize) -> u64 { + if run_index == 0 { + self.probe_seed + } else { + mix_seed( + self.probe_seed + ^ (run_index as u64).wrapping_mul(0xBF58_476D_1CE4_E5B9) + ^ 0xBEE5_0000_0000_0002, + ) + } + } +} + +fn factor_type_tag(factor_type: &FactorType) -> u64 { + match factor_type { + FactorType::F => 0xF0, + FactorType::S => 0x5F, + } +} + +fn apply_tag(side: &RsrsApply) -> u64 { + match side { + RsrsApply::Left(factor_type) => 0x1A00 ^ factor_type_tag(factor_type), + RsrsApply::Right(factor_type) => 0x2B00 ^ factor_type_tag(factor_type), + RsrsApply::Sandwich => 0x3C00, + } +} + +fn env_u64(name: &str) -> Option { + env::var(name) + .ok() + .and_then(|value| value.parse::().ok()) +} + +fn env_usize(name: &str) -> Option { + env::var(name) + .ok() + .and_then(|value| value.parse::().ok()) +} + +fn env_flag(name: &str) -> Option { + env::var(name).ok().map(|value| { + matches!( + value.as_str(), + "1" | "true" | "TRUE" | "yes" | "YES" | "on" | "ON" + ) + }) +} + +fn max_mul_error(errors: ErrorStats) -> f64 { + errors.0.max(errors.1).max(errors.2).max(errors.3) +} + +fn median(values: &mut [f64]) -> f64 { + values.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mid = values.len() / 2; + if values.len().is_multiple_of(2) { + 0.5 * (values[mid - 1] + values[mid]) + } else { + values[mid] + } +} // Error functions pub fn spectral_norm_estimator( @@ -41,9 +162,10 @@ where let dim = arr.shape()[1]; let max_err = (0..sample_size) - .map(|_sample_ind| { + .map(|sample_ind| { let mut test_vec = rlst_dynamic_array1!(Item, [dim]); - let mut local_rng: rand::rngs::StdRng = rand::SeedableRng::from_entropy(); + let mut local_rng = + seeded_rng(0x5350_4543_0000 ^ (dim as u64).rotate_left(7) ^ sample_ind as u64); test_vec.fill_from_standard_normal(&mut local_rng); let mut res_vec = empty_array(); res_vec @@ -59,12 +181,19 @@ where } pub fn app_inv_error< - Item: RlstScalar + RandScalar + MatrixInverse + MatrixId + MatrixPseudoInverse + MatrixLu + MatrixQr, + Item: RlstScalar + + RandScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, >( target_arr: &DynamicArray, rsrs_factors: &mut RsrsFactors, sample_size: usize, - side: RsrsSide, + side: RsrsApply, ) -> f64 where StandardNormal: Distribution, @@ -77,44 +206,48 @@ where let dim = target_arr.shape()[1]; let mut sample_mat_1 = empty_array(); let mut sample_mat_2 = empty_array(); - let mut local_rng: rand::rngs::StdRng = rand::SeedableRng::from_entropy(); - let factor_options = MulOptions { + let mut local_rng = seeded_rng( + 0x1A11_0001 ^ apply_tag(&side) ^ (dim as u64).rotate_left(13) ^ sample_size as u64, + ); + let base_factor_options = BaseFactorOptions { inv: true, - trans: false, - side: Side::Left, - factor_type: FactorType::F, - t_trans: false, + trans: TransMode::NoTrans, + trans_target: false, }; + //num_cpus::get() let view_shape; let view_offset = match side { - RsrsSide::Left => |ind| [0, ind], - RsrsSide::Right => |ind| [ind, 0], - RsrsSide::Squeeze => |_ind| [0, 0], + RsrsApply::Left(_) => |ind| [0, ind], + RsrsApply::Right(_) => |ind| [ind, 0], + RsrsApply::Sandwich => |_ind| [0, 0], }; - match side { - RsrsSide::Left => { + let aux_side = match side { + RsrsApply::Left(_) => { sample_mat_1.resize_in_place([dim, sample_size]); sample_mat_1.fill_from_standard_normal(&mut local_rng); sample_mat_2 .r_mut() .simple_mult_into_resize(target_arr.r(), sample_mat_1.r()); view_shape = [dim, 1]; + Side::Left } - RsrsSide::Right => { + RsrsApply::Right(_) => { sample_mat_1.resize_in_place([sample_size, dim]); sample_mat_1.fill_from_standard_normal(&mut local_rng); sample_mat_2 .r_mut() .simple_mult_into_resize(sample_mat_1.r(), target_arr.r()); view_shape = [1, dim]; + Side::Right } - RsrsSide::Squeeze => { + RsrsApply::Sandwich => { view_shape = [0, 0]; + Side::Left //CHECK } - } + }; - rsrs_factors.matmul(&mut sample_mat_2, side, &factor_options); + rsrs_factors.matmul(&mut sample_mat_2, aux_side, &base_factor_options); let mut res = empty_array(); res.fill_from_resize(sample_mat_2.r() - sample_mat_1.r()); @@ -135,12 +268,19 @@ where } pub fn app_error< - Item: RlstScalar + RandScalar + MatrixInverse + MatrixId + MatrixPseudoInverse + MatrixLu + MatrixQr, + Item: RlstScalar + + RandScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, >( target_arr: &DynamicArray, rsrs_factors: &mut RsrsFactors, sample_size: usize, - side: RsrsSide, + side: RsrsApply, ) -> f64 where StandardNormal: Distribution, @@ -154,45 +294,49 @@ where let mut sample_mat_1 = empty_array(); let mut sample_mat_2 = empty_array(); - let mut local_rng: rand::rngs::StdRng = rand::SeedableRng::from_entropy(); - let factor_options = MulOptions { + let mut local_rng = seeded_rng( + 0x2A22_0002 ^ apply_tag(&side) ^ (dim as u64).rotate_left(17) ^ sample_size as u64, + ); + + let base_factor_options = BaseFactorOptions { inv: false, - trans: false, - side: Side::Left, - factor_type: FactorType::F, - t_trans: false, + trans: TransMode::NoTrans, + trans_target: false, }; let view_shape; let view_offset = match side { - RsrsSide::Left => |ind| [0, ind], - RsrsSide::Right => |ind| [ind, 0], - RsrsSide::Squeeze => |_ind| [0, 0], + RsrsApply::Left(_) => |ind| [0, ind], + RsrsApply::Right(_) => |ind| [ind, 0], + RsrsApply::Sandwich => |_ind| [0, 0], }; - match side { - RsrsSide::Left => { + let aux_side = match side { + RsrsApply::Left(_) => { sample_mat_1.resize_in_place([dim, sample_size]); sample_mat_1.fill_from_standard_normal(&mut local_rng); sample_mat_2 .r_mut() .simple_mult_into_resize(target_arr.r(), sample_mat_1.r()); view_shape = [dim, 1]; + Side::Left } - RsrsSide::Right => { + RsrsApply::Right(_) => { sample_mat_1.resize_in_place([sample_size, dim]); sample_mat_1.fill_from_standard_normal(&mut local_rng); sample_mat_2 .r_mut() .simple_mult_into_resize(sample_mat_1.r(), target_arr.r()); view_shape = [1, dim]; + Side::Right } - RsrsSide::Squeeze => { + RsrsApply::Sandwich => { view_shape = [0, 0]; + Side::Left } - } + }; - rsrs_factors.matmul(&mut sample_mat_1, side, &factor_options); + rsrs_factors.matmul(&mut sample_mat_1, aux_side, &base_factor_options); let mut res = empty_array(); res.fill_from_resize(sample_mat_2.r() - sample_mat_1.r()); @@ -215,7 +359,14 @@ where } pub fn rsrs_error_estimator< - Item: RlstScalar + RandScalar + MatrixInverse + MatrixId + MatrixPseudoInverse + MatrixLu + MatrixQr, + Item: RlstScalar + + RandScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, >( target_arr: &DynamicArray, rsrs_factors: &mut RsrsFactors, @@ -229,10 +380,30 @@ where TriangularMatrix: TriangularOperations, ::Real: RandScalar, { - let app_inv_err_left = app_inv_error(target_arr, rsrs_factors, sample_size, RsrsSide::Left); - let app_inv_err_right = app_inv_error(target_arr, rsrs_factors, sample_size, RsrsSide::Right); - let app_err_left = app_error(target_arr, rsrs_factors, sample_size, RsrsSide::Left); - let app_err_right = app_error(target_arr, rsrs_factors, sample_size, RsrsSide::Right); + let app_inv_err_left = app_inv_error( + target_arr, + rsrs_factors, + sample_size, + RsrsApply::Left(FactorType::F), + ); + let app_inv_err_right = app_inv_error( + target_arr, + rsrs_factors, + sample_size, + RsrsApply::Right(FactorType::F), + ); + let app_err_left = app_error( + target_arr, + rsrs_factors, + sample_size, + RsrsApply::Left(FactorType::F), + ); + let app_err_right = app_error( + target_arr, + rsrs_factors, + sample_size, + RsrsApply::Right(FactorType::F), + ); ( app_inv_err_left, @@ -305,7 +476,14 @@ where } fn commutative_factors_errors< - Item: RlstScalar + RandScalar + MatrixInverse + MatrixPseudoInverse + MatrixLu + MatrixId + MatrixQr, + Item: RlstScalar + + RandScalar + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixId + + MatrixIdNoSkel + + MatrixQr, >( factors: &CommutativeFactors, target_arr: &mut DynamicArray, @@ -319,21 +497,21 @@ where ::Real: RandScalar, { let target_arr = Arc::new(Mutex::new(target_arr)); - - let factor_options_left = MulOptions { + let base_options = BaseFactorOptions { inv: true, - trans: false, + trans: TransMode::NoTrans, + trans_target: false, + }; + let factor_options_left = MulOptions { + base_options: base_options.clone(), side: Side::Left, factor_type: FactorType::F, - t_trans: false, }; let factor_options_right = MulOptions { - inv: true, - trans: false, + base_options: base_options.clone(), side: Side::Right, factor_type: FactorType::S, - t_trans: false, }; let errors: Vec<_> = factors @@ -359,7 +537,7 @@ where } Factor::Diag(diag_box_factor) => { let mut exact_diag_box = as MatrixExtraction>::new( - &mut target_arr, + &target_arr, ExtInsType::Cross( diag_box_factor.inds.clone(), diag_box_factor.inds.clone(), @@ -373,14 +551,14 @@ where let mut app_dbox = rlst_dynamic_array2!(Item, shape); app_dbox.set_identity(); - let options = MulOptions { + let base_options = BaseFactorOptions { inv: false, - trans: false, - side: Side::Left, - factor_type: FactorType::F, - t_trans: false, + trans: TransMode::NoTrans, + trans_target: false, }; - diag_box_factor.arr.mul(&mut app_dbox, Side::Left, &options); + diag_box_factor + .arr + .mul(&mut app_dbox, &Side::Left, &base_options); let mut res: DynamicArray = empty_array(); res.fill_from_resize(exact_diag_box.r() - app_dbox.r()); @@ -388,28 +566,23 @@ where let err_diag = spectral_norm_estimator(&res, 10).unwrap() / spectral_norm_estimator(&exact_diag_box, 10).unwrap(); - let mut app_inv_dbox = rlst_dynamic_array2!(Item, shape); - app_inv_dbox.set_identity(); + let mut identity = rlst_dynamic_array2!(Item, shape); + identity.set_identity(); - let options = MulOptions { + let base_options = BaseFactorOptions { inv: true, - trans: false, - side: Side::Left, - factor_type: FactorType::F, - t_trans: false, + trans: TransMode::NoTrans, + trans_target: false, }; diag_box_factor .arr - .mul(&mut app_inv_dbox, Side::Left, &options); - - exact_diag_box.r_mut().into_inverse_alloc().unwrap(); + .mul(&mut exact_diag_box, &Side::Left, &base_options); let mut res: DynamicArray = empty_array(); - res.fill_from_resize(exact_diag_box.r() - app_inv_dbox.r()); + res.fill_from_resize(exact_diag_box.r() - identity.r()); - let err_inv_diag = spectral_norm_estimator(&res, 10).unwrap() - / spectral_norm_estimator(&exact_diag_box, 10).unwrap(); + let err_inv_diag = spectral_norm_estimator(&res, 10).unwrap(); let errors: Errors = (err_diag, err_inv_diag); errors @@ -422,7 +595,14 @@ where } fn el_factors_inv_mul_errors< - Item: RlstScalar + RandScalar + MatrixInverse + MatrixId + MatrixPseudoInverse + MatrixLu + MatrixQr, + Item: RlstScalar + + RandScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, >( rsrs_factors: &RsrsFactors, target_arr: &mut DynamicArray, @@ -437,8 +617,17 @@ where { let errors: Vec<(Vec, Vec)> = (0..rsrs_factors.num_levels) .map(|level_it| { - let factors = &rsrs_factors.id_factors[level_it]; - let id_errors = commutative_factors_errors(factors, target_arr); + let id_errors = match &rsrs_factors.id_factors { + MultiLevelIdFactors::Single(id_factors) => { + let factors = &id_factors[level_it]; + commutative_factors_errors(factors, target_arr) + } + MultiLevelIdFactors::Batched(id_factors) => id_factors[level_it] + .iter() + .flat_map(|id_batch| commutative_factors_errors(id_batch, target_arr)) + .collect(), + }; + let lu_errors = rsrs_factors.lu_factors[level_it] .iter() .flat_map(|lu_batch| commutative_factors_errors(lu_batch, target_arr)) @@ -475,6 +664,7 @@ where let mut id_stats = Vec::new(); let mut lu_stats = Vec::new(); + errors .iter() .for_each(|(id_level_errors, lu_level_errors)| { @@ -487,11 +677,18 @@ where } fn get_boxes_errors< - Item: RlstScalar + RandScalar + MatrixInverse + MatrixPseudoInverse + MatrixId + MatrixLu + MatrixQr, + Item: RlstScalar + + RandScalar + + MatrixInverse + + MatrixPseudoInverse + + MatrixId + + MatrixIdNoSkel + + MatrixLu + + MatrixQr, >( kernel_mat: &mut DynamicArray, rsrs_factors: &mut RsrsFactors, - tol: f64, + _tol: f64, ) where StandardNormal: Distribution>, Standard: Distribution>, @@ -516,7 +713,7 @@ fn get_boxes_errors< .for_each(|(level, stats)| { let (mu_1, mu_2, std_dev_1, std_dev_2) = stats; println!("Errors LU, level {level} : ({mu_1} +/- {std_dev_1}, {mu_2} +/- {std_dev_2})"); - assert!(*mu_1 <= tol && *mu_2 <= tol); + //assert!(*mu_1 <= tol && *mu_2 <= tol); }); println!("\n"); @@ -545,12 +742,12 @@ fn get_boxes_errors< "Mean residual diagonal blocks errors : {diag_re_r_mean:?}, sketch block error: {diag_re_s:?}" ); - assert!( + /*assert!( diag_re_r_mean.0 <= tol && diag_re_r_mean.1 <= tol && diag_re_s.0 <= tol && diag_re_s.1 <= tol - ); + );*/ } //Function that creates a low rank matrix by calculating a kernel given a random point distribution on an unit sphere. @@ -642,6 +839,7 @@ pub fn sphere_surface( //Matrix building +#[allow(dead_code)] fn laplace_kernel(dist: f64, npoints: usize) -> f64 { let pi = std::f64::consts::PI; let n: f64 = num::NumCast::from(npoints).unwrap(); @@ -656,6 +854,7 @@ fn helmholtz_kernel(dist: f64, npoints: usize, kappa: f64) -> Complex { (i * kappa * d).exp() / (4.0 * pi * n * d) } +#[allow(dead_code)] fn get_laplace_matrix(points_x: &[bempp_octree::Point]) -> DynamicArray { let n: usize = points_x.len(); let mut arr: DynamicArray = rlst_dynamic_array2!(f64, [n, n]); @@ -686,7 +885,7 @@ fn get_helmholtz_matrix(points_x: &[bempp_octree::Point]) -> DynamicArray, 2> = rlst_dynamic_array2!(Complex, [n, n]); let mut view = arr.r_mut(); - let pi = 0.0; + let pi = std::f64::consts::PI; for (i, point_x) in points_x.iter().enumerate() { for (j, point_y) in points_x.iter().enumerate() { let coords_x: [f64; 3] = point_x.coords(); @@ -709,13 +908,18 @@ fn get_helmholtz_matrix(points_x: &[bempp_octree::Point]) -> DynamicArray, id_tols: Vec, max_level: usize, max_leaf_points: usize, + benchmark_config: BenchmarkConfig, comm: &SimpleCommunicator, ) { + let configured_num_threads = env_usize("RSRS_EXAMPLE_NUM_THREADS") + .unwrap_or_else(num_cpus::get) + .max(1); for npts in npoints_vec { for &id_tol in id_tols.iter() { let points: Vec = sphere_surface(npts, comm); @@ -723,42 +927,83 @@ fn laplace_test( Octree::new(&points, max_level, max_leaf_points, comm); println!("Test: {npts} points, tol:{id_tol}"); let mut kernel_mat: DynamicArray = get_laplace_matrix(&points); - let operator = Operator::from(&kernel_mat); - let args = RsrsArgs::new( - 8, - 16, - 420, - NullMethod::Projection, - BlockExtractionMethod::LuLstSq, - BlockExtractionMethod::LuLstSq, - PivotMethod::Lu, - PivotMethod::Lu, - 1e-10, - id_tol, - 1e-10, - 1e-10, - 4, - 1, - true, - RankPicking::Min, - ); - - let options = RsrsOptions::::new(Some(args)); - let mut rsrs_algo = Rsrs::new(&tree, options, operator.domain().dimension()); - - let mut rsrs_factors = rsrs_algo.run(operator.r()); - let mul_errors = rsrs_error_estimator(&kernel_mat, &mut rsrs_factors, 10); - - println!("Multiplication errors: {mul_errors:?}\n"); - - assert!( - mul_errors.0 <= id_tol - && mul_errors.1 <= id_tol - && mul_errors.2 <= id_tol - && mul_errors.3 <= id_tol - ); - - get_boxes_errors(&mut kernel_mat, &mut rsrs_factors, id_tol); + let mut seed_run_max_errors = Vec::new(); + + for run_index in 0..benchmark_config.seed_runs { + let algo_seed = benchmark_config.algo_seed_for_run(run_index); + let probe_seed = benchmark_config.probe_seed_for_run(run_index); + set_probe_seed(probe_seed); + + println!( + "Seed run {}/{}: algo_seed={}, probe_seed={}", + run_index + 1, + benchmark_config.seed_runs, + algo_seed, + probe_seed + ); + + let args = RsrsArgs::new( + 8, + 16, + 0, + 420, + Shift::False, + NullMethod::Projection, + RankRevealingQrType::SRRQR(1.01), + BlockExtractionMethod::LuLstSq, + BlockExtractionMethod::LuLstSq, + PivotMethod::Lu(1e-10), + PivotMethod::Lu(0.0), + 1e-10, + id_tol, + 1e-10, + 1e-10, + 4, + 1, + bempp_rsrs::rsrs::args::Symmetry::Symmetric, + RankPicking::Min, + FactType::Joint, + false, + configured_num_threads, + false, + true, + ); + let options = RsrsOptions::::new(Some(args)); + let (mut rsrs_factors, mul_errors) = { + let operator = Operator::from(&kernel_mat); + let mut rsrs_algo = Rsrs::new(&tree, options, operator.domain().dimension()); + let mut rsrs_factors = rsrs_algo.run_with_seed(operator.r(), algo_seed); + let mul_errors = rsrs_error_estimator(&kernel_mat, &mut rsrs_factors, 10); + (rsrs_factors, mul_errors) + }; + + println!("Multiplication errors: {mul_errors:?}\n"); + seed_run_max_errors.push(max_mul_error(mul_errors)); + + if benchmark_config.include_box_errors { + get_boxes_errors(&mut kernel_mat, &mut rsrs_factors, id_tol); + } + } + + if benchmark_config.seed_runs > 1 { + let mut median_errors = seed_run_max_errors.clone(); + let median_max_error = median(&mut median_errors); + let min_max_error = seed_run_max_errors + .iter() + .copied() + .min_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(); + let max_max_error = seed_run_max_errors + .iter() + .copied() + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(); + + println!( + "Seed sweep summary: median max multiplication error = {}, min = {}, max = {}\n", + median_max_error, min_max_error, max_max_error + ); + } } } } @@ -768,8 +1013,12 @@ fn helmholtz_test( id_tols: Vec, max_level: usize, max_leaf_points: usize, + benchmark_config: BenchmarkConfig, comm: &SimpleCommunicator, ) { + let configured_num_threads = env_usize("RSRS_EXAMPLE_NUM_THREADS") + .unwrap_or_else(num_cpus::get) + .max(1); for npts in npoints_vec { for &id_tol in id_tols.iter() { let points: Vec = sphere_surface(npts, comm); @@ -777,22 +1026,87 @@ fn helmholtz_test( Octree::new(&points, max_level, max_leaf_points, comm); println!("Test: {npts} points, tol:{id_tol}"); let mut kernel_mat: DynamicArray, 2> = get_helmholtz_matrix(&points); - let operator = Operator::from(&kernel_mat); - let options = RsrsOptions::new(None); - let mut rsrs_algo = Rsrs::new(&tree, options, operator.domain().dimension()); - let mut rsrs_factors = rsrs_algo.run(operator.r()); - let mul_errors = rsrs_error_estimator(&kernel_mat, &mut rsrs_factors, 10); - - println!("Multiplication errors: {mul_errors:?}\n"); - - assert!( - mul_errors.0 <= id_tol - && mul_errors.1 <= id_tol - && mul_errors.2 <= id_tol - && mul_errors.3 <= id_tol - ); - - get_boxes_errors(&mut kernel_mat, &mut rsrs_factors, id_tol); + let mut seed_run_max_errors = Vec::new(); + + for run_index in 0..benchmark_config.seed_runs { + let algo_seed = benchmark_config.algo_seed_for_run(run_index); + let probe_seed = benchmark_config.probe_seed_for_run(run_index); + set_probe_seed(probe_seed); + + println!( + "Seed run {}/{}: algo_seed={}, probe_seed={}", + run_index + 1, + benchmark_config.seed_runs, + algo_seed, + probe_seed + ); + + let options = if env::var("RSRS_EXAMPLE_NUM_THREADS").is_ok() { + let args = RsrsArgs::new( + 8, + 16, + 0, + 420, + Shift::False, + NullMethod::Projection, + RankRevealingQrType::RRQR, + BlockExtractionMethod::LuLstSq, + BlockExtractionMethod::LuLstSq, + PivotMethod::Lu(1e-10), + PivotMethod::Lu(0.0), + 1e-10, + 1e-2, + 1e-10, + 1e-10, + 4, + 1, + bempp_rsrs::rsrs::args::Symmetry::NoSymm, + RankPicking::Min, + FactType::Joint, + false, + configured_num_threads, + false, + true, + ); + RsrsOptions::new(Some(args)) + } else { + RsrsOptions::new(None) + }; + let (mut rsrs_factors, mul_errors) = { + let operator = Operator::from(&kernel_mat); + let mut rsrs_algo = Rsrs::new(&tree, options, operator.domain().dimension()); + let mut rsrs_factors = rsrs_algo.run_with_seed(operator.r(), algo_seed); + let mul_errors = rsrs_error_estimator(&kernel_mat, &mut rsrs_factors, 10); + (rsrs_factors, mul_errors) + }; + + println!("Multiplication errors: {mul_errors:?}\n"); + seed_run_max_errors.push(max_mul_error(mul_errors)); + + if benchmark_config.include_box_errors { + get_boxes_errors(&mut kernel_mat, &mut rsrs_factors, id_tol); + } + } + + if benchmark_config.seed_runs > 1 { + let mut median_errors = seed_run_max_errors.clone(); + let median_max_error = median(&mut median_errors); + let min_max_error = seed_run_max_errors + .iter() + .copied() + .min_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(); + let max_max_error = seed_run_max_errors + .iter() + .copied() + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(); + + println!( + "Seed sweep summary: median max multiplication error = {}, min = {}, max = {}\n", + median_max_error, min_max_error, max_max_error + ); + } } } } @@ -801,25 +1115,60 @@ pub fn main() { let universe: mpi::environment::Universe = mpi::initialize().unwrap(); let comm: SimpleCommunicator = universe.world(); //Error testing - let max_level: usize = 16; - let max_leaf_points: usize = 30; - - let id_tols = [4.0]; - let npoints_vec = [1000]; - - laplace_test( - npoints_vec.to_vec(), - id_tols.to_vec(), - max_level, - max_leaf_points, - &comm, + let max_level = env::var("RSRS_EXAMPLE_MAX_LEVEL") + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(2); + let max_leaf_points = env::var("RSRS_EXAMPLE_MAX_LEAF_POINTS") + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(30); + let algo_seed = env_u64("RSRS_ALGO_SEED").unwrap_or(DEFAULT_ALGO_SEED); + let probe_seed = + env_u64("RSRS_PROBE_SEED").unwrap_or_else(|| mix_seed(algo_seed ^ PROBE_SEED_XOR_TAG)); + let seed_runs = env_usize("RSRS_EXAMPLE_NUM_SEEDS").unwrap_or(1).max(1); + let include_box_errors = env_flag("RSRS_EXAMPLE_INCLUDE_BOX_ERRORS").unwrap_or(seed_runs == 1); + let benchmark_config = BenchmarkConfig { + algo_seed, + probe_seed, + seed_runs, + include_box_errors, + }; + set_probe_seed(probe_seed); + println!( + "Benchmark seeds: algo_seed={}, probe_seed={}, seed_runs={}, include_box_errors={}", + algo_seed, probe_seed, seed_runs, include_box_errors ); - helmholtz_test( - npoints_vec.to_vec(), - id_tols.to_vec(), - max_level, - max_leaf_points, - &comm, - ); + let id_tols = env::var("RSRS_EXAMPLE_ID_TOL") + .ok() + .and_then(|value| value.parse::().ok()) + .map(|value| vec![value]) + .unwrap_or_else(|| vec![1e-2]); + let npoints_vec = env::var("RSRS_EXAMPLE_NPOINTS") + .ok() + .and_then(|value| value.parse::().ok()) + .map(|value| vec![value]) + .unwrap_or_else(|| vec![5000]); + + let benchmark_case = env::var("RSRS_EXAMPLE_CASE").unwrap_or_else(|_| "helmholtz".to_string()); + + match benchmark_case.as_str() { + "laplace" => laplace_test( + npoints_vec.clone(), + id_tols.clone(), + max_level, + max_leaf_points, + benchmark_config, + &comm, + ), + _ => helmholtz_test( + npoints_vec, + id_tols, + max_level, + max_leaf_points, + benchmark_config, + &comm, + ), + } } diff --git a/src/.DS_Store b/src/.DS_Store deleted file mode 100644 index 90d405d..0000000 Binary files a/src/.DS_Store and /dev/null differ diff --git a/src/rsrs.rs b/src/rsrs.rs index f1a7a8f..b953638 100644 --- a/src/rsrs.rs +++ b/src/rsrs.rs @@ -1,9 +1,11 @@ //! An implementation of the RSRS algorithm. +pub mod args; pub mod box_skeletonisation; pub mod rsrs_cycle; pub mod rsrs_factors; pub mod sketch; +pub mod statistics; pub mod tree_indexing; #[cfg(test)] diff --git a/src/rsrs/args.rs b/src/rsrs/args.rs new file mode 100644 index 0000000..ee3a732 --- /dev/null +++ b/src/rsrs/args.rs @@ -0,0 +1,378 @@ +use crate::{ + rsrs::{ + rsrs_factors::{ + null_and_extract::{ExtractOptions, IdOptions, PivotMethod}, + rsrs_operator::FactType, + }, + sketch::Shift, + }, + utils::linear_algebra::{BlockExtractionMethod, NullMethod}, +}; +use rlst::prelude::*; +use serde::{Deserialize, Serialize}; +use std::fmt::Write; + +type Real = ::Real; + +fn default_load_samples() -> bool { + true +} + +#[derive(Debug, Clone, Deserialize)] +pub enum RankPicking { + Min, + DoubleMin, + Max, + Avg, + Mid, + Tol, +} + +fn default_fixed_rank_sampling_mode() -> FixedRankSamplingMode { + FixedRankSamplingMode::PerLevel +} + +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] +pub enum FixedRankSamplingMode { + PerLevel, + Constant, +} + +#[derive(Debug, Clone)] +pub struct SketchingOptions { + pub oversampling: usize, + pub oversampling_diag_blocks: usize, + pub initial_num_samples: usize, + pub min_num_samples: usize, + pub fixed_rank_sampling_mode: FixedRankSamplingMode, + pub shift: Shift, + pub save_samples: bool, + pub load_samples: bool, + pub sample_storage_dir: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub enum Symmetry { + NoSymm, + Symmetric, + Hermitian, +} + +impl Symmetry { + /// Returns whether RSRS can avoid sampling a second sketch stream. + /// + /// `Symmetric` and `Hermitian` both sample only `y = AΩ`. Complex + /// symmetric matrices may still opt into split factor storage later, but + /// they do not need a stored `z` sketch. + pub fn symm_val(&self) -> bool { + match self { + Symmetry::NoSymm => false, + Symmetry::Symmetric => true, + Symmetry::Hermitian => true, + } + } + + /// Returns whether factor application can reuse a single symmetric storage + /// family. + /// + /// Real symmetric and Hermitian operators use the one-family storage path. + /// Complex symmetric operators still sample only `y`, but they build split + /// left/right factors from synthetic conjugated data. + pub fn factor_symm_val(&self) -> bool { + match self { + Symmetry::NoSymm => false, + Symmetry::Symmetric => std::mem::size_of::() == std::mem::size_of::(), + Symmetry::Hermitian => true, + } + } + + /// Returns whether this is the special complex symmetric case: one-sketch + /// sampling, but split factor construction from conjugated `y` data. + pub fn complex_symmetric_val(&self) -> bool { + matches!(self, Symmetry::Symmetric) && !self.factor_symm_val::() + } +} + +#[derive(Debug, Clone)] +pub struct RsrsOptions { + pub fact_type: FactType, + pub sketching: SketchingOptions, + pub id_options: IdOptions, + pub lu_options: ExtractOptions, + pub extract_db_options: ExtractOptions, + pub min_rank: usize, + pub symmetry: Symmetry, + pub min_level: usize, + pub rank_picking: RankPicking, + pub num_threads: usize, + pub flush_factors: bool, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(bound = "Real: Deserialize<'de>")] +pub struct RsrsArgs { + oversampling: usize, + oversampling_diag_blocks: usize, + min_num_samples: usize, + initial_num_samples: usize, + #[serde(default = "default_fixed_rank_sampling_mode")] + fixed_rank_sampling_mode: FixedRankSamplingMode, + shift: Shift, + null_method: NullMethod, + qr_method: RankRevealingQrType>, + near_block_extraction_method: BlockExtractionMethod, + diag_block_extraction_method: BlockExtractionMethod, + lu_pivot_method: PivotMethod, + diag_pivot_method: PivotMethod, + tol_null: Real, + tol_id: Real, + tol_ext_near: Real, + tol_diag_ext: Real, + min_rank: usize, + min_level: usize, + symmetry: Symmetry, + rank_picking: RankPicking, + fact_type: FactType, + save_samples: bool, + #[serde(default = "default_load_samples")] + load_samples: bool, + #[serde(default)] + sample_storage_dir: Option, + num_threads: usize, + flush_factors: bool, + store_far: bool, +} + +impl RsrsArgs +where + Item: RlstScalar, +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + oversampling: usize, + oversampling_diag_blocks: usize, + min_num_samples: usize, + initial_num_samples: usize, + shift: Shift, + null_method: NullMethod, + qr_method: RankRevealingQrType>, + near_block_extraction_method: BlockExtractionMethod, + diag_block_extraction_method: BlockExtractionMethod, + lu_pivot_method: PivotMethod, + diag_pivot_method: PivotMethod, + tol_null: Real, + tol_id: Real, + tol_ext_near: Real, + tol_diag_ext: Real, + min_rank: usize, + min_level: usize, + symmetry: Symmetry, + rank_picking: RankPicking, + fact_type: FactType, + save_samples: bool, + num_threads: usize, + flush_factors: bool, + store_far: bool, + ) -> Self { + Self { + oversampling, + oversampling_diag_blocks, + min_num_samples, + initial_num_samples, + fixed_rank_sampling_mode: default_fixed_rank_sampling_mode(), + shift, + null_method, + qr_method, + near_block_extraction_method, + diag_block_extraction_method, + lu_pivot_method, + diag_pivot_method, + tol_null, + tol_id, + tol_ext_near, + tol_diag_ext, + min_rank, + min_level, + symmetry, + rank_picking, + fact_type, + save_samples, + load_samples: true, + sample_storage_dir: None, + num_threads, + flush_factors, + store_far, + } + } +} + +impl RsrsOptions { + pub fn new(args: Option>) -> Self { + let args = match args { + Some(input) => input, + None => RsrsArgs::new( + 8, + 16, + 0, + 420, + Shift::False, + NullMethod::Projection, + RankRevealingQrType::RRQR, + BlockExtractionMethod::LuLstSq, + BlockExtractionMethod::LuLstSq, + PivotMethod::LuHybrid(0.0), + PivotMethod::LuHybrid(0.0), + Item::real(1e-10), + Item::real(1e-2), + Item::real(1e-10), + Item::real(1e-10), + 4, + 1, + Symmetry::NoSymm, + RankPicking::Min, + FactType::Joint, + false, + num_cpus::get(), + false, + true, + ), + }; + + let min_rank = if args.tol_id > num::One::one() { + let k = num::ToPrimitive::to_usize(&args.tol_id).unwrap(); + println!("For tolerances > 1, ID will use this as a fixed rank instead. This fixed rank is: {k}"); + + if k <= args.min_rank { + k + } else { + args.min_rank + } + } else { + args.min_rank + }; + + Self { + sketching: SketchingOptions { + oversampling: args.oversampling, + oversampling_diag_blocks: args.oversampling_diag_blocks, + min_num_samples: args.min_num_samples, + initial_num_samples: args.initial_num_samples, + fixed_rank_sampling_mode: args.fixed_rank_sampling_mode, + shift: args.shift, + save_samples: args.save_samples, + load_samples: args.load_samples, + sample_storage_dir: args.sample_storage_dir, + }, + id_options: IdOptions { + null_method: args.null_method, + qr_method: args.qr_method, + tol_null: args.tol_null, + tol_id: args.tol_id, + store_far: args.store_far, + }, + lu_options: ExtractOptions { + block_extraction_method: args.near_block_extraction_method, + pivot_method: args.lu_pivot_method, + tol_lstsq: args.tol_ext_near, + }, + extract_db_options: ExtractOptions { + block_extraction_method: args.diag_block_extraction_method, + pivot_method: args.diag_pivot_method, + tol_lstsq: args.tol_diag_ext, + }, + fact_type: args.fact_type, + min_rank, + min_level: args.min_level, + symmetry: args.symmetry, + rank_picking: args.rank_picking, + num_threads: args.num_threads, + flush_factors: args.flush_factors, + } + } + + pub fn to_identifier(&self) -> String { + let mut id = String::from("rsrs"); + + write!( + &mut id, + "_null_{:?}_toln_{:e}", + self.id_options.null_method, self.id_options.tol_null, + ) + .unwrap(); + + match self.sketching.shift { + Shift::True(alpha) => write!( + &mut id, + "_os_{os}_osdiag_{osdiag}_initsam_{init}_shiftd_{alpha:.e}", + os = self.sketching.oversampling, + osdiag = self.sketching.oversampling_diag_blocks, + init = self.sketching.initial_num_samples, + // Keep the sampling-budget mode in the identifier so fixed-rank + // runs with different activation policies do not collide. + // The non-shifted branch uses the same token below. + alpha = alpha + ) + .unwrap(), + Shift::False => write!( + &mut id, + "_os_{os}_osdiag_{osdiag}_initsam_{init}", + os = self.sketching.oversampling, + osdiag = self.sketching.oversampling_diag_blocks, + init = self.sketching.initial_num_samples + ) + .unwrap(), + }; + + write!( + &mut id, + "_fsamp_{:?}", + self.sketching.fixed_rank_sampling_mode + ) + .unwrap(); + + match self.id_options.qr_method{ + RankRevealingQrType::RRQR => write!( + &mut id, + "_mrnk_{}_mlvl_{}_{:?}_rpick_{:?}_next_{:?}_tolextn_{:e}_db_ext_{:?}_tol_lstsq_{:e}_rrqr", + self.min_rank, + self.min_level, + self.symmetry, + self.rank_picking, + self.lu_options.block_extraction_method, + self.lu_options.tol_lstsq, + self.extract_db_options.block_extraction_method, + self.extract_db_options.tol_lstsq + ) + .unwrap(), + RankRevealingQrType::SRRQR(f) => write!( + &mut id, + "_mrnk_{}_mlvl_{}_{:?}_rpick_{:?}_next_{:?}_tolextn_{:e}_db_ext_{:?}_tol_lstsq_{:e}_srrqr_{:e}", + self.min_rank, + self.min_level, + self.symmetry, + self.rank_picking, + self.lu_options.block_extraction_method, + self.lu_options.tol_lstsq, + self.extract_db_options.block_extraction_method, + self.extract_db_options.tol_lstsq, + f + ) + .unwrap(), + }; + + id + } +} + +impl RsrsArgs +where + Item: RlstScalar, +{ + pub fn with_fixed_rank_sampling_mode( + mut self, + fixed_rank_sampling_mode: FixedRankSamplingMode, + ) -> Self { + self.fixed_rank_sampling_mode = fixed_rank_sampling_mode; + self + } +} diff --git a/src/rsrs/box_skeletonisation.rs b/src/rsrs/box_skeletonisation.rs index cac6830..b47fbbe 100644 --- a/src/rsrs/box_skeletonisation.rs +++ b/src/rsrs/box_skeletonisation.rs @@ -1,16 +1,18 @@ -use super::{ - rsrs_cycle::{BoxType, RsrsOptions}, - rsrs_factors::{IdTimes, LuTimes, Times}, -}; +use crate::rsrs::args::RsrsOptions; +use crate::rsrs::rsrs_factors::commutative_factors::BoxType; +use crate::rsrs::rsrs_factors::commutative_factors::IdFactor; +use crate::rsrs::rsrs_factors::commutative_factors::LuFactor; +use crate::rsrs::rsrs_factors::null_and_extract::ExtractionScratch; use crate::rsrs::sketch::SamplingSpace; -use crate::rsrs::{ - rsrs_factors::{IdFactor, LuFactor}, - sketch::SketchData, -}; +use crate::rsrs::sketch::SketchData; +use crate::rsrs::statistics::IdTimes; +use crate::rsrs::statistics::Times; use rand_distr::{Distribution, Standard, StandardNormal}; -use rlst::dense::{linalg::lu::MatrixLu, tools::RandScalar}; +use rlst::dense::{ + linalg::{interpolative_decomposition::MatrixIdNoSkel, lu::MatrixLu}, + tools::RandScalar, +}; pub use rlst::prelude::*; -use serde::Serialize; pub struct Tols { pub id: ::Real, pub null: ::Real, @@ -24,45 +26,18 @@ pub struct LowRankResult { pub id_times: Times, } +pub struct FullRankResult { + pub len_near_field_inds: usize, + pub len_target_inds: usize, + pub id_times: Times, +} + #[allow(clippy::large_enum_variant)] pub enum Rank { Low(LowRankResult), - Full(Times), -} - -#[derive(Debug, Serialize, Clone)] -pub struct UpdateTimes { - pub id: u128, - pub lu: u128, -} - -macro_rules! impl_times_operations { - ($struct_name:ident, $trait_name:ident, $arg_1:ident, $arg_2:ident) => { - pub trait $trait_name { - fn new() -> Self; - fn sum(&mut self, $arg_1: u128, $arg_2: u128); - } - - impl $trait_name for $struct_name { - fn new() -> Self { - Self { - $arg_1: 0_u128, - $arg_2: 0_u128, - } - } - - fn sum(&mut self, $arg_1: u128, $arg_2: u128) { - self.$arg_1 += $arg_1; - self.$arg_2 += $arg_2; - } - } - }; + Full(FullRankResult), } -impl_times_operations!(IdTimes, IdTimesOperations, nullification, id); -impl_times_operations!(LuTimes, LuTimesOperations, extraction, lu); -impl_times_operations!(UpdateTimes, UpdateTimesOperations, id, lu); - type Real = ::Real; pub trait Skel> where @@ -72,6 +47,7 @@ where #[allow(clippy::too_many_arguments)] fn id_step( &mut self, + scratch: &mut ExtractionScratch, box_type: &BoxType>, target_inds: &[usize], near_field_inds: &[usize], @@ -80,20 +56,24 @@ where subs_sample_dim: usize, options: &RsrsOptions, ) -> Rank; + #[allow(clippy::too_many_arguments)] fn lu_step( &self, + scratch: &mut ExtractionScratch, y_data: &SketchData, z_data: &SketchData, - ind_r: &mut [usize], - near_field_inds: &mut [usize], + ind_r: &[usize], + near_field_inds: &[usize], + inactive_inds: &[usize], subs_sample_dim: usize, options: &RsrsOptions, - ) -> (LuFactor, Times); + ) -> Option<(LuFactor, Times)>; } impl< T: RlstScalar + MatrixId + + MatrixIdNoSkel + MatrixNull + MatrixInverse + MatrixPseudoInverse @@ -112,6 +92,7 @@ where type Item = T; fn id_step( &mut self, + scratch: &mut ExtractionScratch, box_type: &BoxType>, target_inds: &[usize], near_field_inds: &[usize], @@ -120,63 +101,108 @@ where subs_sample_dim: usize, options: &RsrsOptions, ) -> Rank { - if target_inds.len() <= options.min_rank { + let fixed_rank = if options.id_options.tol_id > num::One::one() { + Some(num::ToPrimitive::to_usize(&options.id_options.tol_id).unwrap()) + } else { + None + }; + if let Some(rank) = fixed_rank { + assert!( + target_inds.len() > rank, + "Fixed-rank ID launched with non-launchable box: |I_B|_act={} <= k={rank}", + target_inds.len(), + ); + } + let skip_id = target_inds.len() <= options.min_rank; + if skip_id { let id_times = IdTimes { nullification: 0_u128, id: 0_u128, }; let times = Times::Id(id_times); - return Rank::Full(times); + let full_rank_result: FullRankResult = FullRankResult { + len_near_field_inds: near_field_inds.to_vec().len(), + id_times: times, + len_target_inds: target_inds.len(), + }; + return Rank::Full(full_rank_result); } let mut local_target_inds = target_inds.to_vec(); let mut local_near_field_inds = near_field_inds.to_vec(); let (id_factor, id_times) = IdFactor::new( + scratch, &mut local_target_inds, &mut local_near_field_inds, y_data, z_data, subs_sample_dim, + options.id_options.tol_id > num::One::one(), box_type, - options, + &options.id_options, + &options.symmetry, ); match id_factor { - Some(low_rank_factor) => { - if !low_rank_factor.ind_r.is_empty() { - let low_rank_result: LowRankResult = LowRankResult { - id_factor: low_rank_factor, - near_field_inds: local_near_field_inds, - id_times, - target_inds: local_target_inds, - }; - Rank::Low(low_rank_result) - } else { - Rank::Full(id_times) - } + Some(low_rank_factor) if !low_rank_factor.ind_r.is_empty() => { + let low_rank_result: LowRankResult = LowRankResult { + id_factor: low_rank_factor, + near_field_inds: local_near_field_inds, + id_times, + target_inds: local_target_inds, + }; + Rank::Low(low_rank_result) + } + None => { + let full_rank_result: FullRankResult = FullRankResult { + len_near_field_inds: near_field_inds.to_vec().len(), + id_times, + len_target_inds: target_inds.len(), + }; + Rank::Full(full_rank_result) + } + Some(_) => { + let full_rank_result: FullRankResult = FullRankResult { + len_near_field_inds: near_field_inds.to_vec().len(), + id_times, + len_target_inds: target_inds.len(), + }; + Rank::Full(full_rank_result) } - None => Rank::Full(id_times), } } fn lu_step( &self, + scratch: &mut ExtractionScratch, y_data: &SketchData, z_data: &SketchData, - ind_r: &mut [usize], - near_field_inds: &mut [usize], + ind_r: &[usize], + near_field_inds: &[usize], + inactive_inds: &[usize], subs_sample_dim: usize, options: &RsrsOptions, - ) -> (LuFactor, Times) { - let (lu_factors, lu_times) = LuFactor::new( - ind_r, - near_field_inds, - y_data, - z_data, - subs_sample_dim, - options, - ); - (lu_factors.unwrap(), lu_times) + ) -> Option<(LuFactor, Times)> { + if ind_r.is_empty() { + return None; + } + if near_field_inds.len() > ind_r.len() { + let (lu_factors, lu_times) = LuFactor::new( + scratch, + ind_r, + near_field_inds, + inactive_inds, + y_data, + z_data, + subs_sample_dim, + options.id_options.tol_id > num::One::one(), + &options.lu_options, + &options.symmetry, + ); + Some((lu_factors.unwrap(), lu_times)) + } else { + None + } } } diff --git a/src/rsrs/rsrs_cycle.rs b/src/rsrs/rsrs_cycle.rs index db5d9e8..1381e4f 100644 --- a/src/rsrs/rsrs_cycle.rs +++ b/src/rsrs/rsrs_cycle.rs @@ -1,79 +1,54 @@ use super::{ - box_skeletonisation::{ - IdTimesOperations, LuTimesOperations, Rank, Skel, UpdateTimes, UpdateTimesOperations, - }, - rsrs_factors::{DiagBoxFactor, LuTimes, PivotMethod, RsrsFactors, RsrsFactorsImpl}, - sketch::SketchData, + box_skeletonisation::{Rank, Skel}, + sketch::{apply_shift_delta, mix_seed, shift_alpha, SketchData}, tree_indexing::{TreeData, TreeIndexing}, }; -use crate::rsrs::{ - rsrs_factors::{LocalFromSpaces, RsrsOperator}, - sketch::SamplingSpace, -}; -use crate::{ - rsrs::rsrs_factors::{IdTimes, Times}, - utils::least_squares_and_null::NullMethod, -}; use crate::{ rsrs::{ - rsrs_factors::{CommutativeFactors, CommutativeFactorsOperations, Factor}, - sketch::UpdateType, + args::{FixedRankSamplingMode, RankPicking, RsrsOptions, Symmetry}, + rsrs_factors::{ + commutative_factors::{ + BoxType, CommutativeFactors, CommutativeFactorsOperations, DiagBoxFactor, + DiagExtractionScratch, Factor, LuFactor, MultiLevelIdFactors, RsrsFactors, + }, + null_and_extract::ExtractionScratch, + rsrs_operator::{FactType, LocalFromSpaces, RsrsFactorsImpl, RsrsOperator}, + }, + sketch::{SamplingSpace, UpdateType}, + statistics::{ + FactorMemoryStats, IdTimes, IdTimesOperations, LevelEffort, LimitingFactors, + LimitingLevel, LuTimes, LuTimesOperations, MemorySnapshot, Stats, Times, UpdateTimes, + UpdateTimesOperations, + }, + }, + utils::{ + io::{resolve_sampling_dir, IOData}, + memory::{format_bytes, process_memory_usage}, }, - utils::least_squares_and_null::BlockExtractionMethod, }; use bempp_octree::{MortonKey, Octree}; use mpi::traits::CommunicatorCollectives; +use rand::{rngs::OsRng, RngCore}; use rand_distr::{Distribution, Standard, StandardNormal}; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; -use rlst::dense::{linalg::lu::MatrixLu, tools::RandScalar}; +use rayon::{ + iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}, + ThreadPoolBuilder, +}; +use rlst::dense::{ + linalg::{interpolative_decomposition::MatrixIdNoSkel, lu::MatrixLu}, + tools::RandScalar, +}; pub use rlst::prelude::*; use rustc_hash::FxHashSet; -use serde::Deserialize; use std::{ collections::HashMap, - fmt::Write, + path::{Path, PathBuf}, time::{Duration, Instant}, -}; // Ensure IndexableSpace is in scope -type Inds = Vec>; - -#[derive(Debug)] -pub struct LimitingLevel { - pub level: usize, - pub num_boxes: usize, - pub active_points: usize, - pub elapsed_time: u128, -} +}; -#[derive(Debug)] -pub struct LimitingFactors { - pub min_samples: usize, - pub max_level: usize, - pub limiting_level: LimitingLevel, -} +type Inds = Vec>; -#[derive(Debug)] -pub struct Stats { - pub sampling_time: Vec, - pub sampling_extraction_time: u128, - pub id_times: Vec, - pub tot_id_time: u128, - pub lu_times: Vec, - pub tot_lu_time: u128, - pub update_times: Vec, - pub total_elapsed_time: u128, - pub total_elapsed_time_wo_sampling: u128, - pub dim: usize, - pub extraction_time: u128, - pub residual_size: usize, - pub ranks: Vec, - pub box_sizes: Vec, - pub near_field_sizes: Vec, - pub dec_boxes_per_level: Vec, - pub index_calculation: u128, - pub sorting_near_field: u128, - pub residual_calculation: u128, - pub limiting_factors: LimitingFactors, -} +type Real = ::Real; pub struct Rsrs { level_indexing: TreeData, @@ -88,233 +63,541 @@ pub struct Rsrs { pub active_samples: usize, pub stats: Stats, options: RsrsOptions, + anticipated_fixed_rank_samples: Option, } -#[derive(Debug, Clone)] -pub enum BoxType { - Merged(usize), - Full(Real), +#[derive(Debug, Clone, PartialEq, Eq)] +enum FixedRankSampleBudget { + PerLevel { + total_samples: usize, + active_samples_by_level: HashMap, + box_samples_by_level: HashMap>, + }, + Constant { + samples: usize, + }, } -#[derive(Debug, Clone, Deserialize)] -pub enum RankPicking { - Min, - DoubleMin, - Max, - Avg, - Mid, - Tol, +impl FixedRankSampleBudget { + fn total_samples(&self) -> usize { + match self { + Self::PerLevel { total_samples, .. } => *total_samples, + Self::Constant { samples } => *samples, + } + } + + fn with_global_min_samples(self, min_num_samples: usize) -> Self { + if min_num_samples > self.total_samples() { + Self::Constant { + samples: min_num_samples, + } + } else { + self + } + } + + fn per_level_box_samples(&self, level: usize, key: &MortonKey) -> Option { + match self { + Self::PerLevel { + box_samples_by_level, + .. + } => box_samples_by_level + .get(&level) + .and_then(|level_samples| level_samples.get(key)) + .copied(), + Self::Constant { .. } => None, + } + } } -#[derive(Debug, Clone)] -pub struct IdOptions { - pub null_method: NullMethod, - pub tol_null: Real, - pub tol_id: Real, +fn oversample( + samples: usize, + oversampling: usize, + id_tol: ::Real, + ms: usize, +) -> usize { + if id_tol < num::One::one() { + (samples + (samples / 100) * oversampling).max(ms) + } else { + (samples + num::ToPrimitive::to_usize(&id_tol).unwrap()).max(ms) + } } -#[derive(Debug, Clone)] -pub struct ExtractOptions { - pub block_extraction_method: BlockExtractionMethod, - pub pivot_method: PivotMethod, - pub tol_lstsq: Real, +fn per_level_boxwise_fixed_rank_samples( + active_box_size: usize, + active_near_with_self_size: usize, + rank: usize, + p_param: usize, + min_num_samples: usize, +) -> usize { + (active_near_with_self_size + active_box_size.min(rank) + p_param).max(min_num_samples) } -#[derive(Debug, Clone)] -pub struct SketchingOptions { - pub oversampling: usize, - pub oversampling_diag_blocks: usize, - pub initial_num_samples: usize, +fn root_fixed_rank_samples( + root_sketch_size: usize, + p_param: usize, + smax: usize, + min_num_samples: usize, +) -> usize { + root_sketch_size + .saturating_mul(2) + .saturating_add(2 * p_param) + .min(smax) + .max(min_num_samples) } -#[derive(Debug, Clone)] -pub struct RsrsOptions { - pub sketching: SketchingOptions, - pub id_options: IdOptions, - pub lu_options: ExtractOptions, - pub extract_db_options: ExtractOptions, - pub min_rank: usize, - pub hermitian: bool, - pub min_level: usize, - pub rank_picking: RankPicking, +fn local_sample_count( + per_box_samples: Option, + active_near_with_self_size: usize, + active_samples: usize, + fixed_rank: bool, + fixed_rank_sampling_mode: FixedRankSamplingMode, + options: &RsrsOptions, +) -> usize { + if !fixed_rank { + return active_samples; + } + + let rank = num::ToPrimitive::to_usize(&options.id_options.tol_id).unwrap(); + match fixed_rank_sampling_mode { + FixedRankSamplingMode::Constant => active_samples, + FixedRankSamplingMode::PerLevel => { + let local_samples = per_box_samples.unwrap_or(active_samples); + let min_safe_samples = active_near_with_self_size.saturating_add(2); + let local_samples = if local_samples < min_safe_samples { + active_samples + } else { + local_samples + }; + assert!( + local_samples <= active_samples, + "PerLevel local sample request ({local_samples}) exceeds active samples ({active_samples}) for rank={}, p={}", + rank, + options.sketching.oversampling, + ); + local_samples + } + } } -#[derive(Debug, Clone, Deserialize)] -#[serde(bound = "Real: Deserialize<'de>")] -pub struct RsrsArgs { - oversampling: usize, - oversampling_diag_blocks: usize, - initial_num_samples: usize, - null_method: NullMethod, - near_block_extraction_method: BlockExtractionMethod, - diag_block_extraction_method: BlockExtractionMethod, - lu_pivot_method: PivotMethod, - diag_pivot_method: PivotMethod, - tol_null: Real, - tol_id: Real, - tol_ext_near: Real, - tol_diag_ext: Real, - min_rank: usize, - min_level: usize, - hermitian: bool, - rank_picking: RankPicking, +fn assert_post_nullification_headroom( + active_box_size: usize, + active_near_with_self_size: usize, + subs_sample_dim: usize, + options: &RsrsOptions, +) { + let rank = num::ToPrimitive::to_usize(&options.id_options.tol_id).unwrap(); + let p = options.sketching.oversampling; + let required = active_near_with_self_size + active_box_size.min(rank) + p; + assert!( + subs_sample_dim >= required, + "Insufficient fixed-rank samples for post-nullification sketch: samples={}, required={}, active_box_size={}, active_near_with_self_size={}, rank={}, p={}", + subs_sample_dim, + required, + active_box_size, + active_near_with_self_size, + rank, + p, + ); } -impl RsrsArgs -where - Item: RlstScalar, -{ - #[allow(clippy::too_many_arguments)] - pub fn new( - oversampling: usize, - oversampling_diag_blocks: usize, - initial_num_samples: usize, - null_method: NullMethod, - near_block_extraction_method: BlockExtractionMethod, - diag_block_extraction_method: BlockExtractionMethod, - lu_pivot_method: PivotMethod, - diag_pivot_method: PivotMethod, - tol_null: Real, - tol_id: Real, - tol_ext_near: Real, - tol_diag_ext: Real, - min_rank: usize, - min_level: usize, - hermitian: bool, - rank_picking: RankPicking, - ) -> Self { - Self { - oversampling, - oversampling_diag_blocks, - initial_num_samples, - null_method, - near_block_extraction_method, - diag_block_extraction_method, - lu_pivot_method, - diag_pivot_method, - tol_null, - tol_id, - tol_ext_near, - tol_diag_ext, - min_rank, - min_level, - hermitian, - rank_picking, - } +fn should_launch_box( + active_box_size: usize, + fixed_rank: bool, + options: &RsrsOptions, +) -> bool { + if active_box_size == 0 { + return false; + } + + if !fixed_rank { + return true; } + + let rank = num::ToPrimitive::to_usize(&options.id_options.tol_id).unwrap(); + active_box_size > rank } -impl RsrsOptions { - pub fn new(args: Option>) -> Self { - let args = match args { - Some(input) => input, - None => RsrsArgs::new( - 8, - 16, - 420, - NullMethod::Projection, - BlockExtractionMethod::LuLstSq, - BlockExtractionMethod::LuLstSq, - PivotMethod::Lu, - PivotMethod::Lu, - Item::real(1e-10), - Item::real(1e-2), - Item::real(1e-10), - Item::real(1e-10), - 4, - 1, - true, - RankPicking::Min, - ), - }; +fn default_run_seed() -> u64 { + mix_seed(OsRng.next_u64()) +} + +fn scoped_seed(base_seed: u64, level_it: usize, stage_tag: u64) -> u64 { + mix_seed(base_seed ^ (level_it as u64).wrapping_mul(0x9E3779B97F4A7C15) ^ stage_tag) +} + +fn round_up_to_multiple_of_five(value: usize) -> usize { + if value == 0 { + 0 + } else { + 5 * value.div_ceil(5) + } +} - let min_rank = if args.tol_id > num::One::one() { - let k = num::ToPrimitive::to_usize(&args.tol_id).unwrap(); - println!("For tolerances > 1, ID will use this as a fixed rank instead. This fixed rank is: {k}"); +struct FixedRankSampleEstimate { + total_samples: usize, + max_s_vec_k: usize, + max_s_vec_p: usize, + effective_p: usize, +} + +struct FixedRankRuntimeSampleEstimate { + total_samples: usize, + active_samples_by_level: HashMap, + box_samples_by_level: HashMap>, + root_sketch_size: usize, +} - if k <= args.min_rank { - k +fn format_level_samples(active_samples_by_level: &HashMap) -> String { + let mut level_samples = active_samples_by_level.iter().collect::>(); + level_samples.sort_by_key(|(level, _)| *level); + level_samples + .into_iter() + .map(|(level, samples)| format!("{level}->{samples}")) + .collect::>() + .join(", ") +} + +fn fixed_rank_skeleton_upper_size( + key_level: usize, + current_level: usize, + root_level: usize, + box_size: usize, + rank: usize, +) -> usize { + if key_level == current_level && current_level > root_level { + box_size.min(rank) + } else { + box_size + } +} + +fn fixed_rank_sample_estimate( + level_indexing: &TreeData, + rank: usize, + p_param: usize, + root_level: usize, +) -> FixedRankSampleEstimate { + let mut working_indexing = level_indexing.clone(); + let mut bs_map: HashMap = working_indexing + .boxes_map + .iter() + .map(|(key, indices)| (*key, indices.len())) + .collect(); + let mut k_map: HashMap = HashMap::new(); + let mut max_s_vec_k = 0usize; + let mut max_s_vec_p = 0usize; + let mut previous_level_keys = working_indexing + .level_keys + .iter() + .copied() + .collect::>(); + + loop { + let current_level = working_indexing.current_level; + let current_level_keys = working_indexing + .level_keys + .iter() + .copied() + .collect::>(); + let mut current_bs_map: HashMap = HashMap::new(); + + if current_level == working_indexing.max_level { + for key in ¤t_level_keys { + current_bs_map.insert(*key, *bs_map.get(key).unwrap_or(&0)); + } + } else { + for key in ¤t_level_keys { + let box_size = previous_level_keys + .iter() + .filter(|prev_key| prev_key.parent() == *key || **prev_key == *key) + .map(|prev_key| *k_map.get(prev_key).unwrap_or(&0)) + .sum(); + current_bs_map.insert(*key, box_size); + } + } + + for key in ¤t_level_keys { + let box_size = *current_bs_map.get(key).unwrap_or(&0); + let near_size = working_indexing + .get_box_near_field_keys(key, current_level) + .iter() + .map(|near_key| *current_bs_map.get(near_key).unwrap_or(&0)) + .sum::(); + let s_vec_size = box_size + near_size; + + if key.level() > root_level { + let effective_rank = rank.min(box_size); + let s_vec_k = s_vec_size + effective_rank; + let s_vec_p = s_vec_size + rank + p_param; + k_map.insert(*key, effective_rank); + max_s_vec_k = max_s_vec_k.max(s_vec_k); + max_s_vec_p = max_s_vec_p.max(s_vec_p); + } else if *key == MortonKey::root() { + let s_vec_k = s_vec_size; + let s_vec_p = s_vec_size + p_param; + max_s_vec_k = max_s_vec_k.max(s_vec_k); + max_s_vec_p = max_s_vec_p.max(s_vec_p); + k_map.insert(*key, box_size); } else { - args.min_rank + k_map.insert(*key, box_size); + } + } + + if current_level == 0 { + break; + } + + bs_map.retain(|key, _| key.level() < current_level); + for (key, value) in current_bs_map { + bs_map.insert(key, value); + } + previous_level_keys = current_level_keys; + working_indexing.update_level_keys(); + } + + let effective_p = round_up_to_multiple_of_five(max_s_vec_p.saturating_sub(max_s_vec_k)); + FixedRankSampleEstimate { + total_samples: max_s_vec_k + effective_p, + max_s_vec_k, + max_s_vec_p, + effective_p, + } +} + +fn fixed_rank_runtime_sample_estimate( + level_indexing: &TreeData, + rank: usize, + p_param: usize, + min_num_samples: usize, + root_level: usize, +) -> FixedRankRuntimeSampleEstimate { + let mut working_indexing = level_indexing.clone(); + let mut leaf_box_sizes: HashMap = working_indexing + .boxes_map + .iter() + .map(|(key, indices)| (*key, indices.len())) + .collect(); + let mut skeleton_upper_sizes: HashMap = HashMap::new(); + let mut active_samples_by_level = HashMap::new(); + let mut box_samples_by_level = HashMap::new(); + let mut max_active_samples = 0usize; + let mut root_sketch_size = 0usize; + let mut previous_level_keys = working_indexing + .level_keys + .iter() + .copied() + .collect::>(); + + loop { + let current_level = working_indexing.current_level; + let current_level_keys = working_indexing + .level_keys + .iter() + .copied() + .collect::>(); + let mut current_box_sizes: HashMap = HashMap::new(); + + if current_level == working_indexing.max_level { + for key in ¤t_level_keys { + current_box_sizes.insert(*key, *leaf_box_sizes.get(key).unwrap_or(&0)); } } else { - args.min_rank - }; + for key in ¤t_level_keys { + current_box_sizes.insert(*key, 0); + } + for prev_key in &previous_level_keys { + let carried_key = if prev_key.level() == current_level { + *prev_key + } else { + prev_key.parent() + }; + if let Some(total) = current_box_sizes.get_mut(&carried_key) { + *total += *skeleton_upper_sizes.get(prev_key).unwrap_or(&0); + } + } + } - Self { - sketching: SketchingOptions { - oversampling: args.oversampling, - oversampling_diag_blocks: args.oversampling_diag_blocks, - initial_num_samples: args.initial_num_samples, - }, - id_options: IdOptions { - null_method: args.null_method, - tol_null: args.tol_null, - tol_id: args.tol_id, - }, - lu_options: ExtractOptions { - block_extraction_method: args.near_block_extraction_method, - pivot_method: args.lu_pivot_method, - tol_lstsq: args.tol_ext_near, - }, - extract_db_options: ExtractOptions { - block_extraction_method: args.diag_block_extraction_method, - pivot_method: args.diag_pivot_method, - tol_lstsq: args.tol_diag_ext, - }, - min_rank, - min_level: args.min_level, - hermitian: args.hermitian, - rank_picking: args.rank_picking, + let mut level_span_upper = 0usize; + let mut current_box_samples = HashMap::new(); + for key in ¤t_level_keys { + let box_size = *current_box_sizes.get(key).unwrap_or(&0); + let near_sketch_size = level_indexing + .get_box_near_field_keys(key, current_level) + .iter() + .map(|near_key| *current_box_sizes.get(near_key).unwrap_or(&0)) + .sum::(); + let runtime_level_span = per_level_boxwise_fixed_rank_samples( + box_size, + box_size.saturating_add(near_sketch_size), + rank, + p_param, + min_num_samples, + ); + current_box_samples.insert(*key, runtime_level_span); + level_span_upper = level_span_upper.max(runtime_level_span); } + + if current_level > root_level { + let level_samples = level_span_upper.max(min_num_samples); + active_samples_by_level.insert(current_level, level_samples); + box_samples_by_level.insert(current_level, current_box_samples); + max_active_samples = max_active_samples.max(level_samples); + } else if current_level == 0 { + root_sketch_size = current_box_sizes + .get(&MortonKey::root()) + .copied() + .unwrap_or(0); + let level_samples = root_sketch_size + .saturating_add(p_param) + .max(min_num_samples); + active_samples_by_level.insert(current_level, level_samples); + max_active_samples = max_active_samples.max(level_samples); + } + + if current_level == 0 { + break; + } + + leaf_box_sizes.retain(|key, _| key.level() < current_level); + for (key, value) in ¤t_box_sizes { + leaf_box_sizes.insert(*key, *value); + } + skeleton_upper_sizes.clear(); + for (key, value) in ¤t_box_sizes { + let skeleton_upper = fixed_rank_skeleton_upper_size( + key.level(), + current_level, + root_level, + *value, + rank, + ); + skeleton_upper_sizes.insert(*key, skeleton_upper); + } + previous_level_keys = current_level_keys; + working_indexing.update_level_keys(); } - pub fn to_identifier(&self) -> String { - let mut id = String::from("rsrs"); - - write!( - &mut id, - "_null_{:?}_toln_{:e}", - self.id_options.null_method, self.id_options.tol_null, - ) - .unwrap(); - - write!( - &mut id, - "_os_{os}_osdiag_{osdiag}_initsam_{init}", - os = self.sketching.oversampling, - osdiag = self.sketching.oversampling_diag_blocks, - init = self.sketching.initial_num_samples - ) - .unwrap(); - - write!( - &mut id, - "_mrnk_{}_mlvl_{}_herm_{}_rpick_{:?}_next_{:?}_tolextn_{:e}_db_ext_{:?}_tol_lstsq_{:e}", - self.min_rank, - self.min_level, - self.hermitian, - self.rank_picking, - self.lu_options.block_extraction_method, - self.lu_options.tol_lstsq, - self.extract_db_options.block_extraction_method, - self.extract_db_options.tol_lstsq - ) - .unwrap(); - - id + FixedRankRuntimeSampleEstimate { + total_samples: max_active_samples, + active_samples_by_level, + box_samples_by_level, + root_sketch_size, } } -type Real = ::Real; +fn anticipated_fixed_rank_samples_per_level( + level_indexing: &TreeData, + rank: usize, + p_param: usize, + min_num_samples: usize, + root_level: usize, +) -> FixedRankSampleBudget { + let component_estimate = fixed_rank_sample_estimate(level_indexing, rank, p_param, root_level); + let runtime_estimate = fixed_rank_runtime_sample_estimate( + level_indexing, + rank, + component_estimate.effective_p, + min_num_samples, + root_level, + ); + let total_samples = component_estimate + .total_samples + .max(runtime_estimate.total_samples); + let box_samples_by_level: HashMap> = runtime_estimate + .box_samples_by_level + .into_iter() + .map(|(level, level_samples)| { + let level_samples = level_samples + .into_iter() + .map(|(key, samples)| (key, samples.min(total_samples))) + .collect(); + (level, level_samples) + }) + .collect(); + let active_samples_by_level = runtime_estimate + .active_samples_by_level + .into_iter() + .map(|(level, samples)| { + if level == 0 { + return ( + level, + root_fixed_rank_samples( + runtime_estimate.root_sketch_size, + p_param, + total_samples, + min_num_samples, + ), + ); + } + let level_max = box_samples_by_level + .get(&level) + .and_then(|level_samples| level_samples.values().copied().max()) + .unwrap_or(samples); + (level, level_max.max(samples).min(total_samples)) + }) + .collect(); + let level_samples_summary = format_level_samples(&active_samples_by_level); + println!( + "Fixed-rank predicted max active samples: {} (root sketch size = {}, rank = {rank})", + runtime_estimate.total_samples, runtime_estimate.root_sketch_size + ); + println!( + "Fixed-rank sample components: max_s_vec_k = {}, max_s_vec_p = {}, effective_p = {}", + component_estimate.max_s_vec_k, + component_estimate.max_s_vec_p, + component_estimate.effective_p, + ); + println!( + "Fixed-rank per-level active samples: {}", + level_samples_summary + ); + FixedRankSampleBudget::PerLevel { + total_samples, + active_samples_by_level, + box_samples_by_level, + } +} -fn oversample(samples: usize, oversampling: usize) -> usize { - samples + (samples / 100) * oversampling +fn anticipated_fixed_rank_samples_constant( + level_indexing: &TreeData, + rank: usize, + p_param: usize, + root_level: usize, +) -> FixedRankSampleBudget { + let estimate = fixed_rank_sample_estimate(level_indexing, rank, p_param, root_level); + println!( + "Fixed-rank sample components: max_s_vec_k = {}, max_s_vec_p = {}, effective_p = {}", + estimate.max_s_vec_k, estimate.max_s_vec_p, estimate.effective_p + ); + FixedRankSampleBudget::Constant { + samples: estimate.total_samples, + } +} + +fn auto_min_len(batch_len: usize, num_threads: usize) -> usize { + if batch_len <= num_threads { + 1 + } else { + let raw = batch_len / (2 * num_threads); + raw.clamp(1, 32) + } +} + +const RSRS_THREAD_STACK_BYTES: usize = 128 * 1024 * 1024; + +fn build_rsrs_thread_pool(num_threads: usize) -> rayon::ThreadPool { + ThreadPoolBuilder::new() + .num_threads(num_threads) + .stack_size(RSRS_THREAD_STACK_BYTES) + .build() + .unwrap() } impl< Item: RlstScalar + MatrixId + + MatrixIdNoSkel + MatrixNull + MatrixInverse + MatrixPseudoInverse @@ -331,21 +614,176 @@ where MatrixQrDecomposition, TriangularMatrix: TriangularOperations, ::Real: RandScalar, + Item: IOData, + Item: std::convert::From<>::Item>, { + fn sample_buffer_bytes(&self) -> u64 { + let item_size = std::mem::size_of::() as u64; + let samples = self.y_data.test.shape()[0] as u64; + let dim = self.dim as u64; + let num_buffers = if self.options.symmetry.symm_val() { + 2 + } else { + 4 + }; + samples * dim * item_size * (num_buffers as u64) + } + + fn factor_memory_stats(&self, rsrs_factors: Option<&RsrsFactors>) -> FactorMemoryStats { + let Some(rsrs_factors) = rsrs_factors else { + return FactorMemoryStats::default(); + }; + + let breakdown = rsrs_factors.memory_breakdown(); + FactorMemoryStats { + total_bytes: breakdown.total_bytes(), + id_bytes: breakdown.id_bytes, + lu_bytes: breakdown.lu_bytes, + diag_bytes: breakdown.diag_bytes, + perm_bytes: breakdown.perm_bytes, + id_count: breakdown.id_count, + lu_count: breakdown.lu_count, + diag_count: breakdown.diag_count, + } + } + + fn capture_memory_snapshot(&mut self, label: &str, rsrs_factors: Option<&RsrsFactors>) { + let usage = process_memory_usage(); + let sample_buffer_bytes = self.sample_buffer_bytes(); + let factor_memory = self.factor_memory_stats(rsrs_factors); + let accounted_factorization_bytes = sample_buffer_bytes + factor_memory.total_bytes; + + if self.stats.run_start_rss_bytes.is_none() { + self.stats.run_start_rss_bytes = usage.resident_bytes; + } + let baseline_rss_bytes = self.stats.run_start_rss_bytes; + let estimated_temporary_runtime_bytes = match (usage.resident_bytes, baseline_rss_bytes) { + (Some(rss), Some(baseline)) => { + Some(rss.saturating_sub(baseline.saturating_add(accounted_factorization_bytes))) + } + _ => None, + }; + + let resident = usage + .resident_bytes + .map(format_bytes) + .unwrap_or_else(|| "unavailable".to_string()); + let peak = usage + .peak_resident_bytes + .map(format_bytes) + .unwrap_or_else(|| "unavailable".to_string()); + let temp_runtime = estimated_temporary_runtime_bytes + .map(format_bytes) + .unwrap_or_else(|| "unavailable".to_string()); + + println!( + "Memory [{label}]: rss = {resident}, peak = {peak}, sampling ~= {}, factors ~= {}, accounted factorization ~= {}, est temp/runtime ~= {temp_runtime}", + format_bytes(sample_buffer_bytes), + format_bytes(factor_memory.total_bytes), + format_bytes(accounted_factorization_bytes), + ); + println!( + "Factors [{label}]: id = {} ({}), lu = {} ({}), diag = {} ({}), perm = {}", + format_bytes(factor_memory.id_bytes), + factor_memory.id_count, + format_bytes(factor_memory.lu_bytes), + factor_memory.lu_count, + format_bytes(factor_memory.diag_bytes), + factor_memory.diag_count, + format_bytes(factor_memory.perm_bytes) + ); + + self.stats.max_sample_buffer_bytes = + self.stats.max_sample_buffer_bytes.max(sample_buffer_bytes); + self.stats.max_factor_bytes = self.stats.max_factor_bytes.max(factor_memory.total_bytes); + self.stats.max_accounted_factorization_bytes = self + .stats + .max_accounted_factorization_bytes + .max(accounted_factorization_bytes); + self.stats.max_estimated_temporary_runtime_bytes = match ( + self.stats.max_estimated_temporary_runtime_bytes, + estimated_temporary_runtime_bytes, + ) { + (Some(current), Some(candidate)) => Some(current.max(candidate)), + (None, Some(candidate)) => Some(candidate), + (current, None) => current, + }; + + self.stats.memory_snapshots.push(MemorySnapshot { + label: label.to_string(), + rss_bytes: usage.resident_bytes, + peak_rss_bytes: usage.peak_resident_bytes, + baseline_rss_bytes, + sample_buffer_bytes, + factor_memory, + accounted_factorization_bytes, + estimated_temporary_runtime_bytes, + }); + } + pub fn new( octree: &Octree<'_, C>, options: RsrsOptions, dim: usize, ) -> Self { let level_indexing: TreeData = ::new(octree); + let max_leaf_points = level_indexing + .boxes_map + .values() + .map(Vec::len) + .max() + .unwrap_or(0); + println!("Maximum leaf occupancy: {max_leaf_points}"); + let anticipated_fixed_rank_samples = if options.id_options.tol_id > num::One::one() { + let rank = num::ToPrimitive::to_usize(&options.id_options.tol_id).unwrap(); + let estimate_start = Instant::now(); + println!("Estimating fixed-rank sample budget..."); + let estimated_budget = match options.sketching.fixed_rank_sampling_mode { + FixedRankSamplingMode::PerLevel => anticipated_fixed_rank_samples_per_level( + &level_indexing, + rank, + options.sketching.oversampling, + options.sketching.min_num_samples, + 1, + ), + FixedRankSamplingMode::Constant => anticipated_fixed_rank_samples_constant( + &level_indexing, + rank, + options.sketching.oversampling, + 1, + ), + }; + let sample_budget = estimated_budget + .clone() + .with_global_min_samples(options.sketching.min_num_samples); + if sample_budget != estimated_budget { + println!( + "Fixed-rank sample floor override active: min_num_samples = {} exceeds estimated budget = {}. Using a constant fixed budget.", + options.sketching.min_num_samples, + estimated_budget.total_samples() + ); + } + let samples = sample_budget.total_samples(); + println!( + "Fixed-rank sample budget estimated in {:.3}s", + estimate_start.elapsed().as_secs_f64() + ); + println!( + "Anticipated fixed-rank sample budget: {samples} (rank = {rank}, p = {}, mode = {:?})", + options.sketching.oversampling, + options.sketching.fixed_rank_sampling_mode, + ); + Some(sample_budget) + } else { + None + }; let target_inds: Inds = Vec::new(); let near_inds: Inds = Vec::new(); let ind_s: Inds = Vec::new(); let ind_r: Inds = Vec::new(); let box_types: Vec>> = Vec::new(); - //et dim = space.dimension(); - let y_data: SketchData = SketchData::new(dim, false); - let z_data: SketchData = SketchData::new(dim, true); + let y_data: SketchData = SketchData::new(dim, TransMode::NoTrans); + let z_data: SketchData = SketchData::new(dim, TransMode::Trans); let id_times = Vec::new(); let lu_times = Vec::new(); let update_times = Vec::new(); @@ -359,10 +797,13 @@ where min_samples: 0, max_level: 0, limiting_level, + leaf_count: 0, }; + println!("Configured number of threads = {}", options.num_threads); let stats = Stats { sampling_time: Vec::new(), + sample_loading_time: 0_u128, sampling_extraction_time: 0_u128, id_times, tot_id_time: 0_u128, @@ -382,6 +823,14 @@ where residual_calculation: 0_u128, limiting_factors, dim, + level_effort: Vec::new(), + mv_avg_time: Vec::new(), + memory_snapshots: Vec::new(), + run_start_rss_bytes: None, + max_sample_buffer_bytes: 0, + max_factor_bytes: 0, + max_accounted_factorization_bytes: 0, + max_estimated_temporary_runtime_bytes: None, }; Self { @@ -397,6 +846,7 @@ where stats, active_samples: 0, options, + anticipated_fixed_rank_samples, } } @@ -423,38 +873,74 @@ where pub fn run, OpImpl: AsApply>( &mut self, operator: Operator, + ) -> RsrsFactors { + self.run_with_seed(operator, default_run_seed()) + } + + pub fn run_with_seed< + Space: SamplingSpace, + OpImpl: AsApply, + >( + &mut self, + operator: Operator, + seed: u64, ) -> RsrsFactors { let num_levels: usize = self.level_indexing.max_level; let algo_start: Instant = Instant::now(); - let mut rsrs_factors = - as RsrsFactorsImpl>::new(num_levels, self.dim); + let mut rsrs_factors = as RsrsFactorsImpl>::new( + num_levels, + self.dim, + &self.options.fact_type, + self.options.num_threads, + ); + self.capture_memory_snapshot("run start", None); let start: Instant = Instant::now(); - self.tree_cycle(operator.r(), &mut rsrs_factors); + self.tree_cycle(operator.r(), &mut rsrs_factors, seed); let duration = start.elapsed(); println!("Tree cycle elapsed time: {} s", duration.as_secs()); + self.capture_memory_snapshot("after tree cycle", Some(&rsrs_factors)); println!( "Extracting diagonal blocks with {} active samples", self.active_samples ); + self.capture_memory_snapshot("before diagonal extraction", Some(&rsrs_factors)); let start: Instant = Instant::now(); - let (diag_box_factors, rows, cols) = self.extract_step(); + let (mut diag_box_factors, rows, cols) = self.extract_step(); + if self.options.flush_factors { + diag_box_factors.flush(); + } rsrs_factors.diag_box_factors = diag_box_factors; rsrs_factors.perm_factor.orig_indices = cols; rsrs_factors.perm_factor.perm_indices = rows; let extraction_time = start.elapsed(); - println!("Extraction time: {:?}s\n", extraction_time.as_secs()); + println!( + "Extraction time: {:.3} ms ({extraction_time:?})\n", + extraction_time.as_secs_f64() * 1.0e3 + ); + self.capture_memory_snapshot("after diagonal extraction", Some(&rsrs_factors)); self.stats.extraction_time = extraction_time.as_millis(); let duration = algo_start.elapsed(); self.stats.total_elapsed_time = duration.as_millis(); - let sampling_time = - self.stats.sampling_extraction_time + self.stats.sampling_time.iter().sum::(); + let sampling_time = self.stats.sample_loading_time + + self.stats.sampling_extraction_time + + self.stats.sampling_time.iter().sum::(); self.stats.total_elapsed_time_wo_sampling = self.stats.total_elapsed_time.saturating_sub(sampling_time); println!( - "Total elapsed time: {:?} ({}ms for sampling, {}ms for RSRS), with {} active samples\n", + "Total elapsed time: {:?} ({}ms for loading/sampling, {}ms for RSRS), with {} active samples\n", duration, sampling_time, self.stats.total_elapsed_time_wo_sampling, self.active_samples ); + println!( + "Memory maxima: sampling ~= {}, factors ~= {}, accounted factorization ~= {}, est temp/runtime ~= {}\n", + format_bytes(self.stats.max_sample_buffer_bytes), + format_bytes(self.stats.max_factor_bytes), + format_bytes(self.stats.max_accounted_factorization_bytes), + self.stats + .max_estimated_temporary_runtime_bytes + .map(format_bytes) + .unwrap_or_else(|| "unavailable".to_string()) + ); println!("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\n"); rsrs_factors @@ -467,6 +953,7 @@ where &mut self, operator: Operator, rsrs_factors: &mut RsrsFactors, + seed: u64, ) { let mut level: usize = self.level_indexing.max_level; let mut level_it = 0; @@ -478,15 +965,22 @@ where self.get_level_indices(level); let duration: Duration = start.elapsed(); println!("Current Level: {level}. Indices computed in {duration:?}\n\n"); + let active_boxes = self.ind_s.iter().filter(|v| !v.is_empty()).count(); self.stats.index_calculation += duration.as_millis(); + let len_r_s: usize = self + .ind_r + .iter() + .map(|residual_inds| residual_inds.len()) + .sum(); let start: Instant = Instant::now(); - self.level_cycle(operator.r(), rsrs_factors, level_it); + let (level_duration, num_batches) = + self.level_cycle(operator.r(), rsrs_factors, level_it, seed); println!("End level cycle. Summary:"); println!("-------------------------"); let duration: Duration = start.elapsed(); println!("Elapsed time: {} s", duration.as_secs()); - + println!("Leaf count: {}", self.stats.limiting_factors.leaf_count); if self.stats.limiting_factors.limiting_level.level == self.level_indexing.current_level { self.stats.limiting_factors.limiting_level.elapsed_time = duration.as_millis() @@ -501,14 +995,24 @@ where let len_s = self.dim - len_r; let duration: Duration = start.elapsed(); self.stats.residual_calculation += duration.as_millis(); - - println!("Sketch Points: {len_s}"); + let active_indices: usize = self.ind_s.iter().map(Vec::len).sum(); + println!("Sketch Points: {len_s} ({active_indices})"); println!("Residual Points: {len_r}"); println!( "Current Number of Samples: {} of which {} are active\n", self.y_data.test.shape()[0], self.active_samples ); + self.capture_memory_snapshot(&format!("level {level} summary"), Some(rsrs_factors)); + let level_effort = LevelEffort { + time: level_duration, + num_boxes: active_boxes, + num_batches, + effective_dofs: len_r - len_r_s, + residual_len: len_r, + sketch_len: active_indices, + }; + self.stats.level_effort.push(level_effort); level -= 1; level_it += 1; @@ -517,19 +1021,45 @@ where println!("-------------------------"); println!("\nReached lower level: {level}"); self.stats.residual_size = len_r; - let min_oversamples = - oversample(len_s, self.options.sketching.oversampling_diag_blocks); + let active_before = self.active_samples; + let min_oversamples = oversample::( + len_s, + self.options.sketching.oversampling_diag_blocks, + num::One::one(), + self.options.sketching.min_num_samples, + ); + let (sample_target, active_target) = + if let Some(fixed_rank_budget) = self.anticipated_fixed_rank_samples.as_ref() { + let total_samples = fixed_rank_budget.total_samples(); + (total_samples, total_samples) + } else { + (min_oversamples, min_oversamples) + }; + let active_after = active_target.max(active_before); + println!( + "[sampling][diag] level={} residual_len={} min_oversamples={} sampling_target={} active_target={} active_before={} active_after={} total_loaded={}", + self.level_indexing.current_level, + len_s, + min_oversamples, + sample_target, + active_target, + active_before, + active_after, + self.y_data.test.shape()[0] + ); println!("Minimum samples: {min_oversamples}"); let (tot_sampling_time, tot_id_update, tot_lu_update) = self.add_samples( - min_oversamples, + sample_target, + active_target, operator.r(), rsrs_factors, level_it, false, - 0_u64, + false, + scoped_seed(seed, level_it, 0xD1A6_0001), ); - self.active_samples = min_oversamples.max(self.active_samples); + self.active_samples = active_after; self.stats.sampling_extraction_time = tot_sampling_time; let mut update_times = UpdateTimes::new(); @@ -549,7 +1079,8 @@ where operator: Operator, rsrs_factors: &mut RsrsFactors, level_it: usize, - ) { + seed: u64, + ) -> (u128, usize) { let merged_count = self .box_types .iter() @@ -558,29 +1089,20 @@ where println!("Number of merged boxes: {merged_count}\n"); let current_box_indices = - self.sampling_step(operator.r(), rsrs_factors, level_it == 0, level_it); - let id_step_start: Instant = Instant::now(); - let (id_factors_res, current_box_indices, level_ind_r) = - self.id_level_iteration::(¤t_box_indices); - rsrs_factors.id_factors[level_it] = id_factors_res; - let id_step_duration = id_step_start.elapsed(); - self.stats.tot_id_time += id_step_duration.as_millis(); - - println!("ID step in {id_step_duration:?}"); - - let start_id_update: Instant = Instant::now(); - let update_type = UpdateType::Id(&rsrs_factors.id_factors[level_it]); - self.update_samples(0, self.active_samples, level_it, &update_type); - let update_id_time: Duration = start_id_update.elapsed(); + self.sampling_step(operator.r(), rsrs_factors, level_it == 0, level_it, seed); - let mut update_times = UpdateTimes::new(); - update_times.sum(update_id_time.as_millis(), 0_u128); - self.stats.update_times.push(update_times); - - println!("ID updated in {update_id_time:?}\n"); - - rsrs_factors.lu_factors[level_it] = - self.lu_level_iteration::(¤t_box_indices, &level_ind_r, level_it); + let level_start: Instant = Instant::now(); + let num_batches = match self.options.fact_type { + FactType::Joint => { + self.joint_id_and_lu::(rsrs_factors, ¤t_box_indices, level_it) + } + FactType::Split => { + self.split_id_and_lu::(rsrs_factors, current_box_indices, level_it); + 0 + } + }; + let level_duration_wo_sampling = level_start.elapsed(); + (level_duration_wo_sampling.as_millis(), num_batches) } fn sampling_step< @@ -592,44 +1114,88 @@ where rsrs_factors: &RsrsFactors, start: bool, level_it: usize, + seed: u64, ) -> Vec { - let mut box_indices: Vec = (0..self.target_inds.len()).collect::>(); - - box_indices = box_indices - .into_iter() - .filter(|&box_ind| !self.ind_s[box_ind].is_empty()) + let fixed_rank = self.anticipated_fixed_rank_samples.is_some(); + let mut box_indices: Vec = (0..self.target_inds.len()) + .filter(|&box_ind| { + should_launch_box(self.ind_s[box_ind].len(), fixed_rank, &self.options) + }) .collect::>(); box_indices.sort_by_key(|&box_ind| { self.ind_s[box_ind].len() + self.get_near_indices(box_ind).len() }); + if box_indices.is_empty() { + return box_indices; + } + let current_box_indices = box_indices; let last_box_index = *current_box_indices.last().unwrap(); + let hardest_box_skeleton = self.ind_s[last_box_index].len(); + let hardest_box_near = self.get_near_indices(last_box_index).len(); - let min_oversamples = oversample( - self.ind_s[last_box_index].len() + self.get_near_indices(last_box_index).len(), + let level_min_oversamples = oversample::( + hardest_box_skeleton + hardest_box_near, self.options.sketching.oversampling, + self.options.id_options.tol_id, + self.options.sketching.min_num_samples, ); - let min_samples = if start { - self.options - .sketching - .initial_num_samples - .max(min_oversamples) - } else { - min_oversamples - }; + let (sample_target, active_target) = + if let Some(fixed_rank_budget) = self.anticipated_fixed_rank_samples.as_ref() { + let total_samples = fixed_rank_budget.total_samples(); + (total_samples, total_samples) + } else if start { + let target = self + .options + .sketching + .initial_num_samples + .max(level_min_oversamples); + (target, target) + } else { + (level_min_oversamples, level_min_oversamples) + }; + let load_samples = start && self.options.sketching.load_samples; + let active_before = self.active_samples; + let active_after = active_target.max(active_before); + println!( + "[sampling][level] level={} level_it={} hardest_box={} skeleton={} near={} level_min_oversamples={} sampling_target={} active_target={} active_before={} active_after={} total_loaded={} fixed_rank={} fixed_rank_sampling_mode={:?} load_samples={}", + self.level_indexing.current_level, + level_it, + last_box_index, + hardest_box_skeleton, + hardest_box_near, + level_min_oversamples, + sample_target, + active_target, + active_before, + active_after, + self.y_data.test.shape()[0], + fixed_rank, + self.options.sketching.fixed_rank_sampling_mode, + load_samples + ); - let (tot_sampling_time, tot_id_update, tot_lu_update) = - self.add_samples(min_samples, operator.r(), rsrs_factors, level_it, start, 1); + let (tot_sampling_time, tot_id_update, tot_lu_update) = self.add_samples( + sample_target, + active_target, + operator.r(), + rsrs_factors, + level_it, + start, + load_samples, + scoped_seed( + seed, + level_it, + if start { 0x5A11_0001 } else { 0x5A11_0002 }, + ), + ); - self.stats.limiting_factors.min_samples = self - .stats - .limiting_factors - .min_samples - .max(self.active_samples); - self.active_samples = min_oversamples.max(self.active_samples); + self.stats.limiting_factors.min_samples = + self.stats.limiting_factors.min_samples.max(active_after); + self.active_samples = active_after; self.stats.sampling_time.push(tot_sampling_time); let mut update_times = UpdateTimes::new(); @@ -637,50 +1203,165 @@ where self.stats.update_times.push(update_times); println!("Active samples: {}\n", self.active_samples); + self.capture_memory_snapshot( + &format!("level {} after sampling", self.level_indexing.current_level), + Some(rsrs_factors), + ); current_box_indices } + #[allow(clippy::too_many_arguments)] fn add_samples< Space: SamplingSpace, OpImpl: AsApply, >( &mut self, - min_samples: usize, + sample_target: usize, + active_target: usize, operator: Operator, rsrs_factors: &RsrsFactors, level_it: usize, start: bool, - _seed: u64, + load_samples: bool, + seed: u64, ) -> (u128, u128, u128) { + let preferred_sampling_dir = self.options.sketching.sample_storage_dir.as_deref(); + let current_shift = shift_alpha(&self.options.sketching.shift); + let mut active_sampling_dir: Option = None; + + if load_samples { + let load_start = Instant::now(); + if let Some(y_sampling_dir) = + resolve_sampling_dir(preferred_sampling_dir, &["y_test_file", "y_sketch_file"]) + .unwrap() + { + >::load_into_in_dir( + &mut self.y_data.test, + self.dim, + "y_test_file", + Some(y_sampling_dir.as_path()), + ) + .unwrap(); + >::load_into_in_dir( + &mut self.y_data.sketch, + self.dim, + "y_sketch_file", + Some(y_sampling_dir.as_path()), + ) + .unwrap(); + let num_existing_samples = self.y_data.test.shape()[0]; + if current_shift.abs() > f64::EPSILON { + apply_shift_delta(&mut self.y_data.sketch, &self.y_data.test, current_shift); + } + println!( + "{} samples loaded from '{}' and {} stored / {} active target samples", + num_existing_samples, + y_sampling_dir.display(), + sample_target, + active_target + ); + active_sampling_dir = Some(y_sampling_dir); + } + + if !self.options.symmetry.symm_val() { + if let Some(z_sampling_dir) = + resolve_sampling_dir(preferred_sampling_dir, &["z_test_file", "z_sketch_file"]) + .unwrap() + { + >::load_into_in_dir( + &mut self.z_data.test, + self.dim, + "z_test_file", + Some(z_sampling_dir.as_path()), + ) + .unwrap(); + >::load_into_in_dir( + &mut self.z_data.sketch, + self.dim, + "z_sketch_file", + Some(z_sampling_dir.as_path()), + ) + .unwrap(); + let num_existing_samples = self.z_data.test.shape()[0]; + if current_shift.abs() > f64::EPSILON { + apply_shift_delta( + &mut self.z_data.sketch, + &self.z_data.test, + current_shift, + ); + } + println!( + "{} samples loaded from '{}' and {} stored / {} active target samples", + num_existing_samples, + z_sampling_dir.display(), + sample_target, + active_target + ); + active_sampling_dir.get_or_insert(z_sampling_dir); + } + } + let load_duration = load_start.elapsed().as_millis(); + self.stats.sample_loading_time += load_duration; + println!("Sample loading time: {load_duration}ms"); + } + + let sample_storage_dir = active_sampling_dir + .as_deref() + .or_else(|| preferred_sampling_dir.map(Path::new)); + let mut tot_sampling_time = 0_u128; let test_shape = self.y_data.test.shape(); - if min_samples > test_shape[0] { - let extra_samples = min_samples.saturating_sub(self.y_data.test.shape()[0]); + + if sample_target > test_shape[0] { + let extra_samples = sample_target.saturating_sub(self.y_data.test.shape()[0]); println!("Sampling step. Sampling new {extra_samples} vectors\n"); - tot_sampling_time += self.y_data.add_samples(extra_samples, operator.r(), 0_u64); + tot_sampling_time += self.y_data.add_samples( + extra_samples, + operator.r(), + &self.options.sketching.shift, + self.options.sketching.save_samples, + sample_storage_dir, + mix_seed(seed ^ 0x59_5F33_DA7A_0001), + ); - if !self.options.hermitian { - let tot_z_sampling_time = - self.z_data.add_samples(extra_samples, operator.r(), 0_u64); + if !self.options.symmetry.symm_val() { + let tot_z_sampling_time = self.z_data.add_samples( + extra_samples, + operator.r(), + &self.options.sketching.shift, + self.options.sketching.save_samples, + sample_storage_dir, + mix_seed(seed ^ 0x5A_5F33_DA7A_0002), + ); tot_sampling_time += tot_z_sampling_time; } println!("Total samples: {}", self.y_data.test.shape()[0]); println!("Sampling time: {tot_sampling_time}ms\n"); } - if !start && min_samples > self.active_samples { - let extra_active_samples = min_samples.saturating_sub(self.active_samples); - println!("New {extra_active_samples} samples, with {min_samples} min samples."); + if !start && active_target > self.active_samples { + let extra_active_samples = active_target.saturating_sub(self.active_samples); + println!( + "New {extra_active_samples} active samples, with {active_target} active target." + ); let update_start = self.active_samples; - let (tot_id_update, tot_lu_update) = self.update_samples( + let (tot_id_update, tot_lu_update, avg_mv_time) = self.update_samples( update_start, extra_active_samples, level_it, &UpdateType::Both(rsrs_factors), ); + match avg_mv_time { + Some(avg_time) => { + println!("Average update time: {} ms", avg_time); + self.stats.mv_avg_time.push(avg_time) + } + None => todo!(), + }; + println!("Update times: {tot_id_update}ms (ID), {tot_lu_update}ms (LU)"); return (tot_sampling_time, tot_id_update, tot_lu_update); } @@ -693,20 +1374,374 @@ where samples_to_update: usize, level: usize, update_type: &UpdateType, - ) -> (u128, u128) { - let (mut tot_id_update, mut tot_lu_update) = - self.y_data - .update_samples(update_start, samples_to_update, level, update_type); - - if !self.options.hermitian { - let (tot_z_id_update, tot_z_lu_update) = - self.z_data - .update_samples(update_start, samples_to_update, level, update_type); + ) -> (u128, u128, Option) { + let thread_pool = build_rsrs_thread_pool(self.options.num_threads); + let (mut tot_id_update, mut tot_lu_update) = self.y_data.update_samples( + update_start, + samples_to_update, + level, + update_type, + &self.options.fact_type, + &thread_pool, + self.options.num_threads, + ); + + let avg_update_time = if samples_to_update == 0 { + None + } else { + Some((tot_id_update + tot_lu_update) / (samples_to_update as u128)) + }; + + if !self.options.symmetry.symm_val() { + let (tot_z_id_update, tot_z_lu_update) = self.z_data.update_samples( + update_start, + samples_to_update, + level, + update_type, + &self.options.fact_type, + &thread_pool, + self.options.num_threads, + ); tot_id_update += tot_z_id_update; tot_lu_update += tot_z_lu_update; } - (tot_id_update, tot_lu_update) + (tot_id_update, tot_lu_update, avg_update_time) + } + + fn split_id_and_lu>( + &mut self, + rsrs_factors: &mut RsrsFactors, + current_box_indices: Vec, + level_it: usize, + ) { + let id_step_start: Instant = Instant::now(); + let (id_factors_res, current_box_indices, level_ind_r) = + self.id_level_iteration::(¤t_box_indices); + let id_step_duration = id_step_start.elapsed(); + self.stats.tot_id_time += id_step_duration.as_millis(); + println!("ID step in {id_step_duration:?}"); + match &mut rsrs_factors.id_factors { + MultiLevelIdFactors::Single(level_factors) => { + level_factors[level_it] = id_factors_res; + let start_id_update: Instant = Instant::now(); + let update_type = UpdateType::Id(&level_factors[level_it]); + self.update_samples(0, self.active_samples, level_it, &update_type); + let update_id_time: Duration = start_id_update.elapsed(); + + let mut update_times = UpdateTimes::new(); + update_times.sum(update_id_time.as_millis(), 0_u128); + self.stats.update_times.push(update_times); + + println!("ID updated in {update_id_time:?}\n"); + } + MultiLevelIdFactors::Batched(_) => todo!(), + } + //rsrs_factors.id_factors[level_it] = id_factors_res; + + rsrs_factors.lu_factors[level_it] = + self.lu_level_iteration::(¤t_box_indices, &level_ind_r, level_it); + } + fn joint_id_and_lu>( + &mut self, + rsrs_factors: &mut RsrsFactors, + current_box_indices: &[usize], + level_it: usize, + ) -> usize { + println!("ID and LU step"); + let start = Instant::now(); + + // 1. Build independent batches (MIS-based) + let independent_near_fields = self.group_near_fields(current_box_indices); + + // 2. Build level_near_field_inds once (per level) + let mut level_near_field_inds: Vec<_> = current_box_indices + .iter() + .map(|&box_ind| self.get_near_indices(box_ind)) + .collect(); + + let time_independent_nf = start.elapsed(); + self.stats.sorting_near_field += time_independent_nf.as_millis(); + + let num_batches = independent_near_fields.len(); + println!( + "Batches computed in {time_independent_nf:?}, number of batches: {}", + independent_near_fields.len() + ); + + // --- timing / stats accumulators over all batches at this level --- + let mut update_lu_batch_time: u128 = 0; + let mut update_id_batch_time: u128 = 0; + let mut lu_step_duration: u128 = 0; + let mut id_step_duration: u128 = 0; + + let mut level_ind_r: Vec> = vec![Vec::new(); current_box_indices.len()]; + let mut inactive_inds: Vec = Vec::new(); + + let mut lu_times = LuTimes::new(); + let mut id_times = IdTimes::new(); + let mut update_times = UpdateTimes::new(); + let mut num_dec_boxes = 0; + let mut len_sketch = 0; + let mut len_full_rank = 0; + let current_level = self.level_indexing.current_level; + let current_level_keys: Vec = + self.level_indexing.level_keys.iter().copied().collect(); + // Process all batches *sequentially*, each batch using Rayon internally + let thread_pool = build_rsrs_thread_pool(self.options.num_threads); + let batches_res: Vec<_> = thread_pool.install(|| { + independent_near_fields + .into_iter() + .map(|batch| { + // One batch = independent set of LOCAL indices into current_box_indices + let mut id_batch: CommutativeFactors = + CommutativeFactorsOperations::new(); + let mut id_batch_time = IdTimes::new(); + let mut lu_batch: CommutativeFactors = + CommutativeFactorsOperations::new(); + let mut lu_batch_time = LuTimes::new(); + + // ---- ID STEP ---- + let id_step_start = Instant::now(); + + let num_threads = rayon::current_num_threads(); + let min_len_id = auto_min_len(batch.len(), num_threads); + + let id_batch_res: Vec<_> = batch + .par_iter() + .with_min_len(min_len_id) + .map_init(ExtractionScratch::::new, |scratch, box_num| { + let box_ind = current_box_indices[*box_num]; + let mut skel_box = ::default(); + let per_box_samples = self + .anticipated_fixed_rank_samples + .as_ref() + .and_then(|budget| { + budget.per_level_box_samples( + current_level, + ¤t_level_keys[box_ind], + ) + }); + let min_num_samples = local_sample_count( + per_box_samples, + level_near_field_inds[*box_num].len(), + self.active_samples, + self.anticipated_fixed_rank_samples.is_some(), + self.options.sketching.fixed_rank_sampling_mode, + &self.options, + ); + assert_post_nullification_headroom( + self.ind_s[box_ind].len(), + level_near_field_inds[*box_num].len(), + min_num_samples, + &self.options, + ); + + let rank = >::id_step( + &mut skel_box, + scratch, + &self.box_types[box_ind], + &self.ind_s[box_ind], + &level_near_field_inds[*box_num], + &self.y_data, + &self.z_data, + min_num_samples, + &self.options, + ); + + let leaf_counter = match self.box_types[box_ind] { + BoxType::Merged(_) => 0, + BoxType::Full(_) => match rank { + Rank::Low(_) => 1, + Rank::Full(_) => 0, + }, + }; + + (*box_num, box_ind, rank, leaf_counter) + }) + .collect(); + + let mut active_batch: Vec = Vec::new(); + + id_batch_res.into_iter().for_each( + |(box_num, box_ind, result, leaf_counter)| match result { + Rank::Low(low_rank_result) => { + active_batch.push(box_num); + + level_ind_r[box_num] = low_rank_result.id_factor.ind_r.clone(); + self.target_inds[box_ind] = low_rank_result.target_inds.clone(); + self.ind_s[box_ind] = low_rank_result.id_factor.ind_s.clone(); + self.ind_r.push(low_rank_result.id_factor.ind_r.clone()); + + let box_size = self.target_inds[box_ind].len(); + + let res_id_times = match low_rank_result.id_times { + Times::Lu(_) => IdTimes { + nullification: 0, + id: 0, + }, + Times::Id(id_times) => id_times, + }; + + id_batch_time.sum(res_id_times.nullification, res_id_times.id); + + self.stats.ranks.push(self.ind_s[box_ind].len()); + self.stats.box_sizes.push(box_size); + self.stats + .near_field_sizes + .push(low_rank_result.near_field_inds.len()); + + len_sketch += self.ind_s[box_ind].len(); + num_dec_boxes += 1; + self.stats.limiting_factors.leaf_count += leaf_counter; + + id_batch.add_factor(Factor::Id(low_rank_result.id_factor)); + } + Rank::Full(full_rank_result) => { + if let Times::Id(id_times) = full_rank_result.id_times { + id_batch_time.sum(id_times.nullification, id_times.id); + } + self.stats.ranks.push(full_rank_result.len_target_inds); + self.stats.box_sizes.push(full_rank_result.len_target_inds); + self.stats + .near_field_sizes + .push(full_rank_result.len_near_field_inds); + self.stats.limiting_factors.leaf_count += leaf_counter; + len_full_rank += self.ind_s[box_ind].len(); + } + }, + ); + + id_step_duration += id_step_start.elapsed().as_millis(); + + // update samples after ID + let id_batch_start = Instant::now(); + let update_type = UpdateType::Id(&id_batch); + self.update_samples(0, self.active_samples, level_it, &update_type); + update_id_batch_time += id_batch_start.elapsed().as_millis(); + + if self.options.flush_factors { + id_batch.flush(); + } + // ---- LU STEP ---- + if !active_batch.is_empty() { + let lu_step_start = Instant::now(); + + let min_len_lu = auto_min_len(active_batch.len(), num_threads); + let lu_batch_res: Vec<(Times, LuFactor, Vec)> = active_batch + .par_iter() + .with_min_len(min_len_lu) + .map_init(ExtractionScratch::::new, |scratch, box_num| { + let skel_box = ::default(); + let box_ind = current_box_indices[*box_num]; + let per_box_samples = self + .anticipated_fixed_rank_samples + .as_ref() + .and_then(|budget| { + budget.per_level_box_samples( + current_level, + ¤t_level_keys[box_ind], + ) + }); + let min_num_samples = local_sample_count( + per_box_samples, + level_near_field_inds[*box_num].len(), + self.active_samples, + self.anticipated_fixed_rank_samples.is_some(), + self.options.sketching.fixed_rank_sampling_mode, + &self.options, + ); + + >::lu_step( + &skel_box, + scratch, + &self.y_data, + &self.z_data, + &level_ind_r[*box_num], + &level_near_field_inds[*box_num], + &inactive_inds, + min_num_samples, + &self.options, + ) + .map(|(lu_factor, lu_times)| { + (lu_times, lu_factor, level_ind_r[*box_num].clone()) + }) + }) + .flatten() + .collect(); + + lu_batch_res + .into_iter() + .for_each(|(it_lu_times, lu_factor, r_inds)| { + lu_batch.add_factor(Factor::Lu(lu_factor)); + inactive_inds.extend_from_slice(&r_inds); + if let Times::Lu(lu_times) = it_lu_times { + lu_batch_time.sum(lu_times.lu, lu_times.extraction) + } + }); + + inactive_inds.sort_unstable(); + inactive_inds.dedup(); + + lu_step_duration += lu_step_start.elapsed().as_millis(); + + // update samples after LU + let lu_batch_start = Instant::now(); + let update_type = UpdateType::Lu(&lu_batch); + self.update_samples(0, self.active_samples, level_it, &update_type); + update_lu_batch_time += lu_batch_start.elapsed().as_millis(); + } + + level_near_field_inds = level_near_field_inds + .iter() + .map(|inds| { + inds.iter() + .filter(|el| inactive_inds.binary_search(el).is_err()) + .cloned() + .collect() + }) + .collect(); + + if self.options.flush_factors { + lu_batch.flush(); + } + + (id_batch_time, id_batch, lu_batch_time, lu_batch) + }) + .collect() + }); + + // accumulate over batches + batches_res + .into_iter() + .for_each(|(id_batch_time, id_batch, lu_batch_time, lu_batch)| { + id_times.sum(id_batch_time.nullification, id_batch_time.id); + lu_times.sum(lu_batch_time.extraction, lu_batch_time.lu); + + rsrs_factors.lu_factors[level_it].push(lu_batch); + match &mut rsrs_factors.id_factors { + MultiLevelIdFactors::Single(_) => todo!(), + MultiLevelIdFactors::Batched(level_factors) => { + level_factors[level_it].push(id_batch) + } + } + }); + + update_times.sum(update_id_batch_time, update_lu_batch_time); + + self.stats.dec_boxes_per_level.push(num_dec_boxes); + self.stats.id_times.push(id_times); + self.stats.lu_times.push(lu_times); + self.stats.update_times.push(update_times); + + self.stats.tot_id_time += id_step_duration; + self.stats.tot_lu_time += lu_step_duration; + + println!("ID and LU steps completed"); + println!("ID step in {id_step_duration}ms, with updates in {update_id_batch_time}ms"); + println!("LU step in {lu_step_duration}ms, with updates in {update_lu_batch_time}ms\n"); + + num_batches } fn id_level_iteration>( @@ -723,42 +1758,68 @@ where .enumerate() .map(|(box_num, box_ind)| (*box_ind, box_num)) .collect(); + let current_level = self.level_indexing.current_level; + let current_level_keys: Vec = + self.level_indexing.level_keys.iter().copied().collect(); let mut current_box_indices = current_box_indices.to_vec(); current_box_indices.sort_by_key(|&box_ind| { let box_num = *current_near_field_ind_to_num.get(&box_ind).unwrap(); - let near_field_len = current_near_field_indices[box_num].len(); - let source_len = self.ind_s[box_ind].len(); - oversample( - near_field_len + source_len, + oversample::( + current_near_field_indices[box_num].len() + self.ind_s[box_ind].len(), self.options.sketching.oversampling, + self.options.id_options.tol_id, + self.options.sketching.min_num_samples, ) }); let start = Instant::now(); - let id_level_iteration_res: Vec<_> = current_box_indices - .par_iter() - .map(|&box_ind| { - let box_num = *current_near_field_ind_to_num.get(&box_ind).unwrap(); - let near_field_inds = ¤t_near_field_indices[box_num]; - let min_box_samples = oversample( - near_field_inds.len() + self.ind_s[box_ind].len(), - self.options.sketching.oversampling, - ); - let mut skel_box = ::default(); - - let rank = >::id_step( - &mut skel_box, - &self.box_types[box_ind], - &self.ind_s[box_ind], - near_field_inds, - &self.y_data, - &self.z_data, - min_box_samples, - &self.options, - ); - (box_ind, rank) - }) - .collect(); + let thread_pool = build_rsrs_thread_pool(self.options.num_threads); + let id_level_iteration_res: Vec<_> = thread_pool.install(|| { + current_box_indices + .par_iter() + .map_init(ExtractionScratch::::new, |scratch, &box_ind| { + let box_num = *current_near_field_ind_to_num.get(&box_ind).unwrap(); + let near_field_inds = ¤t_near_field_indices[box_num]; + let per_box_samples = + self.anticipated_fixed_rank_samples + .as_ref() + .and_then(|budget| { + budget.per_level_box_samples( + current_level, + ¤t_level_keys[box_ind], + ) + }); + let min_box_samples = local_sample_count( + per_box_samples, + near_field_inds.len(), + self.active_samples, + self.anticipated_fixed_rank_samples.is_some(), + self.options.sketching.fixed_rank_sampling_mode, + &self.options, + ); + assert_post_nullification_headroom( + self.ind_s[box_ind].len(), + near_field_inds.len(), + min_box_samples, + &self.options, + ); + let mut skel_box = ::default(); + + let rank = >::id_step( + &mut skel_box, + scratch, + &self.box_types[box_ind], + &self.ind_s[box_ind], + near_field_inds, + &self.y_data, + &self.z_data, + min_box_samples, + &self.options, + ); + (box_ind, rank) + }) + .collect() + }); let id_level_duration = start.elapsed(); println!("ID calculations in {id_level_duration:?}",); @@ -803,9 +1864,14 @@ where id_level.add_factor(Factor::Id(low_rank_result.id_factor)); } - Rank::Full(it_id_times) => { + Rank::Full(full_rank_result) => { len_full_rank += self.ind_s[box_ind].len(); - let res_id_times = match it_id_times { + self.stats.ranks.push(full_rank_result.len_target_inds); + self.stats.box_sizes.push(full_rank_result.len_target_inds); + self.stats + .near_field_sizes + .push(full_rank_result.len_near_field_inds); + let res_id_times = match full_rank_result.id_times { Times::Lu(_lu_times) => IdTimes { nullification: 0, id: 0, @@ -840,6 +1906,10 @@ where .iter() .map(|&box_ind| self.get_near_indices(box_ind)) .collect(); + let current_level = self.level_indexing.current_level; + let current_level_keys: Vec = + self.level_indexing.level_keys.iter().copied().collect(); + let time_independent_nf = start.elapsed(); self.stats.sorting_near_field += time_independent_nf.as_millis(); @@ -849,38 +1919,63 @@ where let mut update_parallel_batch_time = 0; let mut lu_times = LuTimes::new(); let mut update_times = UpdateTimes::new(); - let lu_step_start: Instant = Instant::now(); + + let mut inactive_inds = Vec::new(); + let thread_pool = build_rsrs_thread_pool(self.options.num_threads); let batches_res: Vec<_> = independent_near_fields .into_iter() .map(|batch| { let mut lu_batch: CommutativeFactors = CommutativeFactorsOperations::new(); let mut lu_batch_time = LuTimes::new(); - let lu_times_and_factor: Vec<_> = batch - .par_iter() - .map(|box_num| { - let skel_box = ::default(); - let box_ind = current_box_indices[*box_num]; - let min_num_samples = oversample( - self.target_inds[box_ind].len() + level_near_field_inds[*box_num].len(), - self.options.sketching.oversampling, - ); - let (lu_factor, lu_times) = >::lu_step( - &skel_box, - &self.y_data, - &self.z_data, - &mut level_ind_r[*box_num].clone(), - &mut level_near_field_inds[*box_num].clone(), - min_num_samples, - &self.options, - ); - (lu_times, lu_factor) - }) - .collect(); - lu_times_and_factor + let lu_times_and_factors: Vec<(Times, LuFactor, Vec)> = thread_pool + .install(|| { + batch + .par_iter() + .map_init(ExtractionScratch::::new, |scratch, box_num| { + let skel_box = ::default(); + let box_ind = current_box_indices[*box_num]; + let per_box_samples = self + .anticipated_fixed_rank_samples + .as_ref() + .and_then(|budget| { + budget.per_level_box_samples( + current_level, + ¤t_level_keys[box_ind], + ) + }); + let min_num_samples = local_sample_count( + per_box_samples, + level_near_field_inds[*box_num].len(), + self.active_samples, + self.anticipated_fixed_rank_samples.is_some(), + self.options.sketching.fixed_rank_sampling_mode, + &self.options, + ); + >::lu_step( + &skel_box, + scratch, + &self.y_data, + &self.z_data, + &level_ind_r[*box_num].clone(), + &level_near_field_inds[*box_num].clone(), + &inactive_inds, + min_num_samples, + &self.options, + ) + .map(|(lu_factor, lu_times)| { + (lu_times, lu_factor, level_ind_r[*box_num].clone()) + }) + }) + .flatten() + .collect() + }); + + lu_times_and_factors .into_iter() - .for_each(|(lu_time, lu_factor)| { + .for_each(|(lu_time, lu_factor, r_inds)| { lu_batch.add_factor(Factor::Lu(lu_factor)); + inactive_inds.extend_from_slice(&r_inds); match lu_time { Times::Lu(lu_times) => { lu_batch_time.sum(lu_times.lu, lu_times.extraction) @@ -908,7 +2003,6 @@ where .collect(); update_times.sum(0_u128, update_parallel_batch_time); - self.stats.lu_times.push(lu_times); self.stats.update_times.push(update_times); let lu_step_duration = lu_step_start.elapsed().as_millis() - update_parallel_batch_time; @@ -922,54 +2016,160 @@ where fn extract_step(&self) -> (CommutativeFactors, Vec, Vec) { let rows: Vec = (0..self.y_data.dim).collect(); - let mut acc_ind_s = Vec::new(); - let mut acc_ind_r = Vec::new(); - for inds in self.ind_s.iter() { + let mut acc_ind_s = Vec::with_capacity(self.ind_s.iter().map(Vec::len).sum()); + let mut acc_ind_r = Vec::with_capacity(self.ind_r.iter().map(Vec::len).sum()); + + let diag_box_count = self.ind_r.iter().filter(|inds| !inds.is_empty()).count(); + let max_diag_box_size = self.ind_r.iter().map(Vec::len).max().unwrap_or(0); + let total_diag_rows: usize = self.ind_r.iter().map(Vec::len).sum(); + let avg_diag_box_size = if diag_box_count == 0 { + 0.0 + } else { + total_diag_rows as f64 / diag_box_count as f64 + }; + + for inds in &self.ind_s { acc_ind_s.extend_from_slice(inds); } - for inds in self.ind_r.iter() { + for inds in &self.ind_r { acc_ind_r.extend_from_slice(inds); } - let mut cols = acc_ind_r; + let mut cols = Vec::with_capacity(acc_ind_r.len() + acc_ind_s.len() + self.y_data.dim); + cols.extend_from_slice(&acc_ind_r); cols.extend_from_slice(&acc_ind_s); + let mut seen = vec![false; self.y_data.dim]; + for &c in &cols { + seen[c] = true; + } + let remaining_indices = rows - .clone() - .into_iter() - .filter(|&el| !cols.contains(&el)) + .iter() + .copied() + .filter(|&el| !seen[el]) .collect::>(); + cols.extend_from_slice(&remaining_indices); - let mut diag_box_factors: CommutativeFactors = CommutativeFactorsOperations::new(); - let mut diag_box_res: Vec<_> = self - .ind_r - .par_iter() - .map(|inds| { - DiagBoxFactor::new( - &mut inds.to_vec(), + let mut diag_box_factors = CommutativeFactors::new(); + let fixed_rank = self.anticipated_fixed_rank_samples.is_some(); + let chunk_size = self.options.num_threads.max(1); + + println!( + "[diag] boxes={} avg_box_size={avg_diag_box_size:.2} max_box_size={} chunk_size={} fixed_rank={} method={:?}", + diag_box_count, + max_diag_box_size, + chunk_size, + fixed_rank, + self.options.extract_db_options.block_extraction_method + ); + + let extraction_start = Instant::now(); + + for inds in &self.ind_r { + if inds.is_empty() { + continue; + } + + let mut diag_scratch = DiagExtractionScratch::::new(); + let factor = if self.options.symmetry.complex_symmetric_val::() { + DiagBoxFactor::new_complex_symm_with_scratch( + inds.to_vec(), &self.y_data, self.active_samples, - &self.options, + fixed_rank, + &self.options.extract_db_options, + &mut diag_scratch, ) - }) - .collect(); + } else if self.options.symmetry.symm_val() { + DiagBoxFactor::new_symm_with_scratch( + inds.to_vec(), + &self.y_data, + self.active_samples, + fixed_rank, + &self.options.extract_db_options, + &mut diag_scratch, + matches!(self.options.symmetry, Symmetry::Hermitian), + ) + } else { + DiagBoxFactor::new_no_symm_with_scratch( + inds.to_vec(), + &self.y_data, + &self.z_data, + self.active_samples, + fixed_rank, + &self.options.extract_db_options, + &mut diag_scratch, + ) + }; + diag_box_factors.add_factor(Factor::Diag(factor.0.unwrap())); + } + + let mut diag_scratch = DiagExtractionScratch::::new(); + let skeleton_diag = if self.options.symmetry.complex_symmetric_val::() { + DiagBoxFactor::new_complex_symm_with_scratch( + acc_ind_s.to_vec(), + &self.y_data, + self.active_samples, + fixed_rank, + &self.options.extract_db_options, + &mut diag_scratch, + ) + } else if self.options.symmetry.symm_val() { + DiagBoxFactor::new_symm_with_scratch( + acc_ind_s.to_vec(), + &self.y_data, + self.active_samples, + fixed_rank, + &self.options.extract_db_options, + &mut diag_scratch, + matches!(self.options.symmetry, Symmetry::Hermitian), + ) + } else { + DiagBoxFactor::new_no_symm_with_scratch( + acc_ind_s.to_vec(), + &self.y_data, + &self.z_data, + self.active_samples, + fixed_rank, + &self.options.extract_db_options, + &mut diag_scratch, + ) + }; - diag_box_res.push(DiagBoxFactor::new( - &mut acc_ind_s.to_vec(), - &self.y_data, - self.active_samples, - &self.options, - )); + /*if self.options.symmetry.symm_val() { + diag_box_res.push(DiagBoxFactor::new( + &mut acc_ind_s.to_vec(), + &self.y_data, + self.active_samples, + &self.options.extract_db_options, + )); + } else { + diag_box_res.push(DiagBoxFactor::new_no_symm( + &mut acc_ind_s.to_vec(), + &self.y_data, + &self.z_data, + self.active_samples, + &self.options.extract_db_options, + )); + }*/ - diag_box_res.into_iter().for_each(|(dbres, _dbtime)| { + println!( + "Extraction time: {:.3} ms ({:?})", + extraction_start.elapsed().as_secs_f64() * 1.0e3, + extraction_start.elapsed() + ); + + std::iter::once(skeleton_diag).for_each(|(dbres, _dbtime)| { diag_box_factors.add_factor(Factor::Diag(dbres.unwrap())); }); (diag_box_factors, cols, rows) } + fn get_near_indices(&mut self, box_ind: usize) -> Vec { let mut near_indices = Vec::new(); for ind in self.near_inds[box_ind].iter() { @@ -1065,12 +2265,9 @@ where // Step 9: Debug / info output let boxes_lengths: Vec<_> = self.ind_s.iter().map(Vec::len).collect(); let active_indices = boxes_lengths.iter().sum::(); + let active_boxes = self.ind_s.iter().filter(|v| !v.is_empty()).count(); - println!( - "New {} boxes, with {} active indices.", - self.ind_s.len(), - active_indices, - ); + println!("New {active_boxes} active boxes of a total of {num_boxes}, and active indices: {active_indices}"); if self.stats.limiting_factors.limiting_level.active_points < active_indices { self.stats.limiting_factors.limiting_level.level = @@ -1123,39 +2320,58 @@ where // Print box info let total_active: usize = self.target_inds.iter().map(Vec::len).sum(); - println!("New {num_boxes} boxes, and active indices: {total_active}"); + let active_boxes = self.ind_s.iter().filter(|v| !v.is_empty()).count(); + println!("New {active_boxes} active boxes of a total of {num_boxes}, and active indices: {total_active}"); self.stats.limiting_factors.max_level = self.level_indexing.current_level; } } fn group_near_fields(&mut self, current_box_indices: &[usize]) -> Vec> { - // Get the next level's keys and the current level's keys - let num_indices = current_box_indices.len(); + + let is_occupied = |b: usize| !self.ind_s[b].is_empty(); + let current_set: FxHashSet = current_box_indices.iter().copied().collect(); let mut group_contents: Vec> = Vec::with_capacity(num_indices); let mut group_indices: Vec> = Vec::with_capacity(num_indices); - let inds = (0..num_indices).collect::>(); // optional: sort here by neighbor size - - 'outer: for ind in inds { - let current_neighbors = &self.near_inds[current_box_indices[ind]]; + 'outer: for (ind, &g) in current_box_indices.iter().enumerate().take(num_indices) { + if !is_occupied(g) { + continue; + } for (group_set, group) in group_contents.iter_mut().zip(group_indices.iter_mut()) { - let has_overlap = current_neighbors.iter().any(|x| group_set.contains(x)); - if !has_overlap { - group_set.extend(current_neighbors.iter().copied()); + let conflict = self.near_inds[g] + .iter() + .copied() + .filter(|&n| current_set.contains(&n) && is_occupied(n)) + .any(|n| group_set.contains(&n)); + + if !conflict { + group_set.insert(g); + group_set.extend( + self.near_inds[g] + .iter() + .copied() + .filter(|&n| current_set.contains(&n) && is_occupied(n)), + ); group.push(ind); continue 'outer; } } - // No compatible group found, create a new one let mut new_set = FxHashSet::default(); - new_set.extend(current_neighbors.iter().copied()); + new_set.insert(g); + new_set.extend( + self.near_inds[g] + .iter() + .copied() + .filter(|&n| current_set.contains(&n) && is_occupied(n)), + ); group_contents.push(new_set); group_indices.push(vec![ind]); } + group_indices } } @@ -1228,3 +2444,41 @@ fn pick_ranks( RankPicking::Tol => None, } } + +#[cfg(test)] +mod tests { + use super::{fixed_rank_skeleton_upper_size, FixedRankSampleBudget}; + use std::collections::HashMap; + + #[test] + fn fixed_rank_floor_override_uses_constant_budget() { + let budget = FixedRankSampleBudget::PerLevel { + total_samples: 315, + active_samples_by_level: HashMap::from([(3, 105), (2, 315)]), + box_samples_by_level: HashMap::new(), + }; + + let floored = budget.with_global_min_samples(400); + + assert_eq!(floored, FixedRankSampleBudget::Constant { samples: 400 }); + } + + #[test] + fn fixed_rank_floor_override_preserves_budget_when_floor_is_lower() { + let budget = FixedRankSampleBudget::PerLevel { + total_samples: 315, + active_samples_by_level: HashMap::from([(3, 105), (2, 315)]), + box_samples_by_level: HashMap::new(), + }; + + let floored = budget.clone().with_global_min_samples(200); + + assert_eq!(floored, budget); + } + + #[test] + fn fixed_rank_estimate_caps_only_current_level_boxes() { + assert_eq!(fixed_rank_skeleton_upper_size(3, 3, 1, 118, 20), 20); + assert_eq!(fixed_rank_skeleton_upper_size(2, 3, 1, 118, 20), 118); + } +} diff --git a/src/rsrs/rsrs_factors.rs b/src/rsrs/rsrs_factors.rs deleted file mode 100644 index 9321756..0000000 --- a/src/rsrs/rsrs_factors.rs +++ /dev/null @@ -1,3023 +0,0 @@ -use super::{ - rsrs_cycle::{BoxType, ExtractOptions, RsrsOptions}, - sketch::SketchData, -}; -use crate::{ - rsrs::sketch::SamplingSpace, - utils::{ - data_ins_ext::{ExtInsType, Extraction, MatrixExtraction}, - elementary_matrix::{col_perm, col_subs, ext_cols, ext_rows, row_perm, row_subs}, - least_squares_and_null::{block_extraction, nullify_near_sketch}, - }, -}; -use itertools::min; -use mpi::{ - topology::SimpleCommunicator, - traits::{Communicator, Equivalence}, -}; -use rand_distr::{Distribution, Standard, StandardNormal}; -use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; -use rlst::{ - dense::{ - linalg::{ - interpolative_decomposition::Accuracy, lu::MatrixLu, - triangular_arrays::TriangularOperations, - }, - tools::RandScalar, - }, - prelude::*, -}; -use serde::{Deserialize, Serialize}; -use std::{ - collections::{HashMap, HashSet}, - rc::Rc, - time::{Duration, Instant}, -}; -type Real = ::Real; - -#[derive(Clone)] -pub struct MulOptions { - /// Inverse operation - pub inv: bool, - /// Transpose operation - pub trans: bool, - pub side: Side, - pub factor_type: FactorType, - pub t_trans: bool, -} - -pub enum OpInfo { - DecFact( - DynamicArray, - DynamicArray, - Vec, - Vec, - ), - DiagBlocks(Vec>), - Perm(Vec, Vec), -} - -#[derive(PartialEq)] -pub enum RsrsSide { - Squeeze, - Left, - Right, -} - -#[derive(Clone)] -pub enum FactorType { - F, - S, -} - -pub enum SquareArr { - Reg(RegDBox), - Lu(LuDBox), -} - -impl SquareArr -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - fn left_mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + Stride<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - right_arr: &mut Array, - factor_options: &MulOptions, - ) { - match self { - SquareArr::Reg(ref reg) => { - let mut new_right_arr = empty_array(); - let trans_mode = if factor_options.trans { - TransMode::ConjTrans - } else { - TransMode::NoTrans - }; - if factor_options.inv { - new_right_arr.r_mut().mult_into_resize( - trans_mode, - TransMode::NoTrans, - num::One::one(), - reg.inv_arr.r(), - right_arr.r(), - num::Zero::zero(), - ); - } else { - new_right_arr.r_mut().mult_into_resize( - trans_mode, - TransMode::NoTrans, - num::One::one(), - reg.arr.r(), - right_arr.r(), - num::Zero::zero(), - ); - } - right_arr.r_mut().fill_from(new_right_arr.r()); - } - SquareArr::Lu(ref lu) => { - if factor_options.inv { - match factor_options.trans { - false => { - lu.perm.left_mul(right_arr, factor_options); - as TriangularOperations>::solve( - &lu.l_arr, - right_arr, - Side::Left, - TransMode::NoTrans, - ); - as TriangularOperations>::solve( - &lu.u_arr, - right_arr, - Side::Left, - TransMode::NoTrans, - ); - } - true => { - as TriangularOperations>::solve( - &lu.u_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - as TriangularOperations>::solve( - &lu.l_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - lu.perm.left_mul(right_arr, factor_options); - } - } - } else { - match factor_options.trans { - false => { - as TriangularOperations>::mul( - &lu.u_arr, - right_arr, - Side::Left, - TransMode::NoTrans, - ); - as TriangularOperations>::mul( - &lu.l_arr, - right_arr, - Side::Left, - TransMode::NoTrans, - ); - lu.perm.left_mul(right_arr, factor_options); - } - true => { - lu.perm.left_mul(right_arr, factor_options); - as TriangularOperations>::mul( - &lu.l_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - as TriangularOperations>::mul( - &lu.u_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - } - } - } - } - } - } - - fn right_mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + Stride<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - right_arr: &mut Array, - factor_options: &MulOptions, - ) { - match self { - SquareArr::Reg(ref reg) => { - let mut new_right_arr = empty_array(); - let trans_mode = if factor_options.trans { - TransMode::ConjTrans - } else { - TransMode::NoTrans - }; - - if factor_options.inv { - new_right_arr.r_mut().mult_into_resize( - TransMode::NoTrans, - trans_mode, - num::One::one(), - right_arr.r(), - reg.inv_arr.r(), - num::Zero::zero(), - ); - } else { - new_right_arr.r_mut().mult_into_resize( - TransMode::NoTrans, - trans_mode, - num::One::one(), - right_arr.r(), - reg.arr.r(), - num::Zero::zero(), - ); - } - right_arr.r_mut().fill_from(new_right_arr.r()); - } - SquareArr::Lu(ref lu) => { - if factor_options.inv { - match factor_options.trans { - false => { - as TriangularOperations>::solve( - &lu.u_arr, - right_arr, - Side::Right, - TransMode::NoTrans, - ); - as TriangularOperations>::solve( - &lu.l_arr, - right_arr, - Side::Right, - TransMode::NoTrans, - ); - - lu.perm.right_mul(right_arr, factor_options); - } - true => { - lu.perm.right_mul(right_arr, factor_options); - as TriangularOperations>::solve( - &lu.l_arr, - right_arr, - Side::Right, - TransMode::ConjTrans, - ); - as TriangularOperations>::solve( - &lu.u_arr, - right_arr, - Side::Right, - TransMode::ConjTrans, - ); - } - } - } else { - match factor_options.trans { - false => { - lu.perm.right_mul(right_arr, factor_options); - as TriangularOperations>::mul( - &lu.l_arr, - right_arr, - Side::Right, - TransMode::NoTrans, - ); - as TriangularOperations>::mul( - &lu.u_arr, - right_arr, - Side::Right, - TransMode::NoTrans, - ); - } - true => { - as TriangularOperations>::mul( - &lu.u_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - as TriangularOperations>::mul( - &lu.l_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - lu.perm.left_mul(right_arr, factor_options); - } - } - } - } - } - } - - pub fn mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + Stride<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - right_arr: &mut Array, - side: Side, - factor_options: &MulOptions, - ) { - match side { - Side::Left => self.left_mul(right_arr, factor_options), - Side::Right => self.right_mul(right_arr, factor_options), - } - } - - fn cond(&self) -> (Real, Real) { - match self { - SquareArr::Reg(reg_dbox) => (condition_number(®_dbox.arr), num::Zero::zero()), - SquareArr::Lu(lu_dbox) => ( - condition_number(&lu_dbox.l_arr.tri), - condition_number(&lu_dbox.u_arr.tri), - ), - } - } -} - -pub struct ComposedFactorData { - sq: SquareArr, - rectg: DynamicArray, -} - -impl ComposedFactorData -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - fn mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + Stride<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - target_arr: &Array, - factor_options: &MulOptions, - ) -> DynamicArray { - let mut res_mul: DynamicArray = empty_array::(); - let mut sq_factor_options = factor_options.clone(); - sq_factor_options.inv = true; - match factor_options.side { - Side::Left => { - if !factor_options.trans { - res_mul.r_mut().mult_into_resize( - TransMode::NoTrans, - TransMode::NoTrans, - num::One::one(), - self.rectg.r(), - target_arr.r(), - num::Zero::zero(), - ); - self.sq - .mul(&mut res_mul, factor_options.side, &sq_factor_options); - } else { - let mut aux_target_arr = empty_array(); - aux_target_arr.r_mut().fill_from_resize(target_arr.r()); - self.sq.mul( - &mut aux_target_arr.r_mut(), - factor_options.side, - &sq_factor_options, - ); - res_mul.r_mut().mult_into_resize( - TransMode::ConjTrans, - TransMode::NoTrans, - num::One::one(), - self.rectg.r(), - aux_target_arr.r(), - num::Zero::zero(), - ); - } - } - Side::Right => { - if !factor_options.trans { - let mut aux_target_arr = empty_array(); - aux_target_arr.r_mut().fill_from_resize(target_arr.r()); - self.sq.mul( - &mut aux_target_arr.r_mut(), - factor_options.side, - &sq_factor_options, - ); - res_mul.r_mut().mult_into_resize( - TransMode::NoTrans, - TransMode::NoTrans, - num::One::one(), - aux_target_arr.r(), - self.rectg.r(), - num::Zero::zero(), - ); - } else { - res_mul.r_mut().mult_into_resize( - TransMode::NoTrans, - TransMode::ConjTrans, - num::One::one(), - target_arr.r(), - self.rectg.r(), - num::Zero::zero(), - ); - self.sq - .mul(&mut res_mul, factor_options.side, &sq_factor_options); - } - } - } - - res_mul - } - - #[allow(clippy::type_complexity)] - fn cond(&self) -> (Real, Option<(Real, Real)>) { - (condition_number(&self.rectg), Some(self.sq.cond())) - } -} - -pub enum FactorData { - Comp(ComposedFactorData), - Reg(DynamicArray), -} - -impl FactorData -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - fn mul< - ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - right_arr: &Array, - factor_options: &MulOptions, - c_indices: &[usize], - r_indices: &[usize], - ) -> DynamicArray { - match factor_options.side { - Side::Left => { - let row_indices: Vec; - let col_indices: Vec; - - if factor_options.trans { - col_indices = r_indices.to_vec(); - row_indices = c_indices.to_vec(); - } else { - col_indices = c_indices.to_vec(); - row_indices = r_indices.to_vec(); - } - - let (axis, transposed) = if factor_options.t_trans { - (1, true) - } else { - (0, false) - }; - - let mut subarr_rows: DynamicArray = - as MatrixExtraction>::new( - right_arr, - ExtInsType::Axis(row_indices.clone(), axis, transposed), - ) - .unwrap() - .ext; - - let mut subarr_cols: DynamicArray = - as MatrixExtraction>::new( - right_arr, - ExtInsType::Axis(col_indices.clone(), axis, transposed), - ) - .unwrap() - .ext; - - let res_mul = match self { - FactorData::Comp(composed_factor_data) => { - composed_factor_data.mul(&subarr_cols, factor_options) - } - FactorData::Reg(array) => { - let mut res_mul: DynamicArray = empty_array::(); - if factor_options.trans { - res_mul.r_mut().mult_into_resize( - TransMode::ConjTrans, - TransMode::NoTrans, - num::One::one(), - array.r(), - subarr_cols.r_mut(), - num::Zero::zero(), - ); - } else { - res_mul.r_mut().mult_into_resize( - TransMode::NoTrans, - TransMode::NoTrans, - num::One::one(), - array.r(), - subarr_cols.r_mut(), - num::Zero::zero(), - ); - } - res_mul - } - }; - - if factor_options.inv { - subarr_rows.sub_into(res_mul.r()); - } else { - subarr_rows.sum_into(res_mul.r()); - } - - subarr_rows - } - Side::Right => { - let row_indices: Vec; - let col_indices: Vec; - - if factor_options.trans { - col_indices = r_indices.to_vec(); - row_indices = c_indices.to_vec(); - } else { - col_indices = c_indices.to_vec(); - row_indices = r_indices.to_vec(); - } - - let (axis, transposed) = if factor_options.t_trans { - (0, true) - } else { - (1, false) - }; - - let mut subarr_rows: DynamicArray = - as MatrixExtraction>::new( - right_arr, - ExtInsType::Axis(row_indices.clone(), axis, transposed), - ) - .unwrap() - .ext; - - let mut subarr_cols: DynamicArray = - as MatrixExtraction>::new( - right_arr, - ExtInsType::Axis(col_indices.clone(), axis, transposed), - ) - .unwrap() - .ext; - - let res_mul = match self { - FactorData::Comp(composed_factor_data) => { - composed_factor_data.mul(&subarr_rows, factor_options) - } - FactorData::Reg(array) => { - let mut res_mul: DynamicArray = empty_array::(); - if factor_options.trans { - res_mul.r_mut().mult_into_resize( - TransMode::NoTrans, - TransMode::ConjTrans, - num::One::one(), - subarr_rows.r_mut(), - array.r(), - num::Zero::zero(), - ); - } else { - res_mul.r_mut().mult_into_resize( - TransMode::NoTrans, - TransMode::NoTrans, - num::One::one(), - subarr_rows.r_mut(), - array.r(), - num::Zero::zero(), - ); - } - res_mul - } - }; - - if factor_options.inv { - subarr_cols.sub_into(res_mul.r()); - } else { - subarr_cols.sum_into(res_mul.r()); - } - - subarr_cols - } - } - } - - fn cond(&self) -> CondType { - match self { - FactorData::Comp(composed_factor_data) => composed_factor_data.cond(), - FactorData::Reg(array) => (condition_number(array), None), - } - } -} - -type CondType = (Real, Option<(Real, Real)>); -pub struct IdFactor { - data: FactorData, - pub perm: Vec, - pub ind_r: Vec, //row_indices - pub ind_s: Vec, //col_indices - pub ind_f: Vec, -} - -pub struct LuFactor { - l_arr: FactorData, - u_arr: FactorData, - hermitian: bool, - pub ind_r: Vec, //cols - pub ind_t: Vec, //rows -} - -pub struct LuDBox { - pub u_arr: TriangularMatrix, - pub l_arr: TriangularMatrix, - pub perm: PermFactor, -} - -pub struct RegDBox { - pub arr: DynamicArray, - pub inv_arr: DynamicArray, -} - -pub enum DiagBoxType { - Reg(RegDBox), - Lu(LuDBox), -} - -type DiagBoxArr = DiagBoxType; -pub struct PermFactor { - pub orig_indices: Vec, - pub perm_indices: Vec, -} - -pub struct DiagBoxFactor { - pub arr: DiagBoxArr, - pub inds: Vec, -} - -pub enum Factor { - Lu(LuFactor), - Id(IdFactor), - Diag(DiagBoxFactor), -} - -pub struct RsrsFactors { - pub num_levels: usize, - pub id_factors: LevelIdFactors, - pub lu_factors: LevelLuFactors, - pub near_field_inds: LevelNearFieldInds, - pub perm_factor: PermFactor, - pub diag_box_factors: DiagBoxFactors, - pub dim: usize, -} - -pub struct RsrsMulType { - pub side: RsrsSide, - pub factor_type: FactorType, - pub t_trans: bool, -} - -type DiagBoxFactors = CommutativeFactors; -type LevelLuFactors = Vec>>; -type LevelIdFactors = Vec>; -type LevelNearFieldInds = Vec>>; -pub type CommutativeFactors = Vec>; - -#[derive(Debug, Serialize, Clone)] -pub struct LuTimes { - pub extraction: u128, - pub lu: u128, -} - -#[derive(Debug, Serialize, Clone)] -pub struct IdTimes { - pub nullification: u128, - pub id: u128, -} - -pub enum Times { - Lu(LuTimes), - Id(IdTimes), -} - -pub fn condition_number(mat: &DynamicArray) -> Real { - let shape = mat.shape(); - let dim: usize = min(shape).unwrap(); - let mut singular_values: DynamicArray, 1> = rlst_dynamic_array1!(Real, [dim]); - let mode: SvdMode = SvdMode::Reduced; - let mut u: DynamicArray = rlst_dynamic_array2!(Item, [shape[0], dim]); - let mut vt: DynamicArray = rlst_dynamic_array2!(Item, [dim, shape[1]]); - - let mut aux_data = empty_array(); - aux_data.fill_from_resize(mat.r()); - - aux_data - .r_mut() - .into_svd_alloc(u.r_mut(), vt.r_mut(), singular_values.data_mut(), mode) - .unwrap(); - - let sigma_max = singular_values[[0]]; - let sigma_min = singular_values[[dim - 1]]; - - sigma_max / sigma_min -} - -fn get_far_indices(n: usize, near_indices: Vec) -> Vec { - let near_set: HashSet = near_indices.into_iter().collect(); - (0..n).filter(|x| !near_set.contains(x)).collect() -} - -fn null_sketch_near_field< - Item: RlstScalar + MatrixId + MatrixInverse + MatrixPseudoInverse + RandScalar + MatrixLu + MatrixQr, ->( - target_inds: &[usize], - near_field_inds: &[usize], - sketch: &DynamicArray, - test: &DynamicArray, - subs_sample_dim: usize, - rsrs_options: &RsrsOptions, -) -> DynamicArray -where - StandardNormal: Distribution, - Standard: Distribution, - QrDecomposition, 2>>: - MatrixQrDecomposition, - LuDecomposition, 2>>: - MatrixLuDecomposition, -{ - let dim = test.shape()[1]; - let sub_test = test.r().into_subview([0, 0], [subs_sample_dim, dim]); - let sub_sketch = sketch.r().into_subview([0, 0], [subs_sample_dim, dim]); - let mut sketch_t = as MatrixExtraction>::new( - &sub_sketch, - ExtInsType::Axis(target_inds.to_vec(), 1, false), - ) - .unwrap() - .ext; - let test_n = as MatrixExtraction>::new( - &sub_test, - ExtInsType::Axis(near_field_inds.to_vec(), 1, false), - ) - .unwrap() - .ext; - nullify_near_sketch(&test_n, &mut sketch_t, &rsrs_options.id_options); - sketch_t -} - -fn null_near_field< - Item: RlstScalar + MatrixId + MatrixInverse + MatrixPseudoInverse + RandScalar + MatrixLu + MatrixQr, ->( - target_inds: &[usize], - near_field_inds: &[usize], - y_data: &SketchData, - z_data: &SketchData, - subs_sample_dim: usize, - rsrs_options: &RsrsOptions, -) -> DynamicArray -where - StandardNormal: Distribution, - Standard: Distribution, - QrDecomposition, 2>>: - MatrixQrDecomposition, - LuDecomposition, 2>>: - MatrixLuDecomposition, -{ - let far_field_sketch = if rsrs_options.hermitian { - null_sketch_near_field( - target_inds, - near_field_inds, - &y_data.sketch, - &y_data.test, - subs_sample_dim, - rsrs_options, - ) - } else { - let null_y_sketch = null_sketch_near_field( - target_inds, - near_field_inds, - &y_data.sketch, - &y_data.test, - subs_sample_dim, - rsrs_options, - ); - let null_z_sketch = null_sketch_near_field( - target_inds, - near_field_inds, - &z_data.sketch, - &z_data.test, - subs_sample_dim, - rsrs_options, - ); - let mut sketch_sum = empty_array(); - sketch_sum.fill_from_resize(null_y_sketch.r() + null_z_sketch.r()); - sketch_sum - }; - - far_field_sketch -} - -fn near_box_extraction( - ind_r: &[usize], - near_field_inds: &[usize], - sketch_data: &SketchData, - subs_sample_dim: usize, - lu_options: &ExtractOptions, - r_numbering: &[usize], - t_numbering: &[usize], -) -> ( - DynamicArray, - DynamicArray, - (Duration, Duration), -) -where - LuDecomposition, 2>>: - MatrixLuDecomposition, -{ - let dim = sketch_data.test.shape()[1]; - let test_subview = sketch_data - .test - .r() - .into_subview([0, 0], [subs_sample_dim, dim]); - let sketch_subview = sketch_data - .sketch - .r() - .into_subview([0, 0], [subs_sample_dim, dim]); - let start = Instant::now(); - let sketch_r: DynamicArray = as MatrixExtraction>::new( - &sketch_subview, - ExtInsType::Axis(ind_r.to_vec(), 1, false), - ) - .unwrap() - .ext; - let mut test_n: DynamicArray = as MatrixExtraction>::new( - &test_subview, - ExtInsType::Axis(near_field_inds.to_vec(), 1, false), - ) - .unwrap() - .ext; - - let mut lu_io_time = start.elapsed(); - let start = Instant::now(); - let near_box = block_extraction(&mut test_n, &sketch_r, lu_options); - let lu_b_ext_time = start.elapsed(); - let start = Instant::now(); - let data_r = as MatrixExtraction>::new( - &near_box, - ExtInsType::Axis(r_numbering.to_vec(), 0, false), - ) - .unwrap() - .ext; - let data_n = as MatrixExtraction>::new( - &near_box, - ExtInsType::Axis(t_numbering.to_vec(), 0, false), - ) - .unwrap() - .ext; - let lu_small_io_time = start.elapsed(); - lu_io_time += lu_small_io_time; - (data_r, data_n, (lu_io_time, lu_b_ext_time)) -} - -pub trait FactorOperations: Sized { - type Item: RlstScalar; - - fn mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - target_arr: &mut Array, - options: &MulOptions, - ); - - fn mul_data< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - target_arr: &Array, - options: &MulOptions, - ) -> DynamicArray; - - fn ins_data< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - source_arr: &DynamicArray, - target_arr: &mut Array, - options: &MulOptions, - ); - - //fn cond(&self) -> (Real, Real); -} - -impl< - Item: RlstScalar - + MatrixId - + MatrixInverse - + MatrixPseudoInverse - + RandScalar - + MatrixLu - + MatrixQr, - > IdFactor -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - pub fn new( - target_inds: &mut [usize], - near_field_inds: &mut [usize], - y_data: &SketchData, - z_data: &SketchData, - subs_sample_dim: usize, - rank_par: &BoxType>, - options: &RsrsOptions, - ) -> (Option, Times) - where - StandardNormal: Distribution, - Standard: Distribution, - LuDecomposition, 2>>: - MatrixLuDecomposition, - QrDecomposition, 2>>: - MatrixQrDecomposition, - { - let start: Instant = Instant::now(); - let test_shape = [subs_sample_dim, near_field_inds.len()]; - let sketch_shape = [subs_sample_dim, target_inds.len()]; - let null_shape = [test_shape[0] - test_shape[1], sketch_shape[1]]; - - let far_field_sketch = null_near_field( - target_inds, - near_field_inds, - y_data, - z_data, - subs_sample_dim, - options, - ); - - let nullification_time: Duration = start.elapsed(); - let start: Instant = Instant::now(); - let max_rank: usize = *far_field_sketch.shape().iter().min().unwrap(); - let id_sketch = match rank_par { - BoxType::Full(tol) => { - if *tol < num::One::one() { - far_field_sketch - .into_subview([0, 0], null_shape) - .into_id_alloc(Accuracy::Tol(*tol), TransMode::Trans) - .unwrap() - } else { - far_field_sketch - .into_subview([0, 0], null_shape) - .into_id_alloc( - Accuracy::FixedRank(num::ToPrimitive::to_usize(tol).unwrap()), - TransMode::Trans, - ) - .unwrap() - } - } - BoxType::Merged(rank) => far_field_sketch - .into_subview([0, 0], null_shape) - .into_id_alloc(Accuracy::FixedRank(*rank), TransMode::Trans) - .unwrap(), - }; - let k: usize = id_sketch.rank; - let mut ind_r = Vec::new(); - let mut ind_s = Vec::new(); - let id_time = start.elapsed(); - - let id_times = IdTimes { - nullification: nullification_time.as_millis(), - id: id_time.as_millis(), - }; - - let times = Times::Id(id_times); - - if id_sketch.rank < max_rank { - let aux_indices = target_inds.to_vec(); - - for (id, &elem) in id_sketch.perm.iter().enumerate() { - let val = aux_indices[elem]; - target_inds[id] = val; - near_field_inds[id] = val; - } - - ind_r.extend_from_slice(&target_inds[k..]); - ind_s.extend_from_slice(&target_inds[..k]); - let ind_f = get_far_indices(y_data.dim, near_field_inds.to_vec()); - - ( - Some(Self { - data: FactorData::Reg(id_sketch.id_mat), - perm: id_sketch.perm, - ind_r, - ind_s, - ind_f, - }), - times, - ) - } else { - (None, times) - } - } - - pub fn cond(&self) -> (CondType, Option>) { - (self.data.cond(), None) - } -} - -impl< - Item: RlstScalar - + MatrixId - + MatrixInverse - + MatrixPseudoInverse - + RandScalar - + MatrixLu - + MatrixQr, - > FactorOperations for IdFactor -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - type Item = Item; - - fn mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - target_arr: &mut Array, - factor_options: &MulOptions, - ) { - let target_block = self.mul_data(target_arr, factor_options); - let t_arr_mutex = std::sync::Mutex::new(target_arr); - self.ins_data(&target_block, *t_arr_mutex.lock().unwrap(), factor_options); - } - - fn mul_data< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - target_arr: &Array, - options: &MulOptions, - ) -> DynamicArray { - let mut trans = options.trans; - - match options.factor_type { - FactorType::F => {} - FactorType::S => { - trans = !trans; - } - } - - let mut aux_options = options.clone(); - aux_options.trans = trans; - - self.data - .mul(target_arr, &aux_options, &self.ind_s, &self.ind_r) - } - - fn ins_data< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - source_arr: &DynamicArray, - target_arr: &mut Array, - options: &MulOptions, - ) { - let mut trans = options.trans; - - match options.factor_type { - FactorType::F => {} - FactorType::S => { - trans = !trans; - } - } - - match options.side { - Side::Left => { - row_subs( - self.ind_s.clone(), - self.ind_r.clone(), - source_arr, - target_arr, - trans, - options.t_trans, - ); - } - Side::Right => { - col_subs( - self.ind_s.clone(), - self.ind_r.clone(), - source_arr, - target_arr, - trans, - options.t_trans, - ); - } - } - } -} - -#[derive(Debug, Clone, Deserialize)] -pub enum PivotMethod { - DirectInversion, - Lu, -} - -pub fn inv_diagonal(arr: &DynamicArray) -> DynamicArray { - let shape = arr.shape(); - let mut d_inv = rlst_dynamic_array2!(Item, shape); - let mut view_1 = d_inv.r_mut(); - let view_2 = arr.r(); - for i in 0..shape[0] { - view_1[[i, i]] = ::one() / view_2[[i, i]]; - } - d_inv -} - -impl LuFactor -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - pub fn new( - ind_r: &mut [usize], - near_field_inds: &mut [usize], - y_data: &SketchData, - z_data: &SketchData, - subs_sample_dim: usize, - options: &RsrsOptions, - ) -> (Option, Times) - where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, - { - let mut r_numbering: Vec = Vec::new(); - let mut t_numbering: Vec = Vec::new(); - let mut ind_t = Vec::new(); - let near_field_ind_to_num: HashMap<_, _> = near_field_inds - .iter() - .enumerate() - .map(|(num, ind)| (ind, num)) - .collect(); - - for &elem in ind_r.iter() { - r_numbering.push(*near_field_ind_to_num.get(&elem).unwrap()); - } - - for (pos, &elem) in near_field_inds.iter().enumerate() { - if !ind_r.contains(&elem) { - t_numbering.push(pos); - ind_t.push(elem); - } - } - let (y_r, y_n, (_y_lu_io_time, y_lu_b_ext_time)) = near_box_extraction( - ind_r, - near_field_inds, - y_data, - subs_sample_dim, - &options.lu_options, - &r_numbering, - &t_numbering, - ); - let start = Instant::now(); - - let u_arr = match options.lu_options.pivot_method { - PivotMethod::DirectInversion => { - let mut y_r_inv = empty_array(); - y_r_inv.r_mut().fill_from_resize(y_r.r().transpose()); - y_r_inv.r_mut().into_inverse_alloc().unwrap(); - let mut rectg = empty_array(); - rectg.fill_from_resize(y_n.transpose()); - - let sq = RegDBox { - arr: y_r, - inv_arr: y_r_inv, - }; - let factor = ComposedFactorData { - sq: SquareArr::Reg(sq), - rectg, - }; - FactorData::Comp(factor) - } - PivotMethod::Lu => { - let shape = y_r.shape(); - let mut y_r_trans = empty_array(); - y_r_trans.fill_from_resize(y_r.r().transpose()); - let lu: LuDecomposition, 2>> = - ::into_lu_alloc(y_r_trans).unwrap(); - let mut l = rlst_dynamic_array2!(Item, shape); - let mut u = rlst_dynamic_array2!(Item, shape); - as MatrixLuDecomposition>::get_l(&lu, l.r_mut()); - as MatrixLuDecomposition>::get_u(&lu, u.r_mut()); - - let perm = as MatrixLuDecomposition>::get_perm(&lu); - - let orig: Vec<_> = (0..shape[1]).collect(); - - let lu_arr = LuDBox { - l_arr: TriangularMatrix::new(&l, TriangularType::Lower).unwrap(), - u_arr: TriangularMatrix::new(&u, TriangularType::Upper).unwrap(), - perm: PermFactor::new(orig, perm).unwrap(), - }; - - let sq = SquareArr::Lu(lu_arr); - let mut rectg = empty_array(); - rectg.fill_from_resize(y_n.transpose()); - let factor = ComposedFactorData { sq, rectg }; - FactorData::Comp(factor) - } - }; - - let u_assembly = start.elapsed(); - - let lu_b_ext_time; - let lu_assembly_time; - - let l_arr = if !options.hermitian { - let (z_r, z_n, (_z_lu_io_time, z_lu_b_ext_time)) = near_box_extraction( - ind_r, - near_field_inds, - z_data, - subs_sample_dim, - &options.lu_options, - &r_numbering, - &t_numbering, - ); - - let start = Instant::now(); - - let l_arr = match options.lu_options.pivot_method { - PivotMethod::DirectInversion => { - let mut z_r_inv = empty_array(); - z_r_inv.r_mut().fill_from_resize(z_r.r()); - z_r_inv.r_mut().into_inverse_alloc().unwrap(); - let mut rectg = empty_array(); - rectg.fill_from_resize(z_n); - - let sq = RegDBox { - arr: z_r, - inv_arr: z_r_inv, - }; - - let factor = ComposedFactorData { - sq: SquareArr::Reg(sq), - rectg, - }; - FactorData::Comp(factor) - } - PivotMethod::Lu => { - let shape = z_r.shape(); - let mut inv_arr = empty_array(); - inv_arr.fill_from_resize(z_r.r()); - let lu = ::into_lu_alloc(inv_arr).unwrap(); - let mut l = rlst_dynamic_array2!(Item, shape); - let mut u = rlst_dynamic_array2!(Item, shape); - - as MatrixLuDecomposition>::get_l(&lu, l.r_mut()); - as MatrixLuDecomposition>::get_u(&lu, u.r_mut()); - let perm = as MatrixLuDecomposition>::get_perm(&lu); - - let orig: Vec<_> = (0..shape[1]).collect(); - - let lu_arr = LuDBox { - l_arr: TriangularMatrix::new(&l, TriangularType::Lower).unwrap(), - u_arr: TriangularMatrix::new(&u, TriangularType::Upper).unwrap(), - perm: PermFactor::new(orig, perm).unwrap(), - }; - let sq = SquareArr::Lu(lu_arr); - let factor = ComposedFactorData { sq, rectg: z_n }; - FactorData::Comp(factor) - } - }; - - let l_assembly = start.elapsed(); - lu_b_ext_time = y_lu_b_ext_time + z_lu_b_ext_time; - lu_assembly_time = u_assembly + l_assembly; - l_arr - } else { - lu_b_ext_time = y_lu_b_ext_time; - lu_assembly_time = u_assembly; - FactorData::Reg(empty_array()) - }; - - let lu_times = LuTimes { - extraction: lu_b_ext_time.as_millis(), - lu: lu_assembly_time.as_millis(), - }; - - let times = Times::Lu(lu_times); - - let hermitian = options.hermitian; - ( - Some(Self { - l_arr, - u_arr, - hermitian, - ind_r: ind_r.to_vec(), - ind_t, - }), - times, - ) - } - - pub fn cond(&self) -> (CondType, Option>) { - if !self.hermitian { - (self.l_arr.cond(), Some(self.u_arr.cond())) - } else { - (self.u_arr.cond(), None) - } - } -} - -impl FactorOperations - for LuFactor -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - type Item = Item; - - fn mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - target_arr: &mut Array, - factor_options: &MulOptions, - ) { - let target_block = self.mul_data(target_arr, factor_options); - - let t_arr_mutex = std::sync::Mutex::new(target_arr); - self.ins_data(&target_block, *t_arr_mutex.lock().unwrap(), factor_options); - } - - fn mul_data< - ArrayImpl: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - target_arr: &Array, - options: &MulOptions, - ) -> DynamicArray { - let mut trans = options.trans; - if self.hermitian { - match options.factor_type { - FactorType::F => { - trans = !trans; - } - FactorType::S => {} - } - - let mut aux_options = options.clone(); - aux_options.trans = trans; - - self.u_arr - .mul(target_arr, &aux_options, &self.ind_t, &self.ind_r) - } else { - match options.factor_type { - FactorType::F => self - .l_arr - .mul(target_arr, options, &self.ind_r, &self.ind_t), - FactorType::S => self - .u_arr - .mul(target_arr, options, &self.ind_t, &self.ind_r), - } - } - } - - fn ins_data< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - source_arr: &DynamicArray, - target_arr: &mut Array, - options: &MulOptions, - ) { - let mut trans = options.trans; - - if self.hermitian { - match options.factor_type { - FactorType::F => { - trans = !trans; - } - FactorType::S => {} - } - - match options.side { - Side::Left => row_subs( - self.ind_t.clone(), - self.ind_r.clone(), - source_arr, - target_arr, - trans, - options.t_trans, - ), - Side::Right => col_subs( - self.ind_t.clone(), - self.ind_r.clone(), - source_arr, - target_arr, - trans, - options.t_trans, - ), - }; - } else { - match options.factor_type { - FactorType::F => { - match options.side { - Side::Left => row_subs( - self.ind_r.clone(), - self.ind_t.clone(), - source_arr, - target_arr, - options.trans, - options.t_trans, - ), - Side::Right => col_subs( - self.ind_r.clone(), - self.ind_t.clone(), - source_arr, - target_arr, - options.trans, - options.t_trans, - ), - }; - } - FactorType::S => { - match options.side { - Side::Left => row_subs( - self.ind_t.clone(), - self.ind_r.clone(), - source_arr, - target_arr, - options.trans, - options.t_trans, - ), - Side::Right => col_subs( - self.ind_t.clone(), - self.ind_r.clone(), - source_arr, - target_arr, - options.trans, - options.t_trans, - ), - }; - } - } - } - } -} - -impl PermFactor { - fn new(orig_indices: Vec, perm_indices: Vec) -> RlstResult { - Ok(Self { - orig_indices, - perm_indices, - }) - } - - pub fn left_mul< - Item: RlstScalar, - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - right_arr: &mut Array, - options: &MulOptions, - ) { - let orig_indices: Vec<_> = (0..right_arr.shape()[0]).collect(); - assert_eq!(orig_indices.len(), self.perm_indices.len()); - let mut trans = options.trans; - if options.inv { - trans = !trans; - } - row_perm( - orig_indices.clone(), - self.perm_indices.clone(), - right_arr, - trans, - ); - } - - pub fn right_mul< - Item: RlstScalar, - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - left_arr: &mut Array, - options: &MulOptions, - ) { - let orig_indices: Vec<_> = (0..left_arr.shape()[1]).collect(); - assert_eq!(orig_indices.len(), self.perm_indices.len()); - let mut trans = !options.trans; - if options.inv { - trans = !trans; - } - col_perm( - orig_indices.clone(), - self.perm_indices.clone(), - left_arr, - trans, - ); - } -} - -fn _add_diagonal( - arr: &mut DynamicArray, - val: ::Real, -) { - let shape = arr.shape(); - let mut view = arr.r_mut(); - for i in 0..shape[0] { - view[[i, i]] += Item::from_real(val); - } -} - -impl DiagBoxArr -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - fn new< - ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + RawAccess - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - inds: &[usize], - db_ext_options: &ExtractOptions, - sub_test: &Array, - sub_sketch: &Array, - ) -> Self { - let sketch_r: DynamicArray = as MatrixExtraction>::new( - sub_sketch, - ExtInsType::Axis(inds.to_vec(), 1, false), - ) - .unwrap() - .ext; - let mut test_c: DynamicArray = as MatrixExtraction>::new( - sub_test, - ExtInsType::Axis(inds.to_vec(), 1, false), - ) - .unwrap() - .ext; - let diag_box = block_extraction(&mut test_c, &sketch_r, db_ext_options); - - match db_ext_options.pivot_method { - PivotMethod::DirectInversion => { - let mut inv_arr = empty_array(); - inv_arr.fill_from_resize(diag_box.r().transpose().conj()); - inv_arr.r_mut().into_inverse_alloc().unwrap(); - let reg_arr = RegDBox { - arr: diag_box, - inv_arr, - }; - DiagBoxType::Reg(reg_arr) - } - PivotMethod::Lu => { - let shape = diag_box.shape(); - let mut inv_arr = empty_array(); - inv_arr.fill_from_resize(diag_box.r().transpose().conj()); - let lu = ::into_lu_alloc(inv_arr).unwrap(); - let mut l = rlst_dynamic_array2!(Item, shape); - let mut u = rlst_dynamic_array2!(Item, shape); - - as MatrixLuDecomposition>::get_l(&lu, l.r_mut()); - as MatrixLuDecomposition>::get_u(&lu, u.r_mut()); - - let perm = as MatrixLuDecomposition>::get_perm(&lu); - - let orig: Vec<_> = (0..shape[1]).collect(); - - let lu_arr = LuDBox { - l_arr: TriangularMatrix::new(&l, TriangularType::Lower).unwrap(), - u_arr: TriangularMatrix::new(&u, TriangularType::Upper).unwrap(), - perm: PermFactor::new(orig, perm).unwrap(), - }; - DiagBoxType::Lu(lu_arr) - } - } - } - - fn left_mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + Stride<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - right_arr: &mut Array, - factor_options: &MulOptions, - ) { - match self { - DiagBoxType::Reg(ref reg) => { - let trans_mode = if factor_options.trans { - TransMode::ConjTrans - } else { - TransMode::NoTrans - }; - - let mut new_right_arr = empty_array(); - if factor_options.inv { - new_right_arr.r_mut().mult_into_resize( - trans_mode, - TransMode::NoTrans, - num::One::one(), - reg.inv_arr.r(), - right_arr.r(), - num::Zero::zero(), - ); - } else { - new_right_arr.r_mut().mult_into_resize( - trans_mode, - TransMode::NoTrans, - num::One::one(), - reg.arr.r(), - right_arr.r(), - num::Zero::zero(), - ); - } - right_arr.r_mut().fill_from(new_right_arr.r()); - } - DiagBoxType::Lu(ref lu) => { - if factor_options.inv { - if factor_options.trans { - as TriangularOperations>::solve( - &lu.u_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - as TriangularOperations>::solve( - &lu.l_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - lu.perm.left_mul(right_arr, factor_options); - } else { - lu.perm.left_mul(right_arr, factor_options); - as TriangularOperations>::solve( - &lu.l_arr, - right_arr, - Side::Left, - TransMode::NoTrans, - ); - as TriangularOperations>::solve( - &lu.u_arr, - right_arr, - Side::Left, - TransMode::NoTrans, - ); - } - } else if factor_options.trans { - lu.perm.left_mul(right_arr, factor_options); - as TriangularOperations>::mul( - &lu.l_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - as TriangularOperations>::mul( - &lu.u_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - } else { - as TriangularOperations>::mul( - &lu.u_arr, - right_arr, - Side::Left, - TransMode::NoTrans, - ); - as TriangularOperations>::mul( - &lu.l_arr, - right_arr, - Side::Left, - TransMode::NoTrans, - ); - lu.perm.left_mul(right_arr, factor_options); - } - } - } - } - - fn right_mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + Stride<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - right_arr: &mut Array, - factor_options: &MulOptions, - ) { - match self { - DiagBoxType::Reg(ref reg) => { - let trans_mode = if factor_options.trans { - TransMode::ConjTrans - } else { - TransMode::NoTrans - }; - - let mut new_right_arr = empty_array(); - if factor_options.inv { - new_right_arr.r_mut().mult_into_resize( - TransMode::NoTrans, - trans_mode, - num::One::one(), - right_arr.r(), - reg.inv_arr.r(), - num::Zero::zero(), - ); - } else { - new_right_arr.r_mut().mult_into_resize( - TransMode::NoTrans, - trans_mode, - num::One::one(), - right_arr.r(), - reg.arr.r(), - num::Zero::zero(), - ); - } - right_arr.r_mut().fill_from(new_right_arr.r()); - } - DiagBoxType::Lu(ref lu) => { - if factor_options.inv { - if factor_options.trans { - lu.perm.right_mul(right_arr, factor_options); - as TriangularOperations>::solve( - &lu.l_arr, - right_arr, - Side::Right, - TransMode::ConjTrans, - ); - as TriangularOperations>::solve( - &lu.u_arr, - right_arr, - Side::Right, - TransMode::ConjTrans, - ); - } else { - as TriangularOperations>::solve( - &lu.u_arr, - right_arr, - Side::Right, - TransMode::NoTrans, - ); - as TriangularOperations>::solve( - &lu.l_arr, - right_arr, - Side::Right, - TransMode::NoTrans, - ); - - lu.perm.right_mul(right_arr, factor_options); - } - } else if factor_options.trans { - as TriangularOperations>::mul( - &lu.u_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - as TriangularOperations>::mul( - &lu.l_arr, - right_arr, - Side::Left, - TransMode::ConjTrans, - ); - lu.perm.left_mul(right_arr, factor_options); - } else { - lu.perm.right_mul(right_arr, factor_options); - as TriangularOperations>::mul( - &lu.l_arr, - right_arr, - Side::Right, - TransMode::NoTrans, - ); - as TriangularOperations>::mul( - &lu.u_arr, - right_arr, - Side::Right, - TransMode::NoTrans, - ); - } - } - } - } - - pub fn mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + Stride<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - right_arr: &mut Array, - side: Side, - factor_options: &MulOptions, - ) { - match side { - Side::Left => self.left_mul(right_arr, factor_options), - Side::Right => self.right_mul(right_arr, factor_options), - } - } -} - -impl DiagBoxFactor -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - pub fn new( - rows: &mut [usize], - y_data: &SketchData, - subs_sample_dim: usize, - options: &RsrsOptions, - ) -> (Option, Times) { - let (sub_test, sub_sketch) = ( - y_data - .test - .r() - .into_subview([0, 0], [subs_sample_dim, y_data.dim]), - y_data - .sketch - .r() - .into_subview([0, 0], [subs_sample_dim, y_data.dim]), - ); - - let diag_times = LuTimes { - //TODO: change this to diag_times - extraction: 0_u128, - lu: 0_u128, - }; - - let times = Times::Lu(diag_times); - - ( - Some(Self { - arr: DiagBoxArr::new(rows, &options.extract_db_options, &sub_test, &sub_sketch), - inds: rows.to_vec(), - }), - times, - ) - } - - pub fn cond(&self) -> (CondType, Option>) { - match &self.arr { - DiagBoxType::Reg(reg_dbox) => ((condition_number(®_dbox.arr), None), None), - DiagBoxType::Lu(lu_dbox) => ( - ( - num::Zero::zero(), - Some(( - condition_number(&lu_dbox.l_arr.tri), - condition_number(&lu_dbox.u_arr.tri), - )), - ), - None, - ), - } - } -} - -impl FactorOperations - for DiagBoxFactor -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - type Item = Item; - - fn mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - target_arr: &mut Array, - factor_options: &MulOptions, - ) { - let target_block = self.mul_data(target_arr, factor_options); - let t_arr_mutex = std::sync::Mutex::new(target_arr); - self.ins_data(&target_block, *t_arr_mutex.lock().unwrap(), factor_options); - } - - fn mul_data< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - target_arr: &Array, - options: &MulOptions, - ) -> DynamicArray { - let trans = options.trans; - - match options.side { - Side::Left => { - let mut target_rows = ext_rows( - self.inds.clone(), - self.inds.clone(), - target_arr, - trans, - options.t_trans, - ); - self.arr.mul(&mut target_rows, Side::Left, options); - target_rows - } - Side::Right => { - let mut target_cols = ext_cols( - self.inds.clone(), - self.inds.clone(), - target_arr, - trans, - options.t_trans, - ); - self.arr.mul(&mut target_cols, Side::Right, options); - target_cols - } - } - } - - fn ins_data< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item>, - >( - &self, - source_arr: &DynamicArray, - target_arr: &mut Array, - options: &MulOptions, - ) { - let trans = options.trans; - - match options.side { - Side::Left => { - row_subs( - self.inds.clone(), - self.inds.clone(), - source_arr, - target_arr, - trans, - options.t_trans, - ); - } - Side::Right => { - col_subs( - self.inds.clone(), - self.inds.clone(), - source_arr, - target_arr, - trans, - options.t_trans, - ); - } - } - } -} - -pub trait CommutativeFactorsOperations: Sized { - type Item: RlstScalar; - fn new() -> Self; - fn add_factor(&mut self, factor: Factor); - fn mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Stride<2> - + RawAccessMut - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item> - + std::marker::Send - + std::marker::Sync, - >( - &self, - target_arr: &mut Array, - factor_options: &MulOptions, - ); - #[allow(clippy::type_complexity)] - fn get_condition_numbers(&self) -> Vec<(CondType, Option>)>; -} - -impl< - Item: RlstScalar - + MatrixId - + MatrixInverse - + MatrixPseudoInverse - + RandScalar - + MatrixLu - + MatrixQr, - > CommutativeFactorsOperations for CommutativeFactors -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - type Item = Item; - - fn new() -> Self { - Vec::new() - } - fn add_factor(&mut self, factor: Factor) { - self.push(factor); - } - - fn mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> - + Stride<2> - + RawAccessMut - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Self::Item> - + UnsafeRandomAccessByRef<2, Item = Self::Item> - + std::marker::Send - + std::marker::Sync, - >( - &self, - target_arr: &mut Array, - factor_options: &MulOptions, - ) where - Self: Sized, - { - let updated_t_arr_blocks: Vec<_> = self - .par_iter() - .enumerate() - .map(|(factor_ind, factor)| { - let target_block = match factor { - Factor::Lu(lu_factor) => lu_factor.mul_data(target_arr, factor_options), - Factor::Id(id_factor) => id_factor.mul_data(target_arr, factor_options), - Factor::Diag(diag_factor) => diag_factor.mul_data(target_arr, factor_options), - }; - (factor_ind, target_block) - }) - .collect(); - - let t_arr_mutex = std::sync::Mutex::new(target_arr); - updated_t_arr_blocks - .par_iter() - .for_each(|(factor_ind, target_block)| { - let factor = &self[*factor_ind]; - match factor { - Factor::Lu(lu_factor) => lu_factor.ins_data( - target_block, - *t_arr_mutex.lock().unwrap(), - factor_options, - ), - Factor::Id(id_factor) => id_factor.ins_data( - target_block, - *t_arr_mutex.lock().unwrap(), - factor_options, - ), - Factor::Diag(diag_factor) => diag_factor.ins_data( - target_block, - *t_arr_mutex.lock().unwrap(), - factor_options, - ), - }; - }); - } - - fn get_condition_numbers(&self) -> Vec<(CondType, Option>)> { - let condition_numbers: Vec<_> = self - .par_iter() - .enumerate() - .map(|(_factor_ind, factor)| match factor { - Factor::Lu(lu_factor) => lu_factor.cond(), - Factor::Id(id_factor) => id_factor.cond(), - Factor::Diag(diag_factor) => diag_factor.cond(), - }) - .collect(); - - condition_numbers - } -} -pub trait RsrsFactorsImpl: Sized { - fn new(num_levels: usize, dim: usize) -> Self; - - fn apply_id_level< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Stride<2> - + RawAccessMut - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + std::marker::Send - + std::marker::Sync, - >( - &self, - target_arr: &mut Array, - factor_options: &MulOptions, - level_it: usize, - ); - - fn apply_lu_level< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Stride<2> - + RawAccessMut - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + std::marker::Send - + std::marker::Sync, - >( - &self, - target_arr: &mut Array, - factor_options: &MulOptions, - dec: bool, - level_it: usize, - ); - - fn el_factors_mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Stride<2> - + RawAccessMut - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + std::marker::Send - + std::marker::Sync, - >( - &self, - target_arr: &mut Array, - mul_type: RsrsMulType, - factor_options: &MulOptions, - level: bool, - ); - - fn matmul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Stride<2> - + RawAccessMut - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + std::marker::Send - + std::marker::Sync, - >( - &mut self, - target_arr: &mut Array, - mul_type: RsrsSide, - factor_options: &MulOptions, - ); - - fn matvec( - &self, - x: &[Item], - y: &mut [Item], - mul_type: RsrsSide, - factor_options: &mut MulOptions, - ); - - fn perm_target_array< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - target_arr: &mut Array, - ); - - fn dim(&self) -> usize; - - #[allow(clippy::type_complexity)] - fn get_condition_numbers( - &self, - ) -> ( - Vec, Option>)>>, - Vec, Option>)>>, - Vec<(CondType, Option>)>, - ); - - fn get_factors(&self) -> &RsrsFactors; -} - -impl< - Item: RlstScalar - + MatrixInverse - + MatrixId - + MatrixPseudoInverse - + MatrixLu - + RandScalar - + MatrixQr, - > RsrsFactorsImpl for RsrsFactors -where - LuDecomposition, 2>>: - MatrixLuDecomposition, - TriangularMatrix: TriangularOperations, -{ - fn new(num_levels: usize, dim: usize) -> Self { - let mut id_factors = Vec::new(); - id_factors.resize_with(num_levels, Vec::new); - let mut lu_factors = Vec::new(); - lu_factors.resize_with(num_levels, Vec::new); - let mut near_field_inds = Vec::new(); - near_field_inds.resize_with(num_levels, Vec::new); - let orig_indices = Vec::new(); - let perm_indices = Vec::new(); - let perm_factor = PermFactor::new(orig_indices, perm_indices).unwrap(); - let diag_box_factors = DiagBoxFactors::new(); - Self { - num_levels, - near_field_inds, - id_factors, - lu_factors, - perm_factor, - diag_box_factors, - dim, - } - } - - fn dim(&self) -> usize { - self.dim - } - - fn apply_id_level< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Stride<2> - + RawAccessMut - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + std::marker::Send - + std::marker::Sync, - >( - &self, - target_arr: &mut Array, - factor_options: &MulOptions, - level_it: usize, - ) { - let id_batch = &self.id_factors[level_it]; - id_batch.mul(target_arr, factor_options); - } - - fn apply_lu_level< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Stride<2> - + RawAccessMut - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + std::marker::Send - + std::marker::Sync, - >( - &self, - target_arr: &mut Array, - - factor_options: &MulOptions, - dec: bool, - level_it: usize, - ) { - let num_lu_batches = self.lu_factors[level_it].len(); - - if dec { - (0..num_lu_batches).rev().for_each(|batch_ind| { - let lu_batch = &self.lu_factors[level_it][batch_ind]; - lu_batch.mul(target_arr, factor_options); - }); - } else { - (0..num_lu_batches).for_each(|batch_ind| { - let lu_batch = &self.lu_factors[level_it][batch_ind]; - lu_batch.mul(target_arr, factor_options); - }); - } - } - - fn el_factors_mul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Stride<2> - + RawAccessMut - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + std::marker::Send - + std::marker::Sync, - >( - &self, - target_arr: &mut Array, - mul_type: RsrsMulType, - factor_options: &MulOptions, - dec: bool, - ) { - let levels = (0..self.num_levels).collect::>(); - if matches!(mul_type.side, RsrsSide::Squeeze) { - let mut left_options = factor_options.clone(); - left_options.side = Side::Left; - left_options.factor_type = FactorType::F; - left_options.t_trans = mul_type.t_trans; - - let mut right_options = factor_options.clone(); - right_options.side = Side::Right; - right_options.factor_type = FactorType::S; - right_options.t_trans = mul_type.t_trans; - - levels.iter().for_each(|&level_it| { - self.apply_id_level(target_arr, &left_options, level_it); - self.apply_id_level(target_arr, &right_options, level_it); - self.apply_lu_level(target_arr, &left_options, dec, level_it); - self.apply_lu_level(target_arr, &right_options, dec, level_it); - }); - } else { - let mut factor_options_aux = factor_options.clone(); - factor_options_aux.factor_type = mul_type.factor_type.clone(); - factor_options_aux.t_trans = mul_type.t_trans; - - if matches!(mul_type.side, RsrsSide::Left) { - factor_options_aux.side = Side::Left; - } else { - factor_options_aux.side = Side::Right; - } - - if dec { - levels.iter().rev().for_each(|&level_it| { - self.apply_lu_level(target_arr, &factor_options_aux, dec, level_it); - self.apply_id_level(target_arr, &factor_options_aux, level_it); - }); - } else { - levels.iter().for_each(|&level_it| { - self.apply_id_level(target_arr, &factor_options_aux, level_it); - self.apply_lu_level(target_arr, &factor_options_aux, dec, level_it); - }); - } - } - } - - fn matmul< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Stride<2> - + RawAccessMut - + Shape<2> - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item> - + std::marker::Send - + std::marker::Sync, - >( - &mut self, - target_arr: &mut Array, - side: RsrsSide, - factor_options: &MulOptions, - ) { - match side { - RsrsSide::Squeeze => {} - RsrsSide::Left => { - let mul_type_1; - let mul_type_2; - if !factor_options.inv { - mul_type_1 = RsrsMulType { - side: RsrsSide::Left, - factor_type: FactorType::S, - t_trans: false, - }; - mul_type_2 = RsrsMulType { - side: RsrsSide::Left, - factor_type: FactorType::F, - t_trans: false, - }; - } else { - mul_type_1 = RsrsMulType { - side: RsrsSide::Left, - factor_type: FactorType::F, - t_trans: false, - }; - mul_type_2 = RsrsMulType { - side: RsrsSide::Left, - factor_type: FactorType::S, - t_trans: false, - }; - } - - let mut factor_options_aux = factor_options.clone(); - factor_options_aux.side = Side::Left; - factor_options_aux.factor_type = FactorType::F; - factor_options_aux.t_trans = false; //TODO: CHECK IF CORRECT - - self.el_factors_mul(target_arr, mul_type_1, factor_options, false); - self.diag_box_factors.mul(target_arr, &factor_options_aux); - self.el_factors_mul(target_arr, mul_type_2, factor_options, true); - } - RsrsSide::Right => { - let mul_type_1; - let mul_type_2; - if !factor_options.inv { - mul_type_1 = RsrsMulType { - side: RsrsSide::Right, - factor_type: FactorType::F, - t_trans: false, - }; - mul_type_2 = RsrsMulType { - side: RsrsSide::Right, - factor_type: FactorType::S, - t_trans: false, - }; - } else { - mul_type_1 = RsrsMulType { - side: RsrsSide::Right, - factor_type: FactorType::S, - t_trans: false, - }; - mul_type_2 = RsrsMulType { - side: RsrsSide::Right, - factor_type: FactorType::F, - t_trans: false, - }; - } - - let mut factor_options_aux = factor_options.clone(); - factor_options_aux.side = Side::Right; - factor_options_aux.factor_type = FactorType::F; - factor_options_aux.t_trans = false; //TODO: CHECK IF CORRECT - - self.el_factors_mul(target_arr, mul_type_1, factor_options, false); - self.diag_box_factors.mul(target_arr, &factor_options_aux); - self.el_factors_mul(target_arr, mul_type_2, factor_options, true); - } - } - } - - fn matvec(&self, x: &[Item], y: &mut [Item], side: RsrsSide, factor_options: &mut MulOptions) { - let target_arr = match side { - RsrsSide::Squeeze => empty_array(), - RsrsSide::Left => { - let mut target_arr = rlst_dynamic_array2!(Item, [x.len(), 1]); - for (i, val) in x.iter().enumerate() { - target_arr.r_mut()[[i, 0]] = *val; - } - let mul_type_1; - let mul_type_2; - if !factor_options.inv { - mul_type_1 = RsrsMulType { - side: RsrsSide::Left, - factor_type: FactorType::S, - t_trans: false, - }; - mul_type_2 = RsrsMulType { - side: RsrsSide::Left, - factor_type: FactorType::F, - t_trans: false, - }; - } else { - mul_type_1 = RsrsMulType { - side: RsrsSide::Left, - factor_type: FactorType::F, - t_trans: false, - }; - mul_type_2 = RsrsMulType { - side: RsrsSide::Left, - factor_type: FactorType::S, - t_trans: false, - }; - } - - let mut factor_options_aux = factor_options.clone(); - factor_options_aux.side = Side::Left; - factor_options_aux.factor_type = FactorType::F; - factor_options_aux.t_trans = false; //TODO: CHECK IF CORRECT - - self.el_factors_mul(&mut target_arr, mul_type_1, factor_options, false); - self.diag_box_factors - .mul(&mut target_arr, &factor_options_aux); - self.el_factors_mul(&mut target_arr, mul_type_2, factor_options, true); - target_arr - } - RsrsSide::Right => { - let mut target_arr = rlst_dynamic_array2!(Item, [1, x.len()]); - - for (i, val) in x.iter().enumerate() { - target_arr.r_mut()[[0, i]] = *val; - } - let mul_type_1; - let mul_type_2; - if !factor_options.inv { - mul_type_1 = RsrsMulType { - side: RsrsSide::Right, - factor_type: FactorType::F, - t_trans: false, - }; - mul_type_2 = RsrsMulType { - side: RsrsSide::Right, - factor_type: FactorType::S, - t_trans: false, - }; - } else { - mul_type_1 = RsrsMulType { - side: RsrsSide::Right, - factor_type: FactorType::S, - t_trans: false, - }; - mul_type_2 = RsrsMulType { - side: RsrsSide::Right, - factor_type: FactorType::F, - t_trans: false, - }; - } - - let mut factor_options_aux = factor_options.clone(); - factor_options_aux.side = Side::Right; - factor_options_aux.factor_type = FactorType::F; - factor_options_aux.t_trans = false; //TODO: CHECK IF CORRECT - - self.el_factors_mul(&mut target_arr, mul_type_1, factor_options, false); - self.diag_box_factors - .mul(&mut target_arr, &factor_options_aux); - self.el_factors_mul(&mut target_arr, mul_type_2, factor_options, true); - target_arr - } - }; - - for (i, val) in target_arr.r().iter().enumerate() { - y[i] = val; - } - } - - fn perm_target_array< - ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> - + Shape<2> - + RawAccessMut - + UnsafeRandomAccessMut<2, Item = Item> - + UnsafeRandomAccessByRef<2, Item = Item>, - >( - &self, - target_arr: &mut Array, - ) { - self.perm_factor.left_mul( - target_arr, - &MulOptions { - inv: false, - trans: false, - side: Side::Left, - factor_type: FactorType::F, - t_trans: false, - }, - ); - self.perm_factor.right_mul( - target_arr, - &MulOptions { - inv: false, - trans: true, - side: Side::Right, - factor_type: FactorType::F, - t_trans: false, - }, - ); - } - - fn get_condition_numbers( - &self, - ) -> ( - Vec, Option>)>>, - Vec, Option>)>>, - Vec<(CondType, Option>)>, - ) { - let mut id_condition_numbers = Vec::new(); - let mut lu_condition_numbers = Vec::new(); - for id_batch in self.id_factors.iter() { - id_condition_numbers.push(id_batch.get_condition_numbers()); - } - for lu_level_batches in self.lu_factors.iter() { - let mut lu_level_condition_numbers = Vec::new(); - for lu_batch in lu_level_batches.iter() { - lu_level_condition_numbers.extend_from_slice(&lu_batch.get_condition_numbers()); - } - lu_condition_numbers.push(lu_level_condition_numbers); - } - let diag_condition_numbers = self.diag_box_factors.get_condition_numbers(); - - ( - id_condition_numbers, - lu_condition_numbers, - diag_condition_numbers, - ) - } - - fn get_factors(&self) -> &Self { - self - } -} - -impl Shape<2> for RsrsFactors { - fn shape(&self) -> [usize; 2] { - [self.dim, self.dim] - } -} - -pub struct RsrsOperator< - 'a, - Item: RlstScalar + MatrixInverse + MatrixId + MatrixPseudoInverse + MatrixLu + RandScalar + MatrixQr, - Space: SamplingSpace, - Op: RsrsFactorsImpl + Shape<2>, -> { - pub op: &'a Op, - domain: Rc, - range: Rc, - inv: bool, -} - -impl< - 'a, - Item: RlstScalar - + MatrixInverse - + MatrixId - + MatrixPseudoInverse - + MatrixLu - + RandScalar - + MatrixQr, - Space: SamplingSpace + LinearSpace, - Op: RsrsFactorsImpl + Shape<2>, - > RsrsOperator<'a, Item, Space, Op> -{ - pub fn get_factors(&self) -> &RsrsFactors { - self.op.get_factors() - } - - #[allow(clippy::type_complexity)] - pub fn get_condition_numbers( - &self, - ) -> ( - Vec, Option>)>>, - Vec, Option>)>>, - Vec<(CondType, Option>)>, - ) { - self.op.get_condition_numbers() - } -} - -// Implement OperatorBase for RsrsOperator so it can be used with rlst::Operator -impl< - 'a, - Item: RlstScalar - + MatrixInverse - + MatrixId - + MatrixPseudoInverse - + MatrixLu - + RandScalar - + MatrixQr, - Space: SamplingSpace + LinearSpace, - Op: RsrsFactorsImpl + Shape<2>, - > OperatorBase for RsrsOperator<'a, Item, Space, Op> -{ - type Domain = Space; - type Range = Space; - - fn domain(&self) -> Rc { - self.domain.clone() - } - - fn range(&self) -> Rc { - self.range.clone() - } -} - -impl< - Item: RlstScalar - + MatrixInverse - + MatrixId - + MatrixPseudoInverse - + MatrixLu - + RandScalar - + MatrixQr, - Space: SamplingSpace, - Op: RsrsFactorsImpl + Shape<2>, - > std::fmt::Debug for RsrsOperator<'_, Item, Space, Op> -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let shape = self.op.shape(); - write!(f, "RsrsOperator: [{}x{}]", shape[0], shape[1]).unwrap(); - Ok(()) - } -} - -pub trait LocalFromSpaces< - 'a, - Item: RlstScalar + MatrixInverse + MatrixId + MatrixPseudoInverse + MatrixLu + RandScalar + MatrixQr, - Space, - Op, ->: Sized -{ - fn from_local_spaces(op: &'a Op, domain: Rc, range: Rc) -> Self; -} - -pub trait Inv { - fn inv(&mut self, inv: bool); -} - -impl< - Item: RlstScalar - + MatrixInverse - + MatrixId - + MatrixPseudoInverse - + MatrixLu - + RandScalar - + MatrixQr, - Space: SamplingSpace, - Op: RsrsFactorsImpl + Shape<2>, - > Inv for RsrsOperator<'_, Item, Space, Op> -where - StandardNormal: Distribution<::Real>, - Standard: Distribution<::Real>, - ::Real: RandScalar, -{ - fn inv(&mut self, inv: bool) { - self.inv = inv; - } -} - -impl< - 'a, - Item: RlstScalar - + MatrixInverse - + MatrixId - + MatrixPseudoInverse - + MatrixLu - + RandScalar - + MatrixQr, - Op: RsrsFactorsImpl + Shape<2>, - > LocalFromSpaces<'a, Item, ArrayVectorSpace, Op> - for RsrsOperator<'a, Item, ArrayVectorSpace, Op> -where - StandardNormal: Distribution<::Real>, - Standard: Distribution<::Real>, - ::Real: RandScalar, -{ - fn from_local_spaces( - op: &'a Op, - domain: Rc>, - range: Rc>, - ) -> Self { - RsrsOperator { - op, - domain: domain.clone(), - range: range.clone(), - inv: false, - } - } -} - -impl< - 'a, - Item: RlstScalar - + MatrixInverse - + MatrixId - + MatrixPseudoInverse - + MatrixLu - + RandScalar - + MatrixQr - + Equivalence, - Op: RsrsFactorsImpl + Shape<2>, - > LocalFromSpaces<'a, Item, DistributedArrayVectorSpace<'a, SimpleCommunicator, Item>, Op> - for RsrsOperator<'a, Item, DistributedArrayVectorSpace<'a, SimpleCommunicator, Item>, Op> -where - StandardNormal: Distribution<::Real>, - Standard: Distribution<::Real>, - ::Real: RandScalar, -{ - fn from_local_spaces( - op: &'a Op, - domain: Rc>, - range: Rc>, - ) -> Self { - RsrsOperator { - op, - domain: domain.clone(), - range: range.clone(), - inv: false, - } - } -} - -impl< - Item: RlstScalar - + MatrixInverse - + MatrixId - + MatrixPseudoInverse - + MatrixLu - + RandScalar - + MatrixQr, - Op: RsrsFactorsImpl + Shape<2>, - > AsApply for RsrsOperator<'_, Item, ArrayVectorSpace, Op> -where - ::Real: RandScalar, - StandardNormal: Distribution<::Real>, - Standard: Distribution<::Real>, -{ - fn apply_extended< - ContainerIn: ElementContainer::E>, - ContainerOut: ElementContainerMut::E>, - >( - &self, - _alpha: ::F, - x: Element, - _beta: ::F, - mut y: Element, - trans_mode: TransMode, - ) { - match trans_mode { - TransMode::NoTrans => { - let mut factor_options = MulOptions { - inv: self.inv, - trans: false, - side: Side::Left, - factor_type: FactorType::F, - t_trans: false, - }; - - // Reshape y to a 2D array before passing to mul - self.op.matvec( - x.imp().view().data(), - y.imp_mut().view_mut().data_mut(), - RsrsSide::Left, - &mut factor_options, - ); - } - TransMode::ConjNoTrans => { - panic!("TransMode::ConjNoTrans not supported for multiplication.") - } - TransMode::Trans => { - let mut factor_options = MulOptions { - inv: self.inv, - trans: false, - side: Side::Left, - factor_type: FactorType::F, - t_trans: false, - }; - - self.op.matvec( - x.imp().view().data(), - y.imp_mut().view_mut().data_mut(), - RsrsSide::Right, - &mut factor_options, - ); - } - TransMode::ConjTrans => { - panic!("TransMode::ConjTrans not supported for multiplication.") - } - } - } - - fn apply::E>>( - &self, - x: Element, - trans_mode: rlst::TransMode, - ) -> rlst::operator::ElementType<::E> { - let mut y = zero_element(self.range()); - self.apply_extended( - <::F as num::One>::one(), - x, - <::F as num::Zero>::zero(), - y.r_mut(), - trans_mode, - ); - y - } -} - -impl< - C: Communicator, - Item: RlstScalar - + MatrixInverse - + MatrixId - + MatrixPseudoInverse - + MatrixLu - + RandScalar - + MatrixQr - + Equivalence, - Op: RsrsFactorsImpl + Shape<2>, - > AsApply for RsrsOperator<'_, Item, DistributedArrayVectorSpace<'_, C, Item>, Op> -where - ::Real: RandScalar, - StandardNormal: Distribution<::Real>, - Standard: Distribution<::Real>, -{ - fn apply_extended< - ContainerIn: ElementContainer::E>, - ContainerOut: ElementContainerMut::E>, - >( - &self, - _alpha: ::F, - x: Element, - _beta: ::F, - mut y: Element, - trans_mode: TransMode, - ) { - match trans_mode { - TransMode::NoTrans => { - let mut factor_options = MulOptions { - inv: self.inv, - trans: false, - side: Side::Left, - factor_type: FactorType::F, - t_trans: false, - }; - - // Reshape y to a 2D array before passing to mul - self.op.matvec( - x.imp().view().local().data(), - y.imp_mut().view_mut().local_mut().data_mut(), - RsrsSide::Left, - &mut factor_options, - ); - } - TransMode::ConjNoTrans => { - panic!("TransMode::ConjNoTrans not supported for multiplication.") - } - TransMode::Trans => { - let mut factor_options = MulOptions { - inv: self.inv, - trans: false, - side: Side::Left, - factor_type: FactorType::F, - t_trans: false, - }; - - self.op.matvec( - x.imp().view().local().data(), - y.imp_mut().view_mut().local_mut().data_mut(), - RsrsSide::Right, - &mut factor_options, - ); - } - TransMode::ConjTrans => { - panic!("TransMode::ConjTrans not supported for multiplication.") - } - } - } - - fn apply::E>>( - &self, - x: Element, - trans_mode: rlst::TransMode, - ) -> rlst::operator::ElementType<::E> { - let mut y = zero_element(self.range()); - self.apply_extended( - <::F as num::One>::one(), - x, - <::F as num::Zero>::zero(), - y.r_mut(), - trans_mode, - ); - y - } -} diff --git a/src/rsrs/rsrs_factors/base_factors.rs b/src/rsrs/rsrs_factors/base_factors.rs new file mode 100644 index 0000000..61be0bb --- /dev/null +++ b/src/rsrs/rsrs_factors/base_factors.rs @@ -0,0 +1,616 @@ +use crate::utils::data_ins_ext::extract_axis_into; +use itertools::min; +use rlst::{ + dense::linalg::lu::{MatrixLu, SquareLuFactors}, + prelude::*, +}; + +type Real = ::Real; +type CNTuple = (Real, Real); //TODO: Remove before releasing +pub type CondType = (CNTuple, Option<(CNTuple, CNTuple)>); //TODO: Remove before releasing + +pub fn condition_number( + //TODO: Remove before releasing + mat: &DynamicArray, +) -> CNTuple { + let shape = mat.shape(); + let dim: usize = min(shape).unwrap(); + let mut singular_values: DynamicArray, 1> = rlst_dynamic_array1!(Real, [dim]); + let mode: SvdMode = SvdMode::Reduced; + let mut u: DynamicArray = rlst_dynamic_array2!(Item, [shape[0], dim]); + let mut vt: DynamicArray = rlst_dynamic_array2!(Item, [dim, shape[1]]); + + let mut aux_data = empty_array(); + aux_data.fill_from_resize(mat.r()); + + aux_data + .r_mut() + .into_svd_alloc(u.r_mut(), vt.r_mut(), singular_values.data_mut(), mode) + .unwrap(); + + let sigma_max = singular_values[[0]]; + let sigma_min = singular_values[[dim - 1]]; + + (sigma_max / sigma_min, sigma_max) +} + +/// Basic flags used when applying an elementary RSRS factor. +/// +/// `trans` describes the orientation of the factor itself, while `trans_target` +/// keeps track of whether the current input is viewed through a transposed +/// row/column layout. +#[derive(Clone)] +pub struct BaseFactorOptions { + /// Inverse operation + pub inv: bool, + /// Transpose operation + pub trans: TransMode, + /// Transpose vector or matrix when applying factor + pub trans_target: bool, +} + +pub(crate) struct FactorApplyScratch { + pub source: DynamicArray, + pub result: DynamicArray, +} + +pub(crate) struct FactorApplyLayout<'a> { + pub source_indices: &'a [usize], + pub target_indices: &'a [usize], + pub axis: usize, + pub transposed: bool, +} + +impl FactorApplyScratch { + pub(crate) fn new() -> Self { + Self { + source: empty_array(), + result: empty_array(), + } + } +} + +pub(crate) fn factor_apply_layout<'a>( + side: &Side, + factor_options: &BaseFactorOptions, + c_indices: &'a [usize], + r_indices: &'a [usize], +) -> FactorApplyLayout<'a> { + // The low-level delta kernels operate on "source" entries that are read + // and "target" entries that are updated. Left/right application, factor + // transpose, and transposed vector views all permute those roles, so we + // normalize the layout once here. + match side { + Side::Left => { + let (source_indices, target_indices) = if !factor_options.trans_val() { + (c_indices, r_indices) + } else { + (r_indices, c_indices) + }; + + if factor_options.trans_target { + FactorApplyLayout { + source_indices, + target_indices, + axis: 1, + transposed: true, + } + } else { + FactorApplyLayout { + source_indices, + target_indices, + axis: 0, + transposed: false, + } + } + } + Side::Right => { + let (source_indices, target_indices) = if !factor_options.trans_val() { + (r_indices, c_indices) + } else { + (c_indices, r_indices) + }; + + if factor_options.trans_target { + FactorApplyLayout { + source_indices, + target_indices, + axis: 0, + transposed: true, + } + } else { + FactorApplyLayout { + source_indices, + target_indices, + axis: 1, + transposed: false, + } + } + } + } +} + +pub(crate) fn conjugate_array_in_place(arr: &mut DynamicArray) { + arr.data_mut() + .iter_mut() + .for_each(|elem| *elem = elem.conj()); +} + +pub(crate) fn assert_transpose_only_mode(trans: TransMode, context: &str) { + match trans { + TransMode::NoTrans | TransMode::Trans => {} + TransMode::ConjNoTrans | TransMode::ConjTrans => { + panic!( + "{context} received a conjugating transposition mode. Conjugation must be handled above the low-level factor kernels." + ); + } + } +} + +/// Handler to manage the factor basic options +impl BaseFactorOptions { + /// Returns whether the factor is applied in transposed orientation. + /// + /// This only answers the layout question used to swap source and target + /// indices. Any conjugation required for complex or Hermitian factors is + /// handled separately in the concrete factor implementations. + pub fn trans_val(&self) -> bool { + match self.trans { + TransMode::NoTrans => false, + TransMode::ConjNoTrans => false, + TransMode::Trans => true, + TransMode::ConjTrans => true, + } + } + + /// Flip the factor orientation between `NoTrans` and `Trans`. + /// + /// Conjugating transpose modes are normalized before this helper is used, + /// so they are intentionally left unsupported here. + pub fn transpose(&self) -> Self { + let mut new_options = self.clone(); + match self.trans { + TransMode::NoTrans => new_options.trans = TransMode::Trans, + TransMode::ConjNoTrans => panic!( + "BaseFactorOptions::transpose cannot be used with conjugating modes. Conjugation must be normalized before reaching low-level factors." + ), + TransMode::Trans => new_options.trans = TransMode::NoTrans, + TransMode::ConjTrans => panic!( + "BaseFactorOptions::transpose cannot be used with conjugating modes. Conjugation must be normalized before reaching low-level factors." + ), + }; + + new_options + } + + /// Toggle between applying the factor and its inverse. + pub fn invert(&self) -> Self { + let mut new_options = self.clone(); + new_options.inv = !self.inv; + new_options + } +} + +/// Lu Decomposition of a square box. +pub struct LuSMat { + /// Reusable LU factors for trusted multiplications and solves. + pub square_factors: SquareLuFactors, +} + +/// Stores a square matrix in a regular array. +pub struct RegSMat { + /// Storage for A + pub arr: DynamicArray, + /// Storage for A^-1 + pub inv_arr: DynamicArray, +} +// A square matrix can either be stored in a LU factor or as a dense matrix with its inverse. +pub enum SquareArr { + Reg(RegSMat), + Lu(LuSMat), +} + +/// Likewise, diagonal blocks can either be stored in a LU factor or as a dense matrix with its inverse. +pub enum DiagBoxArr { + Reg(RegSMat), + Lu(LuSMat), +} + +/// Implementation of the multiplication operations of SquareArr. +impl SquareArr +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + /// Implementation of SquareArr * x (multiplication by the left). + fn left_mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + Stride<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + target_arr: &mut Array, + factor_options: &BaseFactorOptions, + ) { + assert_transpose_only_mode(factor_options.trans, "SquareArr::left_mul"); + match self { + SquareArr::Reg(ref reg) => { + let mut new_target_arr = empty_array(); + if factor_options.inv { + // Transposition is built in. + // A / x when A is stored as a dense matrix. + new_target_arr.r_mut().mult_into_resize( + factor_options.trans, + TransMode::NoTrans, + num::One::one(), + reg.inv_arr.r(), + target_arr.r(), + num::Zero::zero(), + ); + } else { + // A * x when A is stored as a dense matrix. + new_target_arr.r_mut().mult_into_resize( + factor_options.trans, + TransMode::NoTrans, + num::One::one(), + reg.arr.r(), + target_arr.r(), + num::Zero::zero(), + ); + } + target_arr.r_mut().fill_from(new_target_arr.r()); + } + SquareArr::Lu(ref lu) => { + if factor_options.inv { + lu.square_factors + .solve_mat(factor_options.trans, target_arr.r_mut()) + .unwrap(); + } else { + lu.square_factors + .mul_mat(factor_options.trans, target_arr.r_mut()) + .unwrap(); + } + } + } + } + + /// Multiplication by a square matrix. Right multiplication and inversion are defined as (A'*x')' and (A'\x')' + pub fn mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + Stride<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + arr: &mut Array, + side: Side, + factor_options: &BaseFactorOptions, + ) { + assert_transpose_only_mode(factor_options.trans, "SquareArr::mul"); + match side { + Side::Left => self.left_mul(arr, factor_options), + Side::Right => { + let mut aux_arr = empty_array(); + aux_arr.r_mut().fill_from_resize(arr.r().transpose()); + + let trans_factor_options = factor_options.transpose(); + self.left_mul(&mut aux_arr, &trans_factor_options); + + arr.fill_from(aux_arr.r().transpose()); + } + } + } + + /// Compute condition number of square matrix. + pub fn cond(&self) -> (CNTuple, CNTuple) { + //TODO: This function should not be included when releasing. + match self { + SquareArr::Reg(reg_dbox) => ( + condition_number(®_dbox.arr), + (num::Zero::zero(), num::Zero::zero()), + ), + SquareArr::Lu(_) => ( + (num::Zero::zero(), num::Zero::zero()), + (num::Zero::zero(), num::Zero::zero()), + ), + } + } +} + +/// When computing LU factorisation in RSRS, +/// we can either store A = P / R (where P +/// is the pivot extracted from the LU factorisation)or store +/// P and R independently, to then operate +/// them on the go. The second option should be better +/// if the pivot P is ill-conditioned. +/// +/// Composed elementary matrix stores `X_rr` and `X_rn` independently. +pub struct ComposedFactorData { + /// Pivot + pub sq: SquareArr, + /// Rectangular part of the factor + pub rectg: RectArr, +} + +/// RectArray store A = P / R +pub struct RectArr { + // TODO: Maybe change to just an array. + /// Rectangular array. + pub arr: Box>, +} + +/// Multiplication by a composed factor. +impl ComposedFactorData +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + fn left_mul_into( + &self, + target_arr: &mut DynamicArray, + factor_options: &BaseFactorOptions, + res_mul: &mut DynamicArray, + ) { + let mut sq_factor_options = factor_options.clone(); + sq_factor_options.inv = true; + + if !factor_options.trans_val() { + res_mul.r_mut().mult_into_resize( + TransMode::NoTrans, + TransMode::NoTrans, + num::One::one(), + self.rectg.arr.r(), + target_arr.r(), + num::Zero::zero(), + ); + self.sq.mul(res_mul, Side::Left, &sq_factor_options); + } else { + self.sq.mul(target_arr, Side::Left, &sq_factor_options); + res_mul.r_mut().mult_into_resize( + TransMode::Trans, + TransMode::NoTrans, + num::One::one(), + self.rectg.arr.r(), + target_arr.r(), + num::Zero::zero(), + ); + } + } + + /// Multiplication by Composed factor. Right multiplication is defined as (A'*x')' + fn mul_into( + &self, + target_arr: &mut DynamicArray, + side: &Side, + factor_options: &BaseFactorOptions, + res_mul: &mut DynamicArray, + ) { + match side { + Side::Left => self.left_mul_into(target_arr, factor_options, res_mul), + Side::Right => { + let mut aux_arr = empty_array(); + aux_arr.r_mut().fill_from_resize(target_arr.r().transpose()); + + let trans_factor_options = factor_options.transpose(); + let mut aux_res = empty_array(); + self.left_mul_into(&mut aux_arr, &trans_factor_options, &mut aux_res); + + res_mul.r_mut().fill_from_resize(aux_res.r().transpose()); + } + } + } + + /// Compute condition number of R and P + #[allow(clippy::type_complexity)] + pub fn cond(&self) -> CondType { + //TODO: Remove before releasing + (condition_number(&self.rectg.arr), Some(self.sq.cond())) + } +} + +/// Multiplication by a rectangular factor. +impl RectArr { + /// Implementation of b=A*x. Returns the result in a new array. + fn left_mul_into( + &self, + target_arr: &DynamicArray, + factor_options: &BaseFactorOptions, + res_mul: &mut DynamicArray, + ) { + assert_transpose_only_mode(factor_options.trans, "RectArr::left_mul_into"); + res_mul.r_mut().mult_into_resize( + factor_options.trans, + TransMode::NoTrans, + num::One::one(), + self.arr.r(), + target_arr.r(), + num::Zero::zero(), + ); + } + /// Multiplication by a Rectangular factor. Right multiplication is defined as (A'*x')' + fn mul_into( + &self, + target_arr: &DynamicArray, + side: &Side, + factor_options: &BaseFactorOptions, + res_mul: &mut DynamicArray, + ) { + assert_transpose_only_mode(factor_options.trans, "RectArr::mul_into"); + match side { + Side::Left => self.left_mul_into(target_arr, factor_options, res_mul), + Side::Right => { + let mut aux_arr = empty_array(); + aux_arr.r_mut().fill_from_resize(target_arr.r().transpose()); + + let trans_factor_options = factor_options.transpose(); + let mut aux_res = empty_array(); + self.left_mul_into(&aux_arr, &trans_factor_options, &mut aux_res); + + res_mul.r_mut().fill_from_resize(aux_res.r().transpose()); + } + } + } + + /// Compute condition number of the rectangular array. + #[allow(clippy::type_complexity)] + pub fn cond(&self) -> CondType { + //TODO: Remove before releasing. + (condition_number(&self.arr), None) + } +} + +/// An elementary factor can either be defined by a composed or a rectangular +/// factor depending on the user settings. Factor data stores F in the application I+/-F. +pub enum FactorData { + Comp(ComposedFactorData), + Reg(RectArr), +} + +/// Implementation of the Elementary Matrix multiplication and inversion. +impl FactorData +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + /// Multiplication (b=(I+/-F)*x). Instead of using full I and full F, + /// we only use and operate on the bits of data relevant for this operation. + /// + /// Arguments: + /// target_arr: x + /// side: factor can either multiply by the left or the right. + /// factor_options: A can be inverted or transposed and x can also be transposed. + /// c_indices: domain indices (rows to extract in x). + /// r_indices: range indices (rows to modify in x). + pub fn mul< + ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + target_arr: &Array, + side: &Side, + factor_options: &BaseFactorOptions, + c_indices: &[usize], + r_indices: &[usize], + ) -> DynamicArray { + let mut scratch = FactorApplyScratch::new(); + self.mul_with_scratch( + target_arr, + side, + factor_options, + c_indices, + r_indices, + &mut scratch, + ) + } + + pub(crate) fn mul_with_scratch< + ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + target_arr: &Array, + side: &Side, + factor_options: &BaseFactorOptions, + c_indices: &[usize], + r_indices: &[usize], + scratch: &mut FactorApplyScratch, + ) -> DynamicArray { + let layout = self.delta_with_scratch( + target_arr, + side, + factor_options, + c_indices, + r_indices, + false, + scratch, + ); + let mut subarr_target = empty_array(); + extract_axis_into( + &mut subarr_target, + target_arr, + layout.target_indices, + layout.axis, + layout.transposed, + ); + + if factor_options.inv { + // res = (I-F)*x + subarr_target.sub_into(scratch.result.r()); + } else { + // res = (I+F)*x + subarr_target.sum_into(scratch.result.r()); + } + + subarr_target + } + + #[allow(clippy::too_many_arguments)] + pub(crate) fn delta_with_scratch< + 'a, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + target_arr: &Array, + side: &Side, + factor_options: &BaseFactorOptions, + c_indices: &'a [usize], + r_indices: &'a [usize], + conjugate_source: bool, + scratch: &mut FactorApplyScratch, + ) -> FactorApplyLayout<'a> { + let layout = factor_apply_layout(side, factor_options, c_indices, r_indices); + + extract_axis_into( + &mut scratch.source, + target_arr, + layout.source_indices, + layout.axis, + layout.transposed, + ); + + if conjugate_source { + conjugate_array_in_place(&mut scratch.source); + } + + match self { + FactorData::Comp(composed_factor_data) => composed_factor_data.mul_into( + &mut scratch.source, + side, + factor_options, + &mut scratch.result, + ), + FactorData::Reg(rectg) => { + rectg.mul_into(&scratch.source, side, factor_options, &mut scratch.result) + } + } + + layout + } + + /// Compute condition number of the application (I+/-F) + pub fn cond(&self) -> CondType { + //TODO: Remove before releasing + match self { + FactorData::Comp(composed_factor_data) => composed_factor_data.cond(), + FactorData::Reg(rectg) => rectg.cond(), + } + } +} diff --git a/src/rsrs/rsrs_factors/commutative_factors.rs b/src/rsrs/rsrs_factors/commutative_factors.rs new file mode 100644 index 0000000..7ff8e1a --- /dev/null +++ b/src/rsrs/rsrs_factors/commutative_factors.rs @@ -0,0 +1,2456 @@ +use crate::rsrs::args::Symmetry; +use crate::rsrs::rsrs_factors::base_factors::{ + assert_transpose_only_mode, condition_number, conjugate_array_in_place, factor_apply_layout, + BaseFactorOptions, CondType, DiagBoxArr, FactorApplyScratch, FactorData, LuSMat, RectArr, + RegSMat, SquareArr, +}; +use crate::rsrs::rsrs_factors::null_and_extract::{ + extract_lu_factor_from_blocks, near_box_extraction, null_near_field_into, ExtractOptions, + ExtractionScratch, IdOptions, PivotMethod, +}; +use crate::rsrs::rsrs_factors::rsrs_operator::FactType; +use crate::rsrs::sketch::SketchData; +use crate::rsrs::statistics::{IdTimes, LuTimes, Times}; +use crate::utils::linear_algebra::add_diagonal; +use crate::utils::{ + data_ins_ext::{extract_axis_into, raw_matrix_mut, RawMatrixMut}, + elementary_matrix::{ + col_delta, col_delta_raw, col_perm, col_subs, ext_cols, ext_rows, row_delta, row_delta_raw, + row_perm, row_subs, + }, + linear_algebra::{ + block_extraction_into, streaming_chunk_rows, BlockExtractionMethod, + NormalEquationAccumulator, + }, + memory::{matrix_bytes, trace_memory_event, trace_memory_growth}, +}; +use rand_distr::{Distribution, Standard, StandardNormal}; +use rayon::{ + iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}, + ThreadPool, +}; +use rlst::{ + dense::{ + linalg::{ + interpolative_decomposition::{Accuracy, MatrixIdNoSkel}, + lu::{MatrixLu, SquareLuFactors}, + triangular_arrays::TriangularOperations, + }, + tools::RandScalar, + }, + prelude::*, +}; +use std::{ + collections::{HashMap, HashSet}, + mem::{size_of, size_of_val}, + time::{Duration, Instant}, +}; +type Real = ::Real; + +/// BoxType marks a box as merged if it is an +/// union of boxes that have been compressed in previous levels. +#[derive(Debug, Clone)] +pub enum BoxType { + Merged(usize), + Full(Real), +} + +/// Selects which half of an elementary RSRS update is being applied. +/// +/// The names stay abstract on purpose: an ID factor and an LU factor map `F` +/// and `S` to different stored arrays, but the higher-level elimination logic +/// can use one common ordering. +#[derive(Clone, PartialEq, Debug)] +pub enum FactorType { + F, + S, +} + +/// Application options shared by a batch of commutative factors. +#[derive(Clone)] +pub struct MulOptions { + /// Basic orientation and inversion flags for the factor application. + pub base_options: BaseFactorOptions, + /// Whether the factor is applied to a column-vector/matrix (`Left`) or to + /// a row-vector/matrix (`Right`). + pub side: Side, + /// Which half of the split elimination is currently being traversed. + pub factor_type: FactorType, +} + +/// IdFactor: Obtained from Interpolative decomposition. +/// It represents an elementary operation: (I+/-F) +pub struct IdFactor { + /// data: stores F + data: FactorData, + /// perm: stores the permutation induced by the ID + pub perm: Vec, + /// ind_r: stores the residual columns or rows in A + pub ind_r: Vec, //row_indices + /// ind_s: stores the skeleton columns or rows in A + pub ind_s: Vec, //col_indices + /// stores the far field indices when necessary. + pub ind_f: Vec, + /// type of symmetry of the factor + pub symmetry: Symmetry, +} + +/// LuFactor: Obtain from Block LU near field compression. +/// It represents an elementary operation: (I+/-F) +/// U and L are defined in the same factor. +pub struct LuFactor { + /// l_arr: L in (I+/-L) + l_arr: FactorData, + /// u_arr: U in (I+/-U) + pub u_arr: FactorData, + /// symmetric: indicates if LU was performed in a symmetric matrix. + /// (for a symmetric matrix we only store U) + symmetry: Symmetry, + /// ind_r: residual indices + pub ind_r: Vec, //cols + /// ind_t: target indices + pub ind_t: Vec, //rows +} + +/// Diagonal factor after ID and LU factorisation have been applied. +pub struct DiagBoxFactor { + /// arr: contains the information of the diagonal block matrix. + pub arr: DiagBoxArr, + /// inds: contains the indices of the rows/columns to apply the block factor. + pub inds: Vec, +} + +pub(crate) struct DiagExtractionScratch { + primary: DynamicArray, + secondary: DynamicArray, + tertiary: DynamicArray, + quaternary: DynamicArray, + quinary: DynamicArray, +} + +impl DiagExtractionScratch { + pub(crate) fn new() -> Self { + Self { + primary: empty_array(), + secondary: empty_array(), + tertiary: empty_array(), + quaternary: empty_array(), + quinary: empty_array(), + } + } +} + +/// Permutation factor: application that permutes selected rows/columns of a given matrix/vector. +pub struct PermFactor { + /// orig_indices: original indices + pub orig_indices: Vec, + /// perm_indices: permuted indices + pub perm_indices: Vec, +} + +/// A factor can either be an ID factor, a LU factor or a diagonal block factor. +pub enum Factor { + Lu(LuFactor), + Id(IdFactor), + Diag(DiagBoxFactor), +} + +/// RsrsFactors: collects the relevant information for the RSRS of a matrix. +pub struct RsrsFactors { + /// num_levels: depth of the tree + pub num_levels: usize, + /// id_factors: collection of ID factors sorted per level + pub id_factors: MultiLevelIdFactors, + /// lu_factors: collection of LU factors sorted per level + pub lu_factors: MultiLevelLuFactors, + /// perm_factor: permutation induced by RSRS + pub perm_factor: PermFactor, + /// diag_box_factors: block diagonal factors obtained from RSRS + pub diag_box_factors: DiagBoxFactors, + /// fact_type: either a split RSRS (first run all ID + /// factors and then LU) or joint (run ID and LU together). + /// split RSRS has better parallel properties, but introduces + /// larger errors. + pub fact_type: FactType, + /// dim: dimension of the compressed operator + pub dim: usize, + /// num_threads: number of threads used in the factorisation + /// and matrix-vector multiplication + pub num_threads: usize, +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct FactorMemoryBreakdown { + pub id_bytes: u64, + pub lu_bytes: u64, + pub diag_bytes: u64, + pub perm_bytes: u64, + pub id_count: usize, + pub lu_count: usize, + pub diag_count: usize, +} + +impl FactorMemoryBreakdown { + pub fn total_bytes(&self) -> u64 { + self.id_bytes + self.lu_bytes + self.diag_bytes + self.perm_bytes + } + + fn add_factor(&mut self, factor: &Factor) { + match factor { + Factor::Id(id_factor) => { + self.id_bytes += id_factor_bytes(id_factor); + self.id_count += 1; + } + Factor::Lu(lu_factor) => { + self.lu_bytes += lu_factor_bytes(lu_factor); + self.lu_count += 1; + } + Factor::Diag(diag_factor) => { + self.diag_bytes += diag_factor_bytes(diag_factor); + self.diag_count += 1; + } + } + } +} + +/// CommmutativeFactors: batcn of factors of any kind that +/// can be applied in parallel +pub type CommutativeFactors = Vec>; + +/// LevelFactors: batches inside of a level +type LevelFactors = Vec>; +/// MultiLevelFactors: collection of the results for all levels +type MultiLevelFactors = Vec>; + +/// MultiLevelIdFactors: storage for ID factors +pub enum MultiLevelIdFactors { + // Single: single batch + Single(LevelFactors), + // Batched: collection of batches associated to a level + Batched(MultiLevelFactors), +} + +/// MultilevelLuFactors: storage for LU factors +type MultiLevelLuFactors = MultiLevelFactors; + +/// DiagBoxFactors: storage for block diagonal factors +pub type DiagBoxFactors = CommutativeFactors; + +fn dynamic_array_bytes(arr: &DynamicArray) -> u64 { + (arr.shape()[0] as u64) * (arr.shape()[1] as u64) * (size_of::() as u64) +} + +fn usize_vec_bytes(values: &[usize]) -> u64 { + (values.len() as u64) * (size_of::() as u64) +} + +fn perm_factor_bytes(perm: &PermFactor) -> u64 { + usize_vec_bytes(&perm.orig_indices) + usize_vec_bytes(&perm.perm_indices) +} + +fn square_arr_bytes(arr: &SquareArr) -> u64 { + match arr { + SquareArr::Reg(reg) => dynamic_array_bytes(®.arr) + dynamic_array_bytes(®.inv_arr), + SquareArr::Lu(lu) => size_of_val(lu) as u64, + } +} + +fn diag_box_arr_bytes(arr: &DiagBoxArr) -> u64 { + match arr { + DiagBoxArr::Reg(reg) => dynamic_array_bytes(®.arr) + dynamic_array_bytes(®.inv_arr), + DiagBoxArr::Lu(lu) => size_of_val(lu) as u64, + } +} + +fn factor_data_bytes(data: &FactorData) -> u64 { + match data { + FactorData::Comp(comp) => square_arr_bytes(&comp.sq) + dynamic_array_bytes(&comp.rectg.arr), + FactorData::Reg(rect) => dynamic_array_bytes(&rect.arr), + } +} + +fn id_factor_bytes(factor: &IdFactor) -> u64 { + factor_data_bytes(&factor.data) + + usize_vec_bytes(&factor.perm) + + usize_vec_bytes(&factor.ind_r) + + usize_vec_bytes(&factor.ind_s) + + usize_vec_bytes(&factor.ind_f) +} + +fn lu_factor_bytes(factor: &LuFactor) -> u64 { + factor_data_bytes(&factor.l_arr) + + factor_data_bytes(&factor.u_arr) + + usize_vec_bytes(&factor.ind_r) + + usize_vec_bytes(&factor.ind_t) +} + +fn diag_factor_bytes(factor: &DiagBoxFactor) -> u64 { + diag_box_arr_bytes(&factor.arr) + usize_vec_bytes(&factor.inds) +} + +fn prefer_direct_diag_extraction( + rows_len: usize, + subs_sample_dim: usize, + nonsymmetric_buffers: usize, +) -> bool { + const DIRECT_DIAG_BYTES_BUDGET: u64 = 8 * 1024 * 1024; + + let extracted = + matrix_bytes::(subs_sample_dim, rows_len).saturating_mul(nonsymmetric_buffers as u64); + let square = matrix_bytes::(rows_len, rows_len) + .saturating_mul((nonsymmetric_buffers.saturating_sub(1)) as u64); + + extracted.saturating_add(square) <= DIRECT_DIAG_BYTES_BUDGET +} + +fn symmetrize_square_in_place(arr: &mut DynamicArray, adjoint: bool) { + let shape = arr.shape(); + debug_assert_eq!(shape[0], shape[1]); + let half = Item::real(0.5); + + for row in 0..shape[0] { + for col in row..shape[1] { + let mirror = if adjoint { + arr[[col, row]].conj() + } else { + arr[[col, row]] + }; + let avg = (arr[[row, col]] + mirror).mul_real(half); + arr[[row, col]] = avg; + arr[[col, row]] = if adjoint { avg.conj() } else { avg }; + } + } +} + +impl PermFactor { + pub fn new(orig_indices: Vec, perm_indices: Vec) -> RlstResult { + Ok(Self { + orig_indices, + perm_indices, + }) + } + + pub fn left_mul< + Item: RlstScalar, + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + right_arr: &mut Array, + options: &BaseFactorOptions, + ) { + assert_transpose_only_mode(options.trans, "PermFactor::left_mul"); + let orig_indices: Vec<_> = (0..right_arr.shape()[0]).collect(); + assert_eq!(orig_indices.len(), self.perm_indices.len()); + let trans = if options.inv { + let aux_options = options.transpose(); + aux_options.trans_val() + } else { + options.trans_val() + }; + + row_perm( + orig_indices.clone(), + self.perm_indices.clone(), + right_arr, + trans, + ); + } + + pub fn right_mul< + Item: RlstScalar, + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + left_arr: &mut Array, + options: &BaseFactorOptions, + ) { + assert_transpose_only_mode(options.trans, "PermFactor::right_mul"); + let orig_indices: Vec<_> = (0..left_arr.shape()[1]).collect(); + assert_eq!(orig_indices.len(), self.perm_indices.len()); + let trans = if options.inv { + let aux_options = options.transpose(); + aux_options.trans_val() + } else { + options.trans_val() + }; + + col_perm( + orig_indices.clone(), + self.perm_indices.clone(), + left_arr, + trans, + ); + } + + pub fn stored_bytes(&self) -> u64 { + perm_factor_bytes(self) + } +} + +impl RsrsFactors { + pub fn memory_breakdown(&self) -> FactorMemoryBreakdown { + let mut breakdown = FactorMemoryBreakdown { + perm_bytes: self.perm_factor.stored_bytes(), + ..FactorMemoryBreakdown::default() + }; + + match &self.id_factors { + MultiLevelIdFactors::Single(levels) => { + for batch in levels { + for factor in batch { + breakdown.add_factor(factor); + } + } + } + MultiLevelIdFactors::Batched(levels) => { + for level in levels { + for batch in level { + for factor in batch { + breakdown.add_factor(factor); + } + } + } + } + } + + for level in &self.lu_factors { + for batch in level { + for factor in batch { + breakdown.add_factor(factor); + } + } + } + + for factor in &self.diag_box_factors { + breakdown.add_factor(factor); + } + + breakdown + } +} + +/// Helper to extract far indices when needed +fn get_far_indices(n: usize, near_indices: Vec) -> Vec { + let near_set: HashSet = near_indices.into_iter().collect(); + (0..n).filter(|x| !near_set.contains(x)).collect() +} + +/// FactorOperations: multiplication and inversion of +/// factors (ID, LU, block diagonal). +pub trait FactorOperations: Sized { + type Item: RlstScalar; + /// mul: manages parameters to call mul_data and ins_data + fn mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + target_arr: &mut Array, + options: &MulOptions, + ); + /// mul_data: takes information from target_data and + /// modifies it without changing target_data + fn mul_data< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + target_arr: &Array, + side: &Side, + factor_type: Option, + options: &BaseFactorOptions, + ) -> DynamicArray; + + /// ins_data: uses the result in mul_data to modify target_data + fn ins_data< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + source_arr: &DynamicArray, + target_arr: &mut Array, + side: &Side, + factor_type: Option, + options: &BaseFactorOptions, + ); +} + +/// Constructor of ID factors +impl< + Item: RlstScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + RandScalar + + MatrixLu + + MatrixQr, + > IdFactor +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + /// new: nullifies the near field to perform ID decomposition and return ID factors. + /// Arguments: + /// - target_inds: indices of the target box in the octree + /// - near_field_inds: indices of the near field associated to a give box + /// - y_data: stores the test matrix Ω and the associated sketch Y=AΩ + /// - z_data: stores the test matrix ψ and the associated sketch Z=A'ψ + /// - subs_sample_dim: useful when using less than available samples to + /// perform the decomposition + /// - rank_par: indicates hot to pick the rank given that a box has been + /// merged or not + /// - id_options: see IdOptions + /// - symmetric: indicates if A' should also be sketched or not + #[allow(clippy::too_many_arguments)] + pub fn new( + scratch: &mut ExtractionScratch, + target_inds: &mut [usize], + near_field_inds: &mut [usize], + y_data: &SketchData, + z_data: &SketchData, + subs_sample_dim: usize, + fixed_rank: bool, + rank_par: &BoxType>, + id_options: &IdOptions, + symmetry: &Symmetry, + ) -> (Option, Times) + where + StandardNormal: Distribution, + Standard: Distribution, + LuDecomposition, 2>>: + MatrixLuDecomposition, + QrDecomposition, 2>>: + MatrixQrDecomposition, + { + let start: Instant = Instant::now(); + let test_shape = [subs_sample_dim, near_field_inds.len()]; + let sketch_shape = [subs_sample_dim, target_inds.len()]; + let null_shape = [test_shape[0] - test_shape[1], sketch_shape[1]]; + + // nullification of the near field + null_near_field_into( + target_inds, + near_field_inds, + y_data, + z_data, + subs_sample_dim, + symmetry, + fixed_rank, + id_options, + &mut scratch.primary, + &mut scratch.secondary, + &mut scratch.tertiary, + &mut scratch.normal, + ); + + let nullification_time: Duration = start.elapsed(); + let start: Instant = Instant::now(); + let max_rank: usize = *scratch.primary.shape().iter().min().unwrap(); + if max_rank <= 1 { + let id_times = IdTimes { + nullification: nullification_time.as_millis(), + id: 0, + }; + return (None, Times::Id(id_times)); + } + let id_sketch = match rank_par { + BoxType::Full(tol) => { + // for a box that hasn't been merged yet it + // applies ID with rank-revealing QR if a + // rank has not been prescribed + if *tol < num::One::one() { + scratch + .primary + .r_mut() + .into_subview([0, 0], null_shape) + .into_id_alloc_no_skel( + Accuracy::Tol(*tol), + id_options.qr_method.clone(), + TransMode::Trans, + ) + .unwrap() + } else { + let loc_rank = max_rank.min(num::ToPrimitive::to_usize(tol).unwrap()); + scratch + .primary + .r_mut() + .into_subview([0, 0], null_shape) + .into_id_alloc_no_skel( + Accuracy::FixedRank(loc_rank), + id_options.qr_method.clone(), + TransMode::Trans, + ) + .unwrap() + } + } + // for a box that has been merged + // rank-revealing QR is not necessary + BoxType::Merged(rank) => scratch + .primary + .r_mut() + .into_subview([0, 0], null_shape) + .into_id_alloc_no_skel( + Accuracy::FixedRank((*rank).min(max_rank)), + id_options.qr_method.clone(), + TransMode::Trans, + ) + .unwrap(), + }; + + let k: usize = id_sketch.rank; + let mut ind_r = Vec::new(); + let mut ind_s = Vec::new(); + + let id_time = start.elapsed(); + let id_times = IdTimes { + nullification: nullification_time.as_millis(), + id: id_time.as_millis(), + }; + + let times = Times::Id(id_times); + let factor_symmetry = if symmetry.complex_symmetric_val::() { + Symmetry::NoSymm + } else { + symmetry.clone() + }; + + // we check if the interactions can be compressed or if they should pass to the next level + if id_sketch.rank < max_rank { + let mut ind_f = get_far_indices(y_data.dim, near_field_inds.to_vec()); + if !ind_f.is_empty() { + let aux_indices = target_inds.to_vec(); + + for (id, &elem) in id_sketch.perm.iter().enumerate() { + let val = aux_indices[elem]; + target_inds[id] = val; + near_field_inds[id] = val; + } + } + + // we store the residual and skeleton indices + ind_r.extend_from_slice(&target_inds[k..]); + ind_s.extend_from_slice(&target_inds[..k]); + + // we store the far field indices only for debugging purposes + if !id_options.store_far { + ind_f.clear(); + ind_f.shrink_to_fit(); + } + + ( + Some(Self { + data: FactorData::Reg(RectArr { + arr: Box::new(id_sketch.id_mat), + }), + perm: id_sketch.perm, + symmetry: factor_symmetry, + ind_r, + ind_s, + ind_f, + }), + times, + ) + } else { + (None, times) + } + } + + /// Returns whether this factor application needs an explicit conjugation. + /// + /// `BaseFactorOptions` only tracks whether the factor is transposed. For + /// Hermitian and complex nonsymmetric factors, a transpose also changes + /// whether the extracted sketch data must be conjugated. + pub fn conj_val(&self, trans_val: bool) -> bool { + match self.symmetry { + Symmetry::NoSymm => !trans_val, + // Symmetric means A^T = A, not A^H = A, so no conjugation is needed. + Symmetry::Symmetric => false, + Symmetry::Hermitian => trans_val, + } + } + + fn fill_delta_with_scratch< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + target_arr: &Array, + side: &Side, + options: &BaseFactorOptions, + scratch: &mut FactorApplyScratch, + ) { + let conj_target = self.conj_val(options.trans_val()); + if conj_target { + let layout = factor_apply_layout(side, options, &self.ind_s, &self.ind_r); + let shape = target_arr.shape(); + trace_memory_event( + &format!( + "id_factor conj copy target_arr (rows={}, cols={})", + shape[0], shape[1] + ), + Some(matrix_bytes::(shape[0], shape[1])), + ); + trace_memory_growth( + "id_factor conj source slice", + Some(matrix_bytes::( + if layout.axis == 0 { + layout.source_indices.len() + } else { + shape[0] + }, + if layout.axis == 0 { + shape[1] + } else { + layout.source_indices.len() + }, + )), + ); + } + + self.data.delta_with_scratch( + target_arr, + side, + options, + &self.ind_s, + &self.ind_r, + conj_target, + scratch, + ); + + if conj_target { + conjugate_array_in_place(&mut scratch.result); + } + } + + fn apply_delta_with_scratch< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + target_arr: &mut Array, + side: &Side, + options: &BaseFactorOptions, + scratch: &mut FactorApplyScratch, + ) { + self.fill_delta_with_scratch(target_arr, side, options, scratch); + match side { + Side::Left => row_delta( + &self.ind_s, + &self.ind_r, + &scratch.result, + target_arr, + options, + options.inv, + ), + Side::Right => col_delta( + &self.ind_s, + &self.ind_r, + &scratch.result, + target_arr, + options, + options.inv, + ), + } + } + + unsafe fn apply_delta_with_scratch_raw< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + read_target_arr: &Array, + raw_target_arr: RawMatrixMut, + side: &Side, + options: &BaseFactorOptions, + scratch: &mut FactorApplyScratch, + ) { + self.fill_delta_with_scratch(read_target_arr, side, options, scratch); + match side { + Side::Left => unsafe { + row_delta_raw( + &self.ind_s, + &self.ind_r, + &scratch.result, + raw_target_arr, + options, + options.inv, + ) + }, + Side::Right => unsafe { + col_delta_raw( + &self.ind_s, + &self.ind_r, + &scratch.result, + raw_target_arr, + options, + options.inv, + ) + }, + } + } + + /// Returns the largest entry of the ID factor + pub fn cond(&self) -> (CondType, Option>) { + let (dim, max_entry) = match &self.data { + FactorData::Comp(_composed_factor_data) => todo!(), + FactorData::Reg(rectg) => { + let [rows, cols] = rectg.arr.r().shape(); + let dim = rows.min(cols); + let max_entry = rectg + .arr + .r() + .data() + .iter() + .map(|&x| x.abs()) + .fold(0.0, |a, b| { + let a_c: f64 = num::NumCast::from(a).unwrap(); + let b_c: f64 = num::NumCast::from(b).unwrap(); + a_c.max(b_c) + }); + (Item::real(dim), Item::real(max_entry)) + } + }; + + (self.data.cond(), Some(((dim, max_entry), None))) + } +} + +/// Implementation of multiplication and inversion of ID factors +impl< + Item: RlstScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + RandScalar + + MatrixLu + + MatrixQr, + > FactorOperations for IdFactor +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + type Item = Item; + + /// mul: manages parameters to call mul_data and ins_data + fn mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + target_arr: &mut Array, + options: &MulOptions, + ) { + // Transpose accordingly + // See section 4.3 in Yesypenko, A., & Martinsson, P. G. (2026). Randomized Strong Recursive Skeletonization: + // Simultaneous Compression and LU Factorization of Hierarchical Matrices using Matrix–Vector Products: + // A. Yesypenko, P.-G. Martinsson. Journal of Scientific Computing, 106(3), 63. + // For ID factors the stored "second" half is the transpose partner of + // the "first" half, so applying `S` means toggling the factor + // orientation before we touch the low-level delta kernels. + let aux_options = match options.factor_type { + FactorType::F => options.base_options.clone(), + FactorType::S => options.base_options.transpose(), + }; + let mut scratch = FactorApplyScratch::new(); + self.apply_delta_with_scratch(target_arr, &options.side, &aux_options, &mut scratch); + } + + /// mul_data: performs the elementary operation without changing target_arr + fn mul_data< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + target_arr: &Array, + side: &Side, + _factor_type: Option, + options: &BaseFactorOptions, + ) -> DynamicArray { + let mut scratch = FactorApplyScratch::new(); + self.fill_delta_with_scratch(target_arr, side, options, &mut scratch); + let layout = factor_apply_layout(side, options, &self.ind_s, &self.ind_r); + let mut subarr_target = empty_array(); + extract_axis_into( + &mut subarr_target, + target_arr, + layout.target_indices, + layout.axis, + layout.transposed, + ); + if options.inv { + subarr_target.sub_into(scratch.result.r()); + } else { + subarr_target.sum_into(scratch.result.r()); + } + subarr_target + } + + /// ins_data: modifies the corresponding entries in target_arr + fn ins_data< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + source_arr: &DynamicArray, + target_arr: &mut Array, + side: &Side, + _factor_type: Option, + options: &BaseFactorOptions, + ) { + match side { + Side::Left => { + row_subs( + self.ind_s.clone(), + self.ind_r.clone(), + source_arr, + target_arr, + options, + ); + } + Side::Right => { + col_subs( + self.ind_s.clone(), + self.ind_r.clone(), + source_arr, + target_arr, + options, + ); + } + } + } +} + +impl LuFactor +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + _scratch: &mut ExtractionScratch, + ind_r: &[usize], + near_field_inds: &[usize], + inactive_inds: &[usize], + y_data: &SketchData, + z_data: &SketchData, + subs_sample_dim: usize, + fixed_rank: bool, + lu_options: &ExtractOptions, + symmetry: &Symmetry, + ) -> (Option, Times) + where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, + { + let mut r_numbering: Vec = Vec::new(); + let mut t_numbering: Vec = Vec::new(); + let mut ind_t = Vec::new(); + + let near_field_ind_to_num: HashMap<_, _> = near_field_inds + .iter() + .enumerate() + .map(|(num, ind)| (ind, num)) + .collect(); + for &elem in ind_r.iter() { + r_numbering.push(*near_field_ind_to_num.get(&elem).unwrap()); + } + + for (pos, &elem) in near_field_inds.iter().enumerate() { + if !ind_r.contains(&elem) && inactive_inds.binary_search(&elem).is_err() { + t_numbering.push(pos); + ind_t.push(elem); + } + } + let (mut y_data_r, mut y_data_n, (_y_lu_io_time, y_lu_b_ext_time)) = near_box_extraction( + ind_r, + near_field_inds, + y_data, + subs_sample_dim, + fixed_rank, + false, + lu_options, + &r_numbering, + &t_numbering, + ); + let start = Instant::now(); + let u_arr = + extract_lu_factor_from_blocks(&mut y_data_r, &mut y_data_n, &lu_options.pivot_method); + let u_assembly = start.elapsed(); + + let lu_b_ext_time; + let lu_assembly_time; + + let l_arr = + if !symmetry.factor_symm_val::() || symmetry.complex_symmetric_val::() { + let (secondary_data, conjugate_data) = if symmetry.complex_symmetric_val::() { + (y_data, true) + } else { + (z_data, false) + }; + let (mut z_data_r, mut z_data_n, (_z_lu_io_time, z_lu_b_ext_time)) = + near_box_extraction( + ind_r, + near_field_inds, + secondary_data, + subs_sample_dim, + fixed_rank, + conjugate_data, + lu_options, + &r_numbering, + &t_numbering, + ); + + let start = Instant::now(); + + let l_arr = extract_lu_factor_from_blocks( + &mut z_data_r, + &mut z_data_n, + &lu_options.pivot_method, + ); + let l_assembly = start.elapsed(); + lu_b_ext_time = y_lu_b_ext_time + z_lu_b_ext_time; + lu_assembly_time = u_assembly + l_assembly; + l_arr + } else { + lu_b_ext_time = y_lu_b_ext_time; + lu_assembly_time = u_assembly; + FactorData::Reg(RectArr { + arr: Box::new(empty_array()), + }) + }; + + let lu_times = LuTimes { + extraction: lu_b_ext_time.as_millis(), + lu: lu_assembly_time.as_millis(), + }; + + let times = Times::Lu(lu_times); + let factor_symmetry = if symmetry.complex_symmetric_val::() { + Symmetry::NoSymm + } else { + symmetry.clone() + }; + ( + Some(Self { + l_arr, + u_arr, + symmetry: factor_symmetry, + ind_r: ind_r.to_vec(), + ind_t, + }), + times, + ) + } + + /// Returns whether this LU update needs explicit conjugation. + /// + /// Symmetric factors reuse the same triangular data without conjugation, + /// while Hermitian factors use the same storage path but must conjugate + /// whenever the transpose orientation is requested. + pub fn conj_val(&self, trans_val: bool) -> bool { + match self.symmetry { + Symmetry::NoSymm => trans_val, + Symmetry::Symmetric => false, + Symmetry::Hermitian => trans_val, + } + } + + fn fill_delta_with_scratch< + ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + target_arr: &Array, + side: &Side, + factor_type: Option, + options: &BaseFactorOptions, + scratch: &mut FactorApplyScratch, + ) { + let conj_target = self.conj_val(options.trans_val()); + if conj_target { + let layout = factor_apply_layout(side, options, &self.ind_t, &self.ind_r); + let shape = target_arr.shape(); + trace_memory_event( + &format!( + "lu_factor conj copy target_arr (rows={}, cols={})", + shape[0], shape[1] + ), + Some(matrix_bytes::(shape[0], shape[1])), + ); + trace_memory_growth( + "lu_factor conj source slice", + Some(matrix_bytes::( + if layout.axis == 0 { + layout.source_indices.len() + } else { + shape[0] + }, + if layout.axis == 0 { + shape[1] + } else { + layout.source_indices.len() + }, + )), + ); + } + + if self.symmetry.factor_symm_val::() { + self.u_arr.delta_with_scratch( + target_arr, + side, + options, + &self.ind_t, + &self.ind_r, + conj_target, + scratch, + ); + } else { + match factor_type { + Some(FactorType::F) => self.l_arr.delta_with_scratch( + target_arr, + side, + options, + &self.ind_t, + &self.ind_r, + conj_target, + scratch, + ), + Some(FactorType::S) => self.u_arr.delta_with_scratch( + target_arr, + side, + options, + &self.ind_t, + &self.ind_r, + conj_target, + scratch, + ), + None => todo!(), + }; + } + + if conj_target { + conjugate_array_in_place(&mut scratch.result); + } + } + + fn apply_delta_with_scratch< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + target_arr: &mut Array, + side: &Side, + factor_type: Option, + options: &BaseFactorOptions, + scratch: &mut FactorApplyScratch, + ) { + self.fill_delta_with_scratch(target_arr, side, factor_type, options, scratch); + match side { + Side::Left => row_delta( + &self.ind_t, + &self.ind_r, + &scratch.result, + target_arr, + options, + options.inv, + ), + Side::Right => col_delta( + &self.ind_t, + &self.ind_r, + &scratch.result, + target_arr, + options, + options.inv, + ), + } + } + + unsafe fn apply_delta_with_scratch_raw< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + read_target_arr: &Array, + raw_target_arr: RawMatrixMut, + side: &Side, + factor_type: Option, + options: &BaseFactorOptions, + scratch: &mut FactorApplyScratch, + ) { + self.fill_delta_with_scratch(read_target_arr, side, factor_type, options, scratch); + match side { + Side::Left => unsafe { + row_delta_raw( + &self.ind_t, + &self.ind_r, + &scratch.result, + raw_target_arr, + options, + options.inv, + ) + }, + Side::Right => unsafe { + col_delta_raw( + &self.ind_t, + &self.ind_r, + &scratch.result, + raw_target_arr, + options, + options.inv, + ) + }, + } + } + + pub fn cond(&self) -> (CondType, Option>) { + if !self.symmetry.factor_symm_val::() { + (self.l_arr.cond(), Some(self.u_arr.cond())) + } else { + (self.u_arr.cond(), None) + } + } +} + +impl FactorOperations + for LuFactor +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + type Item = Item; + + fn mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + target_arr: &mut Array, + options: &MulOptions, + ) { + // LU factors store the transpose relationship on the opposite half from + // ID factors, so `F` is the branch that flips orientation here. + let aux_options = match options.factor_type { + FactorType::F => options.base_options.transpose(), + FactorType::S => options.base_options.clone(), + }; + let mut scratch = FactorApplyScratch::new(); + self.apply_delta_with_scratch( + target_arr, + &options.side, + Some(options.factor_type.clone()), + &aux_options, + &mut scratch, + ); + } + + fn mul_data< + ArrayImpl: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + target_arr: &Array, + side: &Side, + factor_type: Option, + options: &BaseFactorOptions, + ) -> DynamicArray { + let mut scratch = FactorApplyScratch::new(); + self.fill_delta_with_scratch(target_arr, side, factor_type, options, &mut scratch); + let layout = factor_apply_layout(side, options, &self.ind_t, &self.ind_r); + let mut subarr_target = empty_array(); + extract_axis_into( + &mut subarr_target, + target_arr, + layout.target_indices, + layout.axis, + layout.transposed, + ); + if options.inv { + subarr_target.sub_into(scratch.result.r()); + } else { + subarr_target.sum_into(scratch.result.r()); + } + subarr_target + } + + fn ins_data< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + source_arr: &DynamicArray, + target_arr: &mut Array, + side: &Side, + factor_type: Option, + options: &BaseFactorOptions, + ) { + if self.symmetry.factor_symm_val::() { + match side { + Side::Left => row_subs( + self.ind_t.clone(), + self.ind_r.clone(), + source_arr, + target_arr, + options, + ), + Side::Right => col_subs( + self.ind_t.clone(), + self.ind_r.clone(), + source_arr, + target_arr, + options, + ), + }; + } else { + match factor_type { + Some(FactorType::F) => { + match side { + Side::Left => row_subs( + self.ind_t.clone(), + self.ind_r.clone(), + source_arr, + target_arr, + options, + ), + Side::Right => col_subs( + self.ind_t.clone(), + self.ind_r.clone(), + source_arr, + target_arr, + options, + ), + }; + } + Some(FactorType::S) => { + match side { + Side::Left => row_subs( + self.ind_t.clone(), + self.ind_r.clone(), + source_arr, + target_arr, + options, + ), + Side::Right => col_subs( + self.ind_t.clone(), + self.ind_r.clone(), + source_arr, + target_arr, + options, + ), + }; + } + None => todo!(), + } + } + } +} + +impl DiagBoxArr +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + fn from_extracted( + diag_box: &DynamicArray, + db_ext_options: &ExtractOptions, + ) -> Self { + let shape = diag_box.shape(); + let transposed_bytes = matrix_bytes::(shape[1], shape[0]); + + match db_ext_options.pivot_method { + PivotMethod::DirectInversion => { + let mut arr = empty_array(); + trace_memory_event( + &format!( + "diag_box_arr direct transpose diag_box -> arr (box={}, samples={})", + shape[0], shape[1] + ), + Some(transposed_bytes), + ); + arr.fill_from_resize(diag_box.r().transpose()); + + trace_memory_event( + &format!( + "diag_box_arr direct clone arr -> inv_arr (box={}, samples={})", + shape[1], shape[0] + ), + Some(transposed_bytes), + ); + let mut inv_arr = empty_array(); + inv_arr.fill_from_resize(arr.r()); + inv_arr.r_mut().into_inverse_alloc().unwrap(); + let reg_arr = RegSMat { arr, inv_arr }; + DiagBoxArr::Reg(reg_arr) + } + PivotMethod::Lu(alpha) => { + let mut lu_input = empty_array(); + trace_memory_event( + &format!( + "diag_box_arr lu transpose diag_box -> lu_input (box={}, samples={})", + shape[0], shape[1] + ), + Some(transposed_bytes), + ); + lu_input.fill_from_resize(diag_box.r().transpose()); + add_diagonal(&mut lu_input, Item::real(alpha)); + let lu = ::into_lu_alloc(lu_input).unwrap(); + let square_factors = SquareLuFactors::from_lu(&lu).unwrap(); + let lu_arr = LuSMat { square_factors }; + DiagBoxArr::Lu(lu_arr) + } + PivotMethod::LuHybrid(alpha) => { + let mut lu_input = empty_array(); + trace_memory_event( + &format!( + "diag_box_arr lu hybrid transpose diag_box -> lu_input (box={}, samples={})", + shape[0], shape[1] + ), + Some(transposed_bytes), + ); + lu_input.fill_from_resize(diag_box.r().transpose()); + add_diagonal(&mut lu_input, Item::real(alpha)); + let mut arr = empty_array(); + arr.fill_from_resize(lu_input.r()); + let lu = ::into_lu_alloc(lu_input).unwrap(); + let mut inv_arr = rlst_dynamic_array2!(Item, [shape[1], shape[1]]); + add_diagonal(&mut inv_arr, num::One::one()); + as MatrixLuDecomposition>::solve_mat( + &lu, + TransMode::NoTrans, + inv_arr.r_mut(), + ) + .unwrap(); + DiagBoxArr::Reg(RegSMat { arr, inv_arr }) + } + } + } + + fn streamed_extraction_from_data( + inds: &[usize], + tol_lstsq: Real, + sketch_data: &SketchData, + subs_sample_dim: usize, + conjugate_data: bool, + test_chunk: &mut DynamicArray, + sketch_chunk: &mut DynamicArray, + ) -> DynamicArray { + let capped_samples = subs_sample_dim.min(sketch_data.test.shape()[0]); + let chunk_rows = streaming_chunk_rows::(capped_samples, inds.len() * 2, 2).max(1); + let mut accumulator = NormalEquationAccumulator::::new(inds.len(), inds.len()); + + for chunk in sketch_data.chunk_iter(subs_sample_dim, inds.len() * 2, 2) { + extract_axis_into(test_chunk, &chunk.test, inds, 1, false); + extract_axis_into(sketch_chunk, &chunk.sketch, inds, 1, false); + if conjugate_data { + conjugate_array_in_place(test_chunk); + conjugate_array_in_place(sketch_chunk); + } + accumulator.add_chunk(test_chunk, sketch_chunk); + } + + trace_memory_growth( + &format!( + "diag_box_extraction streamed chunks (size={}, samples={}, chunk_rows={chunk_rows})", + inds.len(), + capped_samples, + ), + Some( + matrix_bytes::(chunk_rows, inds.len()) * 2 + + matrix_bytes::(inds.len(), inds.len()) * 2, + ), + ); + + accumulator.solve(tol_lstsq) + } + + #[allow(clippy::too_many_arguments)] + fn new_with_scratch< + ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + Stride<2> + + RawAccess + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + inds: &[usize], + db_ext_options: &ExtractOptions, + sub_test: &Array, + sub_sketch: &Array, + test_c: &mut DynamicArray, + sketch_r: &mut DynamicArray, + diag_box: &mut DynamicArray, + symmetrize_adjoint: Option, + ) -> Self { + extract_axis_into(sketch_r, sub_sketch, inds, 1, false); + extract_axis_into(test_c, sub_test, inds, 1, false); + block_extraction_into(test_c, sketch_r, db_ext_options, diag_box); + if let Some(adjoint) = symmetrize_adjoint { + symmetrize_square_in_place(diag_box, adjoint); + } + trace_memory_growth( + &format!( + "diag_box_extraction symmetric (size={}, samples={})", + inds.len(), + sub_test.shape()[0] + ), + Some( + matrix_bytes::(sub_test.shape()[0], inds.len()) * 2 + + matrix_bytes::(inds.len(), inds.len()), + ), + ); + Self::from_extracted(diag_box, db_ext_options) + } + + #[allow(clippy::too_many_arguments)] + fn new_no_symm_with_scratch< + ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + Stride<2> + + RawAccess + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + inds: &[usize], + db_ext_options: &ExtractOptions, + y_sub_test: &Array, + y_sub_sketch: &Array, + z_sub_test: &Array, + z_sub_sketch: &Array, + y_conjugate: bool, + z_conjugate: bool, + y_test_c: &mut DynamicArray, + y_sketch_r: &mut DynamicArray, + z_test_c: &mut DynamicArray, + z_sketch_r: &mut DynamicArray, + diag_box: &mut DynamicArray, + ) -> Self { + extract_axis_into(y_sketch_r, y_sub_sketch, inds, 1, false); + extract_axis_into(y_test_c, y_sub_test, inds, 1, false); + if y_conjugate { + conjugate_array_in_place(y_sketch_r); + conjugate_array_in_place(y_test_c); + } + block_extraction_into(y_test_c, y_sketch_r, db_ext_options, diag_box); + + extract_axis_into(z_sketch_r, z_sub_sketch, inds, 1, false); + extract_axis_into(z_test_c, z_sub_test, inds, 1, false); + if z_conjugate { + conjugate_array_in_place(z_sketch_r); + conjugate_array_in_place(z_test_c); + } + block_extraction_into(z_test_c, z_sketch_r, db_ext_options, y_test_c); + + diag_box.sum_into(y_test_c.r().transpose().conj()); + trace_memory_growth( + &format!( + "diag_box_extraction nonsymmetric (size={}, samples={})", + inds.len(), + y_sub_test.shape()[0] + ), + Some( + matrix_bytes::(y_sub_test.shape()[0], inds.len()) * 4 + + matrix_bytes::(inds.len(), inds.len()) * 3, + ), + ); + + diag_box + .r_mut() + .scale_inplace(num::NumCast::from(0.5).unwrap()); + + Self::from_extracted(diag_box, db_ext_options) + } + + fn left_mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + Stride<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + right_arr: &mut Array, + factor_options: &BaseFactorOptions, + ) { + assert_transpose_only_mode(factor_options.trans, "DiagBoxArr::left_mul"); + match self { + DiagBoxArr::Reg(ref reg) => { + let mut new_right_arr = empty_array(); + if factor_options.inv { + new_right_arr.r_mut().mult_into_resize( + factor_options.trans, + TransMode::NoTrans, + num::One::one(), + reg.inv_arr.r(), + right_arr.r(), + num::Zero::zero(), + ); + } else { + new_right_arr.r_mut().mult_into_resize( + factor_options.trans, + TransMode::NoTrans, + num::One::one(), + reg.arr.r(), + right_arr.r(), + num::Zero::zero(), + ); + } + right_arr.r_mut().fill_from(new_right_arr.r()); + } + DiagBoxArr::Lu(ref lu) => { + if factor_options.inv { + lu.square_factors + .solve_mat(factor_options.trans, right_arr.r_mut()) + .unwrap(); + } else { + lu.square_factors + .mul_mat(factor_options.trans, right_arr.r_mut()) + .unwrap(); + } + } + } + } + + pub fn mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + Stride<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + arr: &mut Array, + side: &Side, + factor_options: &BaseFactorOptions, + ) { + assert_transpose_only_mode(factor_options.trans, "DiagBoxArr::mul"); + match side { + Side::Left => self.left_mul(arr, factor_options), + Side::Right => { + let shape = arr.shape(); + trace_memory_event( + &format!( + "diag_factor right transpose copy (rows={}, cols={})", + shape[1], shape[0] + ), + Some(matrix_bytes::(shape[1], shape[0])), + ); + let mut aux_arr = empty_array(); + aux_arr.r_mut().fill_from_resize(arr.r().transpose()); + let aux_factor_options = factor_options.transpose(); + self.left_mul(&mut aux_arr, &aux_factor_options); + arr.fill_from(aux_arr.r().transpose()); + } + } + } +} + +impl DiagBoxFactor +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + pub fn new( + rows: &mut [usize], + y_data: &SketchData, + subs_sample_dim: usize, + fixed_rank: bool, + options: &ExtractOptions, + ) -> (Option, Times) { + let mut scratch = DiagExtractionScratch::new(); + Self::new_with_scratch( + rows.to_vec(), + y_data, + subs_sample_dim, + fixed_rank, + options, + &mut scratch, + ) + } + + pub(crate) fn new_with_scratch( + rows: Vec, + y_data: &SketchData, + subs_sample_dim: usize, + fixed_rank: bool, + options: &ExtractOptions, + scratch: &mut DiagExtractionScratch, + ) -> (Option, Times) { + Self::new_with_scratch_impl( + rows, + y_data, + subs_sample_dim, + fixed_rank, + options, + scratch, + None, + ) + } + + pub(crate) fn new_symm_with_scratch( + rows: Vec, + y_data: &SketchData, + subs_sample_dim: usize, + fixed_rank: bool, + options: &ExtractOptions, + scratch: &mut DiagExtractionScratch, + adjoint: bool, + ) -> (Option, Times) { + Self::new_with_scratch_impl( + rows, + y_data, + subs_sample_dim, + fixed_rank, + options, + scratch, + Some(adjoint), + ) + } + + fn new_with_scratch_impl( + rows: Vec, + y_data: &SketchData, + subs_sample_dim: usize, + fixed_rank: bool, + options: &ExtractOptions, + scratch: &mut DiagExtractionScratch, + symmetrize_adjoint: Option, + ) -> (Option, Times) { + let diag_times = LuTimes { + //TODO: change this to diag_times + extraction: 0_u128, + lu: 0_u128, + }; + + let times = Times::Lu(diag_times); + + ( + Some(Self { + arr: if fixed_rank + && matches!( + options.block_extraction_method, + BlockExtractionMethod::LuLstSq + ) + && !prefer_direct_diag_extraction::(rows.len(), subs_sample_dim, 2) + { + { + let mut diag_box = DiagBoxArr::streamed_extraction_from_data( + &rows, + options.tol_lstsq, + y_data, + subs_sample_dim, + false, + &mut scratch.primary, + &mut scratch.secondary, + ); + if let Some(adjoint) = symmetrize_adjoint { + symmetrize_square_in_place(&mut diag_box, adjoint); + } + DiagBoxArr::from_extracted(&diag_box, options) + } + } else { + let (sub_test, sub_sketch) = ( + y_data + .test + .r() + .into_subview([0, 0], [subs_sample_dim, y_data.dim]), + y_data + .sketch + .r() + .into_subview([0, 0], [subs_sample_dim, y_data.dim]), + ); + DiagBoxArr::new_with_scratch( + &rows, + options, + &sub_test, + &sub_sketch, + &mut scratch.primary, + &mut scratch.secondary, + &mut scratch.tertiary, + symmetrize_adjoint, + ) + }, + inds: rows, + }), + times, + ) + } + + pub fn new_no_symm( + rows: &mut [usize], + y_data: &SketchData, + z_data: &SketchData, + subs_sample_dim: usize, + fixed_rank: bool, + options: &ExtractOptions, + ) -> (Option, Times) { + let mut scratch = DiagExtractionScratch::new(); + Self::new_no_symm_with_scratch( + rows.to_vec(), + y_data, + z_data, + subs_sample_dim, + fixed_rank, + options, + &mut scratch, + ) + } + + pub(crate) fn new_no_symm_with_scratch( + rows: Vec, + y_data: &SketchData, + z_data: &SketchData, + subs_sample_dim: usize, + fixed_rank: bool, + options: &ExtractOptions, + scratch: &mut DiagExtractionScratch, + ) -> (Option, Times) { + let diag_times = LuTimes { + //TODO: change this to diag_times + extraction: 0_u128, + lu: 0_u128, + }; + + let times = Times::Lu(diag_times); + + ( + Some(Self { + arr: if fixed_rank + && matches!( + options.block_extraction_method, + BlockExtractionMethod::LuLstSq + ) + && !prefer_direct_diag_extraction::(rows.len(), subs_sample_dim, 4) + { + let y_diag_box = DiagBoxArr::streamed_extraction_from_data( + &rows, + options.tol_lstsq, + y_data, + subs_sample_dim, + false, + &mut scratch.primary, + &mut scratch.secondary, + ); + let z_diag_box = DiagBoxArr::streamed_extraction_from_data( + &rows, + options.tol_lstsq, + z_data, + subs_sample_dim, + false, + &mut scratch.tertiary, + &mut scratch.quaternary, + ); + let mut diag_box = y_diag_box; + diag_box.sum_into(z_diag_box.r().transpose().conj()); + diag_box + .r_mut() + .scale_inplace(num::NumCast::from(0.5).unwrap()); + DiagBoxArr::from_extracted(&diag_box, options) + } else { + let (y_sub_test, y_sub_sketch) = ( + y_data + .test + .r() + .into_subview([0, 0], [subs_sample_dim, y_data.dim]), + y_data + .sketch + .r() + .into_subview([0, 0], [subs_sample_dim, y_data.dim]), + ); + + let (z_sub_test, z_sub_sketch) = ( + z_data + .test + .r() + .into_subview([0, 0], [subs_sample_dim, z_data.dim]), + z_data + .sketch + .r() + .into_subview([0, 0], [subs_sample_dim, z_data.dim]), + ); + + DiagBoxArr::new_no_symm_with_scratch( + &rows, + options, + &y_sub_test, + &y_sub_sketch, + &z_sub_test, + &z_sub_sketch, + false, + false, + &mut scratch.primary, + &mut scratch.secondary, + &mut scratch.tertiary, + &mut scratch.quaternary, + &mut scratch.quinary, + ) + }, + inds: rows, + }), + times, + ) + } + + pub(crate) fn new_complex_symm_with_scratch( + rows: Vec, + y_data: &SketchData, + subs_sample_dim: usize, + fixed_rank: bool, + options: &ExtractOptions, + scratch: &mut DiagExtractionScratch, + ) -> (Option, Times) { + let diag_times = LuTimes { + extraction: 0_u128, + lu: 0_u128, + }; + + let times = Times::Lu(diag_times); + + ( + Some(Self { + arr: if fixed_rank + && matches!( + options.block_extraction_method, + BlockExtractionMethod::LuLstSq + ) + && !prefer_direct_diag_extraction::(rows.len(), subs_sample_dim, 4) + { + let y_diag_box = DiagBoxArr::streamed_extraction_from_data( + &rows, + options.tol_lstsq, + y_data, + subs_sample_dim, + false, + &mut scratch.primary, + &mut scratch.secondary, + ); + let z_diag_box = DiagBoxArr::streamed_extraction_from_data( + &rows, + options.tol_lstsq, + y_data, + subs_sample_dim, + true, + &mut scratch.tertiary, + &mut scratch.quaternary, + ); + let mut diag_box = y_diag_box; + diag_box.sum_into(z_diag_box.r().transpose().conj()); + diag_box + .r_mut() + .scale_inplace(num::NumCast::from(0.5).unwrap()); + DiagBoxArr::from_extracted(&diag_box, options) + } else { + let (y_sub_test, y_sub_sketch) = ( + y_data + .test + .r() + .into_subview([0, 0], [subs_sample_dim, y_data.dim]), + y_data + .sketch + .r() + .into_subview([0, 0], [subs_sample_dim, y_data.dim]), + ); + + DiagBoxArr::new_no_symm_with_scratch( + &rows, + options, + &y_sub_test, + &y_sub_sketch, + &y_sub_test, + &y_sub_sketch, + false, + true, + &mut scratch.primary, + &mut scratch.secondary, + &mut scratch.tertiary, + &mut scratch.quaternary, + &mut scratch.quinary, + ) + }, + inds: rows, + }), + times, + ) + } + + pub fn cond(&self) -> (CondType, Option>) { + match &self.arr { + DiagBoxArr::Reg(reg_dbox) => ((condition_number(®_dbox.arr), None), None), + DiagBoxArr::Lu(_) => (((num::Zero::zero(), num::Zero::zero()), None), None), + } + } +} + +impl FactorOperations + for DiagBoxFactor +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + type Item = Item; + + fn mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + target_arr: &mut Array, + factor_options: &MulOptions, + ) { + let target_block = self.mul_data( + target_arr, + &factor_options.side, + None, + &factor_options.base_options, + ); + let t_arr_mutex = std::sync::Mutex::new(target_arr); + self.ins_data( + &target_block, + *t_arr_mutex.lock().unwrap(), + &factor_options.side, + None, + &factor_options.base_options, + ); + } + + fn mul_data< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + target_arr: &Array, + side: &Side, + _factor_type: Option, + factor_options: &BaseFactorOptions, + ) -> DynamicArray { + match side { + Side::Left => { + let mut target_rows = ext_rows( + self.inds.clone(), + self.inds.clone(), + target_arr, + factor_options, + ); + self.arr.mul(&mut target_rows, &Side::Left, factor_options); + target_rows + } + Side::Right => { + let mut target_cols = ext_cols( + self.inds.clone(), + self.inds.clone(), + target_arr, + factor_options, + ); + + self.arr.mul(&mut target_cols, &Side::Right, factor_options); + target_cols + } + } + } + + fn ins_data< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item>, + >( + &self, + source_arr: &DynamicArray, + target_arr: &mut Array, + side: &Side, + _factor_type: Option, + factor_options: &BaseFactorOptions, + ) { + match side { + Side::Left => { + row_subs( + self.inds.clone(), + self.inds.clone(), + source_arr, + target_arr, + factor_options, + ); + } + Side::Right => { + col_subs( + self.inds.clone(), + self.inds.clone(), + source_arr, + target_arr, + factor_options, + ); + } + } + } +} + +pub trait CommutativeFactorsOperations: Sized { + type Item: RlstScalar; + fn new() -> Self; + fn add_factor(&mut self, factor: Factor); + fn mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item> + + std::marker::Send + + std::marker::Sync, + >( + &self, + target_arr: &mut Array, + thread_pool: &ThreadPool, + num_threads: usize, + factor_options: &MulOptions, + ); + #[allow(clippy::type_complexity)] + fn get_condition_numbers(&self) -> Vec<(CondType, Option>)>; + fn flush(&mut self); +} + +#[cfg(debug_assertions)] +fn indices_disjoint(a: &[usize], b: &[usize]) -> bool { + let a_set: HashSet = a.iter().copied().collect(); + b.iter().all(|idx| !a_set.contains(idx)) +} + +#[cfg(debug_assertions)] +fn factor_apply_options( + factor: &Factor, + factor_options: &MulOptions, +) -> BaseFactorOptions { + match factor { + Factor::Lu(_) => match factor_options.factor_type { + FactorType::F => factor_options.base_options.transpose(), + FactorType::S => factor_options.base_options.clone(), + }, + Factor::Id(_) => match factor_options.factor_type { + FactorType::F => factor_options.base_options.clone(), + FactorType::S => factor_options.base_options.transpose(), + }, + Factor::Diag(_) => factor_options.base_options.clone(), + } +} + +#[cfg(debug_assertions)] +fn factor_read_write_indices( + factor: &Factor, + side: &Side, + base_options: &BaseFactorOptions, +) -> (Vec, Vec) { + let layout = match factor { + Factor::Lu(lu_factor) => { + factor_apply_layout(side, base_options, &lu_factor.ind_t, &lu_factor.ind_r) + } + Factor::Id(id_factor) => { + factor_apply_layout(side, base_options, &id_factor.ind_s, &id_factor.ind_r) + } + Factor::Diag(diag_factor) => { + factor_apply_layout(side, base_options, &diag_factor.inds, &diag_factor.inds) + } + }; + ( + layout.source_indices.to_vec(), + layout.target_indices.to_vec(), + ) +} + +#[cfg(debug_assertions)] +fn assert_chunk_noninterfering( + factors: &[Factor], + factor_options: &MulOptions, +) { + let footprints: Vec<_> = factors + .iter() + .map(|factor| { + let base_options = factor_apply_options(factor, factor_options); + factor_read_write_indices(factor, &factor_options.side, &base_options) + }) + .collect(); + + for left in 0..footprints.len() { + for right in (left + 1)..footprints.len() { + let (left_reads, left_writes) = &footprints[left]; + let (right_reads, right_writes) = &footprints[right]; + debug_assert!( + indices_disjoint(left_writes, right_writes), + "factor write sets overlap in parallel chunk" + ); + debug_assert!( + indices_disjoint(left_reads, right_writes), + "factor read/write sets overlap in parallel chunk" + ); + debug_assert!( + indices_disjoint(left_writes, right_reads), + "factor write/read sets overlap in parallel chunk" + ); + } + } +} + +impl< + Item: RlstScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + RandScalar + + MatrixLu + + MatrixQr, + > CommutativeFactorsOperations for CommutativeFactors +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + type Item = Item; + + fn new() -> Self { + Vec::new() + } + fn add_factor(&mut self, factor: Factor) { + self.push(factor); + } + + fn mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Self::Item> + + UnsafeRandomAccessByRef<2, Item = Self::Item> + + std::marker::Send + + std::marker::Sync, + >( + &self, + target_arr: &mut Array, + thread_pool: &ThreadPool, + num_threads: usize, + factor_options: &MulOptions, + ) where + Self: Sized, + { + let chunk_size = num_threads.max(1); + + for chunk_start in (0..self.len()).step_by(chunk_size) { + let chunk_end = (chunk_start + chunk_size).min(self.len()); + let factors = &self[chunk_start..chunk_end]; + let direct_update_chunk = factors + .iter() + .all(|factor| !matches!(factor, Factor::Diag(_))); + + if direct_update_chunk { + #[cfg(debug_assertions)] + assert_chunk_noninterfering(factors, factor_options); + + let raw_target = raw_matrix_mut(target_arr); + let read_target: &Array = + unsafe { &*(target_arr as *const Array) }; + + thread_pool.install(|| { + factors.par_iter().for_each_init( + FactorApplyScratch::::new, + |scratch, factor| match factor { + Factor::Lu(lu_factor) => { + let base_options = match factor_options.factor_type { + FactorType::F => factor_options.base_options.transpose(), + FactorType::S => factor_options.base_options.clone(), + }; + unsafe { + lu_factor.apply_delta_with_scratch_raw( + read_target, + raw_target, + &factor_options.side, + Some(factor_options.factor_type.clone()), + &base_options, + scratch, + ) + }; + } + Factor::Id(id_factor) => { + let base_options = match factor_options.factor_type { + FactorType::F => factor_options.base_options.clone(), + FactorType::S => factor_options.base_options.transpose(), + }; + unsafe { + id_factor.apply_delta_with_scratch_raw( + read_target, + raw_target, + &factor_options.side, + &base_options, + scratch, + ) + }; + } + Factor::Diag(_) => unreachable!("diag factors use the fallback path"), + }, + ) + }); + } else { + let updated_t_arr_blocks: Vec<_> = thread_pool.install(|| { + factors + .par_iter() + .enumerate() + .map(|(offset, factor)| { + let factor_ind = chunk_start + offset; + let target_block = match factor { + Factor::Lu(lu_factor) => lu_factor.mul_data( + target_arr, + &factor_options.side, + Some(factor_options.factor_type.clone()), + &match factor_options.factor_type { + FactorType::F => factor_options.base_options.transpose(), + FactorType::S => factor_options.base_options.clone(), + }, + ), + Factor::Id(id_factor) => id_factor.mul_data( + target_arr, + &factor_options.side, + None, + &match factor_options.factor_type { + FactorType::F => factor_options.base_options.clone(), + FactorType::S => factor_options.base_options.transpose(), + }, + ), + Factor::Diag(diag_factor) => diag_factor.mul_data( + target_arr, + &factor_options.side, + Some(factor_options.factor_type.clone()), + &factor_options.base_options, + ), + }; + (factor_ind, target_block) + }) + .collect() + }); + + for (factor_ind, target_block) in updated_t_arr_blocks { + let factor = &self[factor_ind]; + match factor { + Factor::Lu(lu_factor) => { + let base_options = match factor_options.factor_type { + FactorType::F => factor_options.base_options.transpose(), + FactorType::S => factor_options.base_options.clone(), + }; + lu_factor.ins_data( + &target_block, + target_arr, + &factor_options.side, + Some(factor_options.factor_type.clone()), + &base_options, + ) + } + Factor::Id(id_factor) => { + let base_options = match factor_options.factor_type { + FactorType::F => factor_options.base_options.clone(), + FactorType::S => factor_options.base_options.transpose(), + }; + id_factor.ins_data( + &target_block, + target_arr, + &factor_options.side, + Some(factor_options.factor_type.clone()), + &base_options, + ) + } + Factor::Diag(diag_factor) => diag_factor.ins_data( + &target_block, + target_arr, + &factor_options.side, + Some(factor_options.factor_type.clone()), + &factor_options.base_options, + ), + }; + } + } + } + } + + fn get_condition_numbers(&self) -> Vec<(CondType, Option>)> { + let condition_numbers: Vec<_> = self + .par_iter() + .enumerate() + .map(|(_factor_ind, factor)| match factor { + Factor::Lu(lu_factor) => lu_factor.cond(), + Factor::Id(id_factor) => id_factor.cond(), + Factor::Diag(diag_factor) => diag_factor.cond(), + }) + .collect(); + + condition_numbers + } + + fn flush(&mut self) { + self.clear(); + self.shrink_to_fit(); + } +} diff --git a/src/rsrs/rsrs_factors/mod.rs b/src/rsrs/rsrs_factors/mod.rs new file mode 100644 index 0000000..feb72cd --- /dev/null +++ b/src/rsrs/rsrs_factors/mod.rs @@ -0,0 +1,6 @@ +//! This crate handles the factors used in RSRS. It implements multiplication and inversion when needed +//! +pub mod base_factors; +pub mod commutative_factors; +pub mod null_and_extract; +pub mod rsrs_operator; diff --git a/src/rsrs/rsrs_factors/null_and_extract.rs b/src/rsrs/rsrs_factors/null_and_extract.rs new file mode 100644 index 0000000..6b70e41 --- /dev/null +++ b/src/rsrs/rsrs_factors/null_and_extract.rs @@ -0,0 +1,668 @@ +use std::time::{Duration, Instant}; + +use crate::{ + rsrs::{ + args::Symmetry, + rsrs_factors::base_factors::{ + conjugate_array_in_place, ComposedFactorData, FactorData, LuSMat, RectArr, RegSMat, + SquareArr, + }, + sketch::SketchData, + }, + utils::{ + data_ins_ext::extract_axis_into, + linear_algebra::{ + add_diagonal, block_extraction_into, nullify_near_sketch, BlockExtractionMethod, + NormalEquationScratch, NullMethod, + }, + memory::{matrix_bytes, trace_memory_event, trace_memory_growth}, + }, +}; +use rand_distr::{Distribution, Standard, StandardNormal}; +use rlst::{ + dense::{ + linalg::{interpolative_decomposition::MatrixIdNoSkel, lu::MatrixLu}, + tools::RandScalar, + }, + prelude::*, +}; +use serde::{Deserialize, Serialize}; + +type Real = ::Real; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "value")] +pub enum PivotMethod { + DirectInversion, + Lu(f64), //TODO: Change to Item + LuHybrid(f64), //TODO: Change to Item +} + +#[derive(Debug, Clone)] +pub struct ExtractOptions { + pub block_extraction_method: BlockExtractionMethod, + pub pivot_method: PivotMethod, + pub tol_lstsq: Real, +} + +#[derive(Debug, Clone)] +pub struct IdOptions { + pub null_method: NullMethod, + pub qr_method: RankRevealingQrType>, + pub tol_null: Real, + pub tol_id: Real, + pub store_far: bool, +} + +pub struct ExtractionScratch { + pub primary: DynamicArray, + pub secondary: DynamicArray, + pub tertiary: DynamicArray, + pub normal: NormalEquationScratch, +} + +impl ExtractionScratch { + pub fn new() -> Self { + Self { + primary: empty_array(), + secondary: empty_array(), + tertiary: empty_array(), + normal: NormalEquationScratch::new(), + } + } +} + +impl Default for ExtractionScratch { + fn default() -> Self { + Self::new() + } +} + +#[allow(clippy::too_many_arguments)] +fn null_sketch_near_field_into< + Item: RlstScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + RandScalar + + MatrixLu + + MatrixQr, +>( + target_inds: &[usize], + near_field_inds: &[usize], + sketch: &DynamicArray, + test: &DynamicArray, + subs_sample_dim: usize, + conjugate_data: bool, + id_options: &IdOptions, + sketch_t: &mut DynamicArray, + test_n: &mut DynamicArray, + normal_scratch: &mut NormalEquationScratch, +) where + StandardNormal: Distribution, + Standard: Distribution, + QrDecomposition, 2>>: + MatrixQrDecomposition, + LuDecomposition, 2>>: + MatrixLuDecomposition, +{ + let dim = test.shape()[1]; + let sub_test = test.r().into_subview([0, 0], [subs_sample_dim, dim]); + let sub_sketch = sketch.r().into_subview([0, 0], [subs_sample_dim, dim]); + extract_axis_into(sketch_t, &sub_sketch, target_inds, 1, false); + extract_axis_into(test_n, &sub_test, near_field_inds, 1, false); + if conjugate_data { + conjugate_array_in_place(sketch_t); + conjugate_array_in_place(test_n); + } + trace_memory_growth( + &format!( + "null_sketch_near_field buffers (targets={}, near={}, samples={subs_sample_dim})", + target_inds.len(), + near_field_inds.len() + ), + Some( + matrix_bytes::(subs_sample_dim, target_inds.len()) + + matrix_bytes::(subs_sample_dim, near_field_inds.len()), + ), + ); + nullify_near_sketch(test_n, sketch_t, id_options, normal_scratch); +} + +#[allow(clippy::too_many_arguments)] +pub fn null_near_field_into< + Item: RlstScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + RandScalar + + MatrixLu + + MatrixQr, +>( + target_inds: &[usize], + near_field_inds: &[usize], + y_data: &SketchData, + z_data: &SketchData, + subs_sample_dim: usize, + symmetry: &Symmetry, + _fixed_rank: bool, + id_options: &IdOptions, + far_field_sketch: &mut DynamicArray, + test_scratch: &mut DynamicArray, + aux_sketch: &mut DynamicArray, + normal_scratch: &mut NormalEquationScratch, +) where + StandardNormal: Distribution, + Standard: Distribution, + QrDecomposition, 2>>: + MatrixQrDecomposition, + LuDecomposition, 2>>: + MatrixLuDecomposition, +{ + let complex_symmetric = symmetry.complex_symmetric_val::(); + + if symmetry.symm_val() { + null_sketch_near_field_into( + target_inds, + near_field_inds, + &y_data.sketch, + &y_data.test, + subs_sample_dim, + false, + id_options, + far_field_sketch, + test_scratch, + normal_scratch, + ); + + if complex_symmetric { + null_sketch_near_field_into( + target_inds, + near_field_inds, + &y_data.sketch, + &y_data.test, + subs_sample_dim, + true, + id_options, + aux_sketch, + test_scratch, + normal_scratch, + ); + far_field_sketch.sum_into(aux_sketch.r()); + } + } else { + null_sketch_near_field_into( + target_inds, + near_field_inds, + &y_data.sketch, + &y_data.test, + subs_sample_dim, + false, + id_options, + far_field_sketch, + test_scratch, + normal_scratch, + ); + null_sketch_near_field_into( + target_inds, + near_field_inds, + &z_data.sketch, + &z_data.test, + subs_sample_dim, + false, + id_options, + aux_sketch, + test_scratch, + normal_scratch, + ); + far_field_sketch.sum_into(aux_sketch.r()); + } +} + +#[allow(clippy::too_many_arguments)] +pub fn null_near_field< + Item: RlstScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + RandScalar + + MatrixLu + + MatrixQr, +>( + target_inds: &[usize], + near_field_inds: &[usize], + y_data: &SketchData, + z_data: &SketchData, + subs_sample_dim: usize, + symmetry: &Symmetry, + fixed_rank: bool, + id_options: &IdOptions, +) -> DynamicArray +where + StandardNormal: Distribution, + Standard: Distribution, + QrDecomposition, 2>>: + MatrixQrDecomposition, + LuDecomposition, 2>>: + MatrixLuDecomposition, +{ + let mut scratch = ExtractionScratch::new(); + null_near_field_into( + target_inds, + near_field_inds, + y_data, + z_data, + subs_sample_dim, + symmetry, + fixed_rank, + id_options, + &mut scratch.primary, + &mut scratch.secondary, + &mut scratch.tertiary, + &mut scratch.normal, + ); + scratch.primary +} + +#[allow(clippy::too_many_arguments)] +pub fn near_box_extraction_into( + ind_r: &[usize], + near_field_inds: &[usize], + sketch_data: &SketchData, + subs_sample_dim: usize, + _fixed_rank: bool, + conjugate_data: bool, + lu_options: &ExtractOptions, + sample_r: &mut DynamicArray, + sample_n: &mut DynamicArray, + near_box: &mut DynamicArray, +) -> (Duration, Duration) +where + LuDecomposition, 2>>: + MatrixLuDecomposition, +{ + let dim = sketch_data.test.shape()[1]; + let test_subview = sketch_data + .test + .r() + .into_subview([0, 0], [subs_sample_dim, dim]); + let sketch_subview = sketch_data + .sketch + .r() + .into_subview([0, 0], [subs_sample_dim, dim]); + let start = Instant::now(); + extract_axis_into(sample_r, &sketch_subview, ind_r, 1, false); + extract_axis_into(sample_n, &test_subview, near_field_inds, 1, false); + if conjugate_data { + conjugate_array_in_place(sample_r); + conjugate_array_in_place(sample_n); + } + + let lu_io_time = start.elapsed(); + let start = Instant::now(); + block_extraction_into(sample_n, sample_r, lu_options, near_box); + trace_memory_growth( + &format!( + "near_box_extraction block (|r|={}, |near|={}, samples={subs_sample_dim})", + ind_r.len(), + near_field_inds.len() + ), + Some( + matrix_bytes::(subs_sample_dim, ind_r.len()) + + matrix_bytes::(subs_sample_dim, near_field_inds.len()) + + matrix_bytes::(near_field_inds.len(), ind_r.len()), + ), + ); + let lu_b_ext_time = start.elapsed(); + (lu_io_time, lu_b_ext_time) +} + +#[allow(clippy::too_many_arguments)] +pub fn near_box_extraction( + ind_r: &[usize], + near_field_inds: &[usize], + sketch_data: &SketchData, + subs_sample_dim: usize, + fixed_rank: bool, + conjugate_data: bool, + lu_options: &ExtractOptions, + r_numbering: &[usize], + t_numbering: &[usize], +) -> ( + DynamicArray, + DynamicArray, + (Duration, Duration), +) +where + LuDecomposition, 2>>: + MatrixLuDecomposition, +{ + let mut sample_r = empty_array(); + let mut sample_n = empty_array(); + let mut near_box = empty_array(); + let mut data_r = empty_array(); + let mut data_n = empty_array(); + let timings = near_box_extraction_into( + ind_r, + near_field_inds, + sketch_data, + subs_sample_dim, + fixed_rank, + conjugate_data, + lu_options, + &mut sample_r, + &mut sample_n, + &mut near_box, + ); + extract_axis_into(&mut data_r, &near_box, r_numbering, 0, true); + extract_axis_into(&mut data_n, &near_box, t_numbering, 0, true); + (data_r, data_n, timings) +} + +pub fn extract_lu_factor_from_blocks( + pivot_block: &mut DynamicArray, + rect_block: &mut DynamicArray, + pivot_method: &PivotMethod, +) -> FactorData +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + let pivot_shape = pivot_block.shape(); + let rect_shape = rect_block.shape(); + let pivot_bytes = matrix_bytes::(pivot_shape[0], pivot_shape[1]); + let rect_bytes = matrix_bytes::(rect_shape[0], rect_shape[1]); + + match pivot_method { + PivotMethod::DirectInversion => { + trace_memory_event( + &format!( + "extract_lu_factor direct reuse pivot block (r={}, samples={})", + pivot_shape[0], pivot_shape[1] + ), + Some(pivot_bytes), + ); + let mut arr = empty_array(); + std::mem::swap(&mut arr, pivot_block); + + trace_memory_event( + &format!( + "extract_lu_factor direct clone pivot -> inv_arr (r={}, samples={})", + pivot_shape[0], pivot_shape[1] + ), + Some(pivot_bytes), + ); + let mut y_r_inv = empty_array(); + y_r_inv.fill_from_resize(arr.r()); + y_r_inv.r_mut().into_inverse_alloc().unwrap(); + + trace_memory_event( + &format!( + "extract_lu_factor direct reuse rect block (t={}, samples={})", + rect_shape[1], rect_shape[0] + ), + Some(rect_bytes), + ); + let mut rectg = empty_array(); + std::mem::swap(&mut rectg, rect_block); + + trace_memory_growth( + &format!( + "extract_lu_factor direct inversion (r={}, t={})", + pivot_shape[0], rect_shape[1] + ), + Some(pivot_bytes * 2 + rect_bytes), + ); + let sq = RegSMat { + arr, + inv_arr: y_r_inv, + }; + let factor = ComposedFactorData { + sq: SquareArr::Reg(sq), + rectg: RectArr { + arr: Box::new(rectg), + }, + }; + FactorData::Comp(factor) + } + PivotMethod::Lu(alpha) => { + trace_memory_event( + &format!( + "extract_lu_factor lu reuse pivot block (r={}, samples={})", + pivot_shape[0], pivot_shape[1] + ), + Some(pivot_bytes), + ); + let mut lu_input = empty_array(); + std::mem::swap(&mut lu_input, pivot_block); + add_diagonal(&mut lu_input, Item::real(*alpha)); + let lu: LuDecomposition, 2>> = + ::into_lu_alloc(lu_input).unwrap(); + let square_factors = SquareLuFactors::from_lu(&lu).unwrap(); + trace_memory_growth( + &format!("extract_lu_factor lu workspace (r={})", pivot_shape[0]), + Some(pivot_bytes * 3 + rect_bytes), + ); + let lu_arr = LuSMat { square_factors }; + trace_memory_event( + &format!( + "extract_lu_factor lu reuse rect block (t={}, samples={})", + rect_shape[1], rect_shape[0] + ), + Some(rect_bytes), + ); + let mut rectg = empty_array(); + std::mem::swap(&mut rectg, rect_block); + let factor = ComposedFactorData { + sq: SquareArr::Lu(lu_arr), + rectg: RectArr { + arr: Box::new(rectg), + }, + }; + FactorData::Comp(factor) + } + PivotMethod::LuHybrid(alpha) => { + trace_memory_event( + &format!( + "extract_lu_factor lu hybrid reuse pivot block (r={}, samples={})", + pivot_shape[0], pivot_shape[1] + ), + Some(pivot_bytes), + ); + let mut lu_input = empty_array(); + std::mem::swap(&mut lu_input, pivot_block); + add_diagonal(&mut lu_input, Item::real(*alpha)); + let lu: LuDecomposition, 2>> = + ::into_lu_alloc(lu_input).unwrap(); + trace_memory_event( + &format!( + "extract_lu_factor lu hybrid reuse rect block (t={}, samples={})", + rect_shape[1], rect_shape[0] + ), + Some(rect_bytes), + ); + let mut rectg = empty_array(); + std::mem::swap(&mut rectg, rect_block); + as MatrixLuDecomposition>::solve_mat( + &lu, + TransMode::NoTrans, + rectg.r_mut(), + ) + .unwrap(); + trace_memory_growth( + &format!( + "extract_lu_factor lu hybrid solve update block (r={})", + pivot_shape[0] + ), + Some(pivot_bytes + rect_bytes), + ); + FactorData::Reg(RectArr { + arr: Box::new(rectg), + }) + } + } +} + +pub fn extract_lu_factor( + data_r: &DynamicArray, + data_n: &DynamicArray, + pivot_method: &PivotMethod, +) -> FactorData +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + let data_r_shape = data_r.shape(); + let data_n_shape = data_n.shape(); + + match pivot_method { + PivotMethod::DirectInversion => { + let data_r_trans_bytes = matrix_bytes::(data_r_shape[1], data_r_shape[0]); + let data_n_trans_bytes = matrix_bytes::(data_n_shape[1], data_n_shape[0]); + + trace_memory_event( + &format!( + "extract_lu_factor direct transpose data_r -> y_r_inv (r={}, samples={})", + data_r_shape[0], data_r_shape[1] + ), + Some(data_r_trans_bytes), + ); + let mut y_r_inv = empty_array(); + y_r_inv.r_mut().fill_from_resize(data_r.r().transpose()); + y_r_inv.r_mut().into_inverse_alloc().unwrap(); + + trace_memory_event( + &format!( + "extract_lu_factor direct transpose data_n -> rectg (t={}, samples={})", + data_n_shape[0], data_n_shape[1] + ), + Some(data_n_trans_bytes), + ); + let mut rectg = empty_array(); + rectg.fill_from_resize(data_n.r().transpose()); + + trace_memory_event( + &format!( + "extract_lu_factor direct transpose data_r -> arr (r={}, samples={})", + data_r_shape[0], data_r_shape[1] + ), + Some(data_r_trans_bytes), + ); + let mut arr = empty_array(); + arr.r_mut().fill_from_resize(data_r.r().transpose()); + trace_memory_growth( + &format!( + "extract_lu_factor direct inversion (r={}, t={})", + data_r_shape[0], data_n_shape[0] + ), + Some( + matrix_bytes::(data_r_shape[1], data_r_shape[0]) * 2 + + matrix_bytes::(data_n_shape[1], data_n_shape[0]), + ), + ); + let sq = RegSMat { + arr, + inv_arr: y_r_inv, + }; + let factor = ComposedFactorData { + sq: SquareArr::Reg(sq), + rectg: RectArr { + arr: Box::new(rectg), + }, + }; + FactorData::Comp(factor) + } + PivotMethod::Lu(alpha) => { + let shape = data_r.shape(); + let data_r_trans_bytes = matrix_bytes::(shape[1], shape[0]); + let data_n_trans_bytes = matrix_bytes::(data_n_shape[1], data_n_shape[0]); + + trace_memory_event( + &format!( + "extract_lu_factor lu transpose data_r -> lu_input (r={}, samples={})", + shape[0], shape[1] + ), + Some(data_r_trans_bytes), + ); + let mut data_r_trans = empty_array(); + data_r_trans.fill_from_resize(data_r.r().transpose()); + add_diagonal(&mut data_r_trans, Item::real(*alpha)); + let lu: LuDecomposition, 2>> = + ::into_lu_alloc(data_r_trans).unwrap(); + let square_factors = SquareLuFactors::from_lu(&lu).unwrap(); + trace_memory_growth( + &format!("extract_lu_factor lu workspace (r={})", shape[0]), + Some( + matrix_bytes::(shape[0], shape[1]) * 3 + + matrix_bytes::(data_n_shape[1], data_n_shape[0]), + ), + ); + let lu_arr = LuSMat { square_factors }; + let sq = SquareArr::Lu(lu_arr); + trace_memory_event( + &format!( + "extract_lu_factor lu transpose data_n -> rectg (t={}, samples={})", + data_n_shape[0], data_n_shape[1] + ), + Some(data_n_trans_bytes), + ); + let mut rectg = empty_array(); + rectg.fill_from_resize(data_n.r().transpose()); + let factor = ComposedFactorData { + sq, + rectg: RectArr { + arr: Box::new(rectg), + }, + }; + FactorData::Comp(factor) + } + PivotMethod::LuHybrid(alpha) => { + let shape = data_r.shape(); + let data_r_trans_bytes = matrix_bytes::(shape[1], shape[0]); + let data_n_trans_bytes = matrix_bytes::(data_n_shape[1], data_n_shape[0]); + + trace_memory_event( + &format!( + "extract_lu_factor lu hybrid transpose data_r -> lu_input (r={}, samples={})", + shape[0], shape[1] + ), + Some(data_r_trans_bytes), + ); + let mut data_r_trans = empty_array(); + data_r_trans.fill_from_resize(data_r.r().transpose()); + add_diagonal(&mut data_r_trans, Item::real(*alpha)); + let lu: LuDecomposition, 2>> = + ::into_lu_alloc(data_r_trans).unwrap(); + trace_memory_event( + &format!( + "extract_lu_factor lu hybrid transpose data_n -> rectg (t={}, samples={})", + data_n_shape[0], data_n_shape[1] + ), + Some(data_n_trans_bytes), + ); + let mut rectg = empty_array(); + rectg.fill_from_resize(data_n.r().transpose()); + as MatrixLuDecomposition>::solve_mat( + &lu, + TransMode::NoTrans, + rectg.r_mut(), + ) + .unwrap(); + trace_memory_growth( + &format!( + "extract_lu_factor lu hybrid solve update block (r={})", + shape[0] + ), + Some(data_r_trans_bytes + data_n_trans_bytes), + ); + FactorData::Reg(RectArr { + arr: Box::new(rectg), + }) + } + } +} diff --git a/src/rsrs/rsrs_factors/rsrs_operator.rs b/src/rsrs/rsrs_factors/rsrs_operator.rs new file mode 100644 index 0000000..c24b5cb --- /dev/null +++ b/src/rsrs/rsrs_factors/rsrs_operator.rs @@ -0,0 +1,1055 @@ +use std::{rc::Rc, time::Instant}; + +use crate::rsrs::{ + rsrs_factors::{ + base_factors::{BaseFactorOptions, CondType}, + commutative_factors::{ + CommutativeFactorsOperations, DiagBoxFactors, FactorType, MulOptions, + MultiLevelIdFactors, PermFactor, RsrsFactors, + }, + }, + sketch::SamplingSpace, +}; +use crate::utils::memory::{matrix_bytes, trace_memory_event}; +use mpi::{ + topology::SimpleCommunicator, + traits::{Communicator, Equivalence}, +}; +use rand_distr::{Distribution, Standard, StandardNormal}; +use rayon::ThreadPoolBuilder; +use rlst::{ + dense::{ + linalg::{interpolative_decomposition::MatrixIdNoSkel, lu::MatrixLu}, + tools::RandScalar, + }, + prelude::*, +}; +use serde::Deserialize; + +/// Chooses whether each RSRS level stores joint factor batches or separate ID +/// and LU batches. +#[derive(Debug, Clone, Deserialize)] +pub enum FactType { + Joint, + Split, +} + +/// Internal description of how a sequence of elementary factors is traversed. +#[derive(PartialEq)] +pub enum RsrsApply { + Sandwich, + Left(FactorType), + Right(FactorType), +} + +impl RsrsApply { + fn get_level_factors_mult(&self, base_options: BaseFactorOptions) -> MulOptions { + match self { + RsrsApply::Sandwich => { + MulOptions { + side: Side::Left, + factor_type: FactorType::F, + base_options, // Originally trans_target = false + } + } + RsrsApply::Left(factor_type) => MulOptions { + side: Side::Left, + factor_type: factor_type.clone(), + base_options, + }, + RsrsApply::Right(factor_type) => MulOptions { + side: Side::Right, + factor_type: factor_type.clone(), + base_options, + }, + } + } +} + +pub struct MultiLevelFactorsMult { + pub side: Side, + pub factor_type: FactorType, + pub trans_target: bool, +} + +/// Backend trait for factor containers that can be exposed through +/// [`RsrsOperator`]. +pub trait RsrsFactorsImpl: Sized { + fn new( + num_levels: usize, + dim: usize, + factorisation_type: &FactType, + num_threads: usize, + ) -> Self; + + fn apply_level< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, + >( + &self, + target_arr: &mut Array, + level_options: &MulOptions, + dec: bool, + level_it: usize, + ) -> (u128, u128); + + fn apply_id_level< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, + >( + &self, + target_arr: &mut Array, + factor_options: &MulOptions, + level_it: usize, + ); + + fn apply_lu_level< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, + >( + &self, + target_arr: &mut Array, + factor_options: &MulOptions, + dec: bool, + level_it: usize, + ); + + fn el_factors_mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, + >( + &self, + target_arr: &mut Array, + mul_type: RsrsApply, + base_options: &BaseFactorOptions, + dec: bool, + ); + + fn matmul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, + >( + &mut self, + target_arr: &mut Array, + side: Side, + base_options: &BaseFactorOptions, //inv: bool, + //trans_target: bool, + ); + + /// Apply the factored operator to a vector. + /// + /// `side` distinguishes the column-vector (`Left`) and row-vector (`Right`) + /// views used by the higher-level operator wrapper. `base_options.trans` + /// refers to the factor orientation, not to the vector layout. + fn matvec(&self, x: &[Item], y: &mut [Item], side: Side, base_options: &BaseFactorOptions); + + fn perm_target_array< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + target_arr: &mut Array, + ); + + fn dim(&self) -> usize; + + #[allow(clippy::type_complexity)] + fn get_condition_numbers( + &self, + ) -> ( + Vec, Option>)>>, + Vec, Option>)>>, + Vec<(CondType, Option>)>, + ); + + fn get_factors(&self) -> &RsrsFactors; +} + +impl< + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr, + > RsrsFactorsImpl for RsrsFactors +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + fn new(num_levels: usize, dim: usize, fact_type: &FactType, num_threads: usize) -> Self { + let id_factors = match fact_type { + FactType::Joint => { + let mut factors: MultiLevelIdFactors = + MultiLevelIdFactors::Batched(Vec::new()); + if let MultiLevelIdFactors::Batched(ref mut v) = factors { + v.resize_with(num_levels, Vec::new); + } + factors + } + FactType::Split => { + let mut factors: MultiLevelIdFactors = + MultiLevelIdFactors::Single(Vec::new()); + if let MultiLevelIdFactors::Single(ref mut v) = factors { + v.resize_with(num_levels, Vec::new); + } + factors + } + }; + + let mut lu_factors = Vec::new(); + lu_factors.resize_with(num_levels, Vec::new); + let mut near_field_inds: Vec>> = Vec::new(); + near_field_inds.resize_with(num_levels, Vec::new); + let orig_indices = Vec::new(); + let perm_indices = Vec::new(); + let perm_factor = PermFactor::new(orig_indices, perm_indices).unwrap(); + let diag_box_factors = DiagBoxFactors::new(); + Self { + num_levels, + //near_field_inds, + id_factors, + lu_factors, + perm_factor, + diag_box_factors, + dim, + fact_type: fact_type.clone(), + num_threads, + } + } + + fn dim(&self) -> usize { + self.dim + } + + fn apply_level< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, + >( + &self, + target_arr: &mut Array, + level_options: &MulOptions, + dec: bool, + level_it: usize, + ) -> (u128, u128) { + let thread_pool = ThreadPoolBuilder::new() + .num_threads(self.num_threads) + .build() + .unwrap(); + let mut id_time = 0; + let mut lu_time = 0; + match &self.id_factors { + MultiLevelIdFactors::Single(_id_batches) => { + panic!("Apply level is only for joint steps") + } + MultiLevelIdFactors::Batched(batched_factors) => { + if let Some(id_batches) = batched_factors.get(level_it) { + let num_id_batches = id_batches.len(); + if dec { + (0..num_id_batches).rev().for_each(|batch_ind| { + let start = Instant::now(); + let lu_batch = &self.lu_factors[level_it][batch_ind]; + lu_batch.mul(target_arr, &thread_pool, self.num_threads, level_options); + lu_time += start.elapsed().as_millis(); + + let start = Instant::now(); + let id_batch = &id_batches[batch_ind]; + id_batch.mul(target_arr, &thread_pool, self.num_threads, level_options); + id_time += start.elapsed().as_millis(); + }); + } else { + (0..num_id_batches).for_each(|batch_ind| { + let start = Instant::now(); + let id_batch = &id_batches[batch_ind]; + id_batch.mul(target_arr, &thread_pool, self.num_threads, level_options); + id_time += start.elapsed().as_millis(); + + let start = Instant::now(); + let lu_batch = &self.lu_factors[level_it][batch_ind]; + lu_batch.mul(target_arr, &thread_pool, self.num_threads, level_options); + lu_time += start.elapsed().as_millis(); + }); + } + } + } + } + (id_time, lu_time) + } + + fn apply_id_level< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, + >( + &self, + target_arr: &mut Array, + factor_options: &MulOptions, + level_it: usize, + ) { + let thread_pool = ThreadPoolBuilder::new() + .num_threads(self.num_threads) + .build() + .unwrap(); + match &self.id_factors { + MultiLevelIdFactors::Single(id_batches) => { + if let Some(id_batch) = id_batches.get(level_it) { + id_batch.mul(target_arr, &thread_pool, self.num_threads, factor_options); + } + } + MultiLevelIdFactors::Batched(_batched_factors) => { + panic!("Apply ID level is only for split steps") + } + } + } + + fn apply_lu_level< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, + >( + &self, + target_arr: &mut Array, + factor_options: &MulOptions, + dec: bool, + level_it: usize, + ) { + let thread_pool = ThreadPoolBuilder::new() + .num_threads(self.num_threads) + .build() + .unwrap(); + let num_lu_batches = self.lu_factors[level_it].len(); + + if dec { + (0..num_lu_batches).rev().for_each(|batch_ind| { + let lu_batch = &self.lu_factors[level_it][batch_ind]; + lu_batch.mul(target_arr, &thread_pool, self.num_threads, factor_options); + }); + } else { + (0..num_lu_batches).for_each(|batch_ind| { + let lu_batch = &self.lu_factors[level_it][batch_ind]; + lu_batch.mul(target_arr, &thread_pool, self.num_threads, factor_options); + }); + } + } + + fn el_factors_mul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, + >( + &self, + target_arr: &mut Array, + mul_type: RsrsApply, + base_options: &BaseFactorOptions, + dec: bool, + ) { + let levels = (0..self.num_levels).collect::>(); + if matches!(mul_type, RsrsApply::Sandwich) { + let left_options = MulOptions { + base_options: base_options.clone(), + side: Side::Left, + factor_type: FactorType::F, + }; + let right_options = MulOptions { + base_options: base_options.clone(), + side: Side::Right, + factor_type: FactorType::S, + }; + + match self.fact_type { + FactType::Joint => levels.iter().for_each(|&level_it| { + self.apply_level(target_arr, &left_options, dec, level_it); + self.apply_level(target_arr, &right_options, dec, level_it); + }), + FactType::Split => levels.iter().for_each(|&level_it| { + self.apply_id_level(target_arr, &left_options, level_it); + self.apply_id_level(target_arr, &right_options, level_it); + self.apply_lu_level(target_arr, &left_options, dec, level_it); + self.apply_lu_level(target_arr, &right_options, dec, level_it); + }), + } + } else { + let mul_options = mul_type.get_level_factors_mult(base_options.clone()); + match self.fact_type { + FactType::Joint => { + if dec { + levels.iter().rev().for_each(|&level_it| { + self.apply_level(target_arr, &mul_options, dec, level_it); + }); + } else { + levels.iter().for_each(|&level_it| { + self.apply_level(target_arr, &mul_options, dec, level_it); + }); + } + } + FactType::Split => { + if dec { + levels.iter().rev().for_each(|&level_it| { + self.apply_lu_level(target_arr, &mul_options, dec, level_it); + self.apply_id_level(target_arr, &mul_options, level_it); + }); + } else { + levels.iter().for_each(|&level_it| { + self.apply_id_level(target_arr, &mul_options, level_it); + self.apply_lu_level(target_arr, &mul_options, dec, level_it); + }); + } + } + } + } + } + + fn matmul< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, + >( + &mut self, + target_arr: &mut Array, + side: Side, + base_options: &BaseFactorOptions, + ) { + let thread_pool = ThreadPoolBuilder::new() + .num_threads(self.num_threads) + .build() + .unwrap(); + let diag_mul = MulOptions { + base_options: base_options.clone(), + side, + factor_type: FactorType::F, //TODO: CHECK IF CORRECT + }; + // Transposing the full operator reverses the elimination order in the + // same way as inversion. The low-level kernels still receive the + // original `base_options`; this flag only chooses the traversal order. + let effective_inv = base_options.inv ^ base_options.trans_val(); + + match side { + Side::Left => { + let (mul_type_1, mul_type_2) = if !effective_inv { + ( + RsrsApply::Left(FactorType::S), //, trans_target), + RsrsApply::Left(FactorType::F), //, trans_target), + ) + } else { + ( + RsrsApply::Left(FactorType::F), //, trans_target), + RsrsApply::Left(FactorType::S), //, trans_target), + ) + }; + self.el_factors_mul(target_arr, mul_type_1, base_options, false); + self.diag_box_factors + .mul(target_arr, &thread_pool, self.num_threads, &diag_mul); + self.el_factors_mul(target_arr, mul_type_2, base_options, true); + } + Side::Right => { + let (mul_type_1, mul_type_2) = if !effective_inv { + ( + RsrsApply::Right(FactorType::F), //, trans_target), + RsrsApply::Right(FactorType::S), //, trans_target), + ) + } else { + ( + RsrsApply::Right(FactorType::S), //, trans_target), + RsrsApply::Right(FactorType::F), //, trans_target), + ) + }; + + self.el_factors_mul(target_arr, mul_type_1, base_options, false); + self.diag_box_factors + .mul(target_arr, &thread_pool, self.num_threads, &diag_mul); + self.el_factors_mul(target_arr, mul_type_2, base_options, true); + } + } + } + + fn matvec(&self, x: &[Item], y: &mut [Item], side: Side, base_options: &BaseFactorOptions) { + if base_options.trans_val() && !base_options.trans_target { + // For vectors, A^T x is (x^T A)^T, so reuse the opposite-side path + // instead of pushing a transpose through each elementary factor. + let mut normalized_options = base_options.clone(); + normalized_options.trans = TransMode::NoTrans; + let normalized_side = match side { + Side::Left => Side::Right, + Side::Right => Side::Left, + }; + self.matvec(x, y, normalized_side, &normalized_options); + return; + } + + let thread_pool = ThreadPoolBuilder::new() + .num_threads(self.num_threads) + .build() + .unwrap(); + let diag_mul = MulOptions { + base_options: base_options.clone(), + side, + factor_type: FactorType::F, //TODO: CHECK IF CORRECT + }; + // Direct factor matvec diagnostics exercise the transposed factor path. + // Just as in `matmul`, transposing the operator reverses the order of + // the elementary elimination steps. + let effective_inv = base_options.inv ^ base_options.trans_val(); + + let target_arr = match side { + Side::Left => { + let mut target_arr = rlst_dynamic_array2!(Item, [x.len(), 1]); + for (i, val) in x.iter().enumerate() { + target_arr.r_mut()[[i, 0]] = *val; + } + + let (mul_type_1, mul_type_2) = if !effective_inv { + ( + RsrsApply::Left(FactorType::S), //, false), + RsrsApply::Left(FactorType::F), //, false), + ) + } else { + ( + RsrsApply::Left(FactorType::F), //, false), + RsrsApply::Left(FactorType::S), //, false), + ) + }; + + self.el_factors_mul(&mut target_arr, mul_type_1, base_options, false); + self.diag_box_factors.mul( + &mut target_arr, + &thread_pool, + self.num_threads, + &diag_mul, + ); + self.el_factors_mul(&mut target_arr, mul_type_2, base_options, true); + target_arr + } + Side::Right => { + let mut target_arr = rlst_dynamic_array2!(Item, [1, x.len()]); + + for (i, val) in x.iter().enumerate() { + target_arr.r_mut()[[0, i]] = *val; + } + + let (mul_type_1, mul_type_2) = if !effective_inv { + ( + RsrsApply::Right(FactorType::F), //, false), + RsrsApply::Right(FactorType::S), //, false), + ) + } else { + ( + RsrsApply::Right(FactorType::S), //, false), + RsrsApply::Right(FactorType::F), //, false), + ) + }; + + self.el_factors_mul(&mut target_arr, mul_type_1, base_options, false); + self.diag_box_factors.mul( + &mut target_arr, + &thread_pool, + self.num_threads, + &diag_mul, + ); + self.el_factors_mul(&mut target_arr, mul_type_2, base_options, true); + target_arr + } + }; + + for (i, val) in target_arr.r().iter().enumerate() { + y[i] = val; + } + } + + fn perm_target_array< + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, + >( + &self, + target_arr: &mut Array, + ) { + let mul_options = BaseFactorOptions { + inv: false, + trans: TransMode::NoTrans, + trans_target: false, + }; + self.perm_factor.left_mul(target_arr, &mul_options); + let shape = target_arr.shape(); + trace_memory_event( + &format!( + "perm_target_array transpose copy (rows={}, cols={})", + shape[1], shape[0] + ), + Some(matrix_bytes::(shape[1], shape[0])), + ); + let mut aux_arr = empty_array(); + aux_arr.r_mut().fill_from_resize(target_arr.r().transpose()); + self.perm_factor.left_mul(&mut aux_arr, &mul_options); + target_arr.r_mut().fill_from(aux_arr.r().transpose()); + } + + fn get_condition_numbers( + &self, + ) -> ( + Vec, Option>)>>, + Vec, Option>)>>, + Vec<(CondType, Option>)>, + ) { + let mut id_condition_numbers = Vec::new(); + let mut lu_condition_numbers = Vec::new(); + + match &self.id_factors { + MultiLevelIdFactors::Single(id_batches) => { + for id_batch in id_batches.iter() { + id_condition_numbers.push(id_batch.get_condition_numbers()); + } + } + MultiLevelIdFactors::Batched(batched_factors) => { + for batch in batched_factors.iter() { + for id_batch in batch.iter() { + id_condition_numbers.push(id_batch.get_condition_numbers()); + } + } + } + } + + for lu_level_batches in self.lu_factors.iter() { + let mut lu_level_condition_numbers = Vec::new(); + for lu_batch in lu_level_batches.iter() { + lu_level_condition_numbers.extend_from_slice(&lu_batch.get_condition_numbers()); + } + lu_condition_numbers.push(lu_level_condition_numbers); + } + + let diag_condition_numbers = self.diag_box_factors.get_condition_numbers(); + + ( + id_condition_numbers, + lu_condition_numbers, + diag_condition_numbers, + ) + } + + fn get_factors(&self) -> &Self { + self + } +} + +impl Shape<2> for RsrsFactors { + fn shape(&self) -> [usize; 2] { + [self.dim, self.dim] + } +} + +/// Operator wrapper that exposes RSRS factors through the `rlst` operator +/// traits. +pub struct RsrsOperator< + 'a, + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr, + Space: SamplingSpace, + Op: RsrsFactorsImpl + Shape<2>, +> { + pub op: &'a Op, + domain: Rc, + range: Rc, + inv: bool, +} + +impl< + 'a, + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr, + Space: SamplingSpace + LinearSpace, + Op: RsrsFactorsImpl + Shape<2>, + > RsrsOperator<'a, Item, Space, Op> +{ + /// Returns the underlying RSRS factors. + pub fn get_factors(&self) -> &RsrsFactors { + self.op.get_factors() + } + + /// Returns condition-number diagnostics gathered while building the + /// factors. + #[allow(clippy::type_complexity)] + pub fn get_condition_numbers( + &self, + ) -> ( + Vec, Option>)>>, + Vec, Option>)>>, + Vec<(CondType, Option>)>, + ) { + self.op.get_condition_numbers() + } + + /// Normalizes `rlst::TransMode` into the lower-level factor interface. + /// + /// The factor kernels only need plain transpose orientation together with a + /// left/right application choice. Vector transpose and conjugate-transpose + /// products are therefore rewritten using equivalent left/right identities + /// plus explicit input/output conjugation when needed. + fn apply_vec_mode(&self, x: &[Item], y: &mut [Item], trans_mode: TransMode) { + let base_options = BaseFactorOptions { + inv: self.inv, + trans: TransMode::NoTrans, + trans_target: false, + }; + + match trans_mode { + TransMode::NoTrans => { + self.op.matvec(x, y, Side::Left, &base_options); + } + TransMode::Trans => { + // A^T x = (x^T A)^T, so vector transpose application can reuse + // the right-apply path without pushing `Trans` into each factor. + self.op.matvec(x, y, Side::Right, &base_options); + } + TransMode::ConjNoTrans => { + // conj(A) x = conj(A conj(x)) + let input = x.iter().map(|value| value.conj()).collect::>(); + self.op.matvec(&input, y, Side::Left, &base_options); + y.iter_mut().for_each(|value| *value = value.conj()); + } + TransMode::ConjTrans => { + // conj(A^T) x = conj((conj(x)^T A)^T) + let input = x.iter().map(|value| value.conj()).collect::>(); + self.op.matvec(&input, y, Side::Right, &base_options); + y.iter_mut().for_each(|value| *value = value.conj()); + } + } + } +} + +// Implement OperatorBase for RsrsOperator so it can be used with rlst::Operator +impl< + 'a, + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr, + Space: SamplingSpace + LinearSpace, + Op: RsrsFactorsImpl + Shape<2>, + > OperatorBase for RsrsOperator<'a, Item, Space, Op> +{ + type Domain = Space; + type Range = Space; + + fn domain(&self) -> Rc { + self.domain.clone() + } + + fn range(&self) -> Rc { + self.range.clone() + } +} + +impl< + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr, + Space: SamplingSpace, + Op: RsrsFactorsImpl + Shape<2>, + > std::fmt::Debug for RsrsOperator<'_, Item, Space, Op> +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let shape = self.op.shape(); + write!(f, "RsrsOperator: [{}x{}]", shape[0], shape[1]).unwrap(); + Ok(()) + } +} + +/// Helper trait for constructing an [`RsrsOperator`] from already-allocated +/// domain/range spaces. +pub trait LocalFromSpaces< + 'a, + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr, + Space, + Op, +>: Sized +{ + fn from_local_spaces(op: &'a Op, domain: Rc, range: Rc) -> Self; +} + +/// Trait used by iterative solvers to toggle inverse application on an +/// operator. +pub trait Inv { + fn inv(&mut self, inv: bool); +} + +impl< + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr, + Space: SamplingSpace, + Op: RsrsFactorsImpl + Shape<2>, + > Inv for RsrsOperator<'_, Item, Space, Op> +where + StandardNormal: Distribution<::Real>, + Standard: Distribution<::Real>, + ::Real: RandScalar, +{ + fn inv(&mut self, inv: bool) { + self.inv = inv; + } +} + +impl< + 'a, + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr, + Op: RsrsFactorsImpl + Shape<2>, + > LocalFromSpaces<'a, Item, ArrayVectorSpace, Op> + for RsrsOperator<'a, Item, ArrayVectorSpace, Op> +where + StandardNormal: Distribution<::Real>, + Standard: Distribution<::Real>, + ::Real: RandScalar, +{ + fn from_local_spaces( + op: &'a Op, + domain: Rc>, + range: Rc>, + ) -> Self { + RsrsOperator { + op, + domain: domain.clone(), + range: range.clone(), + inv: false, + } + } +} + +impl< + 'a, + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr + + Equivalence, + Op: RsrsFactorsImpl + Shape<2>, + > LocalFromSpaces<'a, Item, DistributedArrayVectorSpace<'a, SimpleCommunicator, Item>, Op> + for RsrsOperator<'a, Item, DistributedArrayVectorSpace<'a, SimpleCommunicator, Item>, Op> +where + StandardNormal: Distribution<::Real>, + Standard: Distribution<::Real>, + ::Real: RandScalar, +{ + fn from_local_spaces( + op: &'a Op, + domain: Rc>, + range: Rc>, + ) -> Self { + RsrsOperator { + op, + domain: domain.clone(), + range: range.clone(), + inv: false, + } + } +} + +impl< + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr, + Op: RsrsFactorsImpl + Shape<2>, + > AsApply for RsrsOperator<'_, Item, ArrayVectorSpace, Op> +where + ::Real: RandScalar, + StandardNormal: Distribution<::Real>, + Standard: Distribution<::Real>, +{ + fn apply_extended< + ContainerIn: ElementContainer::E>, + ContainerOut: ElementContainerMut::E>, + >( + &self, + _alpha: ::F, + x: Element, + _beta: ::F, + mut y: Element, + trans_mode: TransMode, + ) { + self.apply_vec_mode( + x.imp().view().data(), + y.imp_mut().view_mut().data_mut(), + trans_mode, + ); + } + + fn apply::E>>( + &self, + x: Element, + trans_mode: rlst::TransMode, + ) -> rlst::operator::ElementType<::E> { + let mut y = zero_element(self.range()); + self.apply_extended( + <::F as num::One>::one(), + x, + <::F as num::Zero>::zero(), + y.r_mut(), + trans_mode, + ); + y + } +} + +impl< + C: Communicator, + Item: RlstScalar + + MatrixInverse + + MatrixId + + MatrixIdNoSkel + + MatrixPseudoInverse + + MatrixLu + + RandScalar + + MatrixQr + + Equivalence, + Op: RsrsFactorsImpl + Shape<2>, + > AsApply for RsrsOperator<'_, Item, DistributedArrayVectorSpace<'_, C, Item>, Op> +where + ::Real: RandScalar, + StandardNormal: Distribution<::Real>, + Standard: Distribution<::Real>, +{ + fn apply_extended< + ContainerIn: ElementContainer::E>, + ContainerOut: ElementContainerMut::E>, + >( + &self, + _alpha: ::F, + x: Element, + _beta: ::F, + mut y: Element, + trans_mode: TransMode, + ) { + self.apply_vec_mode( + x.imp().view().local().data(), + y.imp_mut().view_mut().local_mut().data_mut(), + trans_mode, + ); + } + + fn apply::E>>( + &self, + x: Element, + trans_mode: rlst::TransMode, + ) -> rlst::operator::ElementType<::E> { + let mut y = zero_element(self.range()); + self.apply_extended( + <::F as num::One>::one(), + x, + <::F as num::Zero>::zero(), + y.r_mut(), + trans_mode, + ); + y + } +} diff --git a/src/rsrs/sketch.rs b/src/rsrs/sketch.rs index 7017b88..787cc62 100644 --- a/src/rsrs/sketch.rs +++ b/src/rsrs/sketch.rs @@ -1,21 +1,29 @@ -use super::rsrs_factors::{ - CommutativeFactors, CommutativeFactorsOperations, FactorType, MulOptions, RsrsFactors, - RsrsFactorsImpl, -}; +use crate::rsrs::rsrs_factors::base_factors::BaseFactorOptions; +use crate::rsrs::rsrs_factors::commutative_factors::CommutativeFactors; +use crate::rsrs::rsrs_factors::commutative_factors::CommutativeFactorsOperations; +use crate::rsrs::rsrs_factors::commutative_factors::FactorType; +use crate::rsrs::rsrs_factors::commutative_factors::MulOptions; +use crate::rsrs::rsrs_factors::commutative_factors::RsrsFactors; +use crate::rsrs::rsrs_factors::rsrs_operator::FactType; +use crate::rsrs::rsrs_factors::rsrs_operator::RsrsFactorsImpl; +use crate::utils::io::IOData; +use crate::utils::linear_algebra::streaming_chunk_rows; use mpi::traits::Communicator; use mpi::traits::Equivalence; use rand::Rng; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; use rand_distr::{Distribution, Standard, StandardNormal}; -use rlst::dense::linalg::lu::MatrixLu; +use rayon::ThreadPool; +use rlst::dense::linalg::{interpolative_decomposition::MatrixIdNoSkel, lu::MatrixLu}; use rlst::operator::ConcreteElementContainer; +use rlst::{dense::array::reference::ArrayRef, dense::array::views::ArraySubView}; pub use rlst::{ dense::{array::empty_array, tools::RandScalar}, prelude::*, }; -use std::time::{SystemTime, UNIX_EPOCH}; -use std::{cell::RefCell, time::Instant}; +use serde::{Deserialize, Serialize}; +use std::{path::Path, time::Instant}; pub enum UpdateType<'a, Item: RlstScalar> { Lu(&'a CommutativeFactors), Id(&'a CommutativeFactors), @@ -32,7 +40,52 @@ pub struct SketchData { pub test: DynamicArray, pub dim: usize, pub num_samples: usize, - pub trans: bool, + pub trans: TransMode, +} + +type SketchArrayRef<'a, Item> = ArrayRef<'a, Item, BaseArray, 2>, 2>; +pub type SketchChunkView<'a, Item> = + Array, 2>, 2>; + +pub struct SampleChunk<'a, Item: RlstScalar> { + pub row_offset: usize, + pub test: SketchChunkView<'a, Item>, + pub sketch: SketchChunkView<'a, Item>, +} + +pub struct SampleChunkIter<'a, Item: RlstScalar> { + data: &'a SketchData, + subs_sample_dim: usize, + chunk_rows: usize, + next_row: usize, +} + +impl<'a, Item: RlstScalar> Iterator for SampleChunkIter<'a, Item> { + type Item = SampleChunk<'a, Item>; + + fn next(&mut self) -> Option { + if self.next_row >= self.subs_sample_dim { + return None; + } + + let row_offset = self.next_row; + let rows = (row_offset + self.chunk_rows).min(self.subs_sample_dim) - row_offset; + self.next_row += rows; + + Some(SampleChunk { + row_offset, + test: self + .data + .test + .r() + .into_subview([row_offset, 0], [rows, self.data.dim]), + sketch: self + .data + .sketch + .r() + .into_subview([row_offset, 0], [rows, self.data.dim]), + }) + } } pub struct FullBoxesData { @@ -40,7 +93,7 @@ pub struct FullBoxesData { pub z_data: SketchData, pub dim: usize, pub active_samples: usize, - pub hermitian: bool, + pub symmetric: bool, } pub enum SampleType { @@ -70,7 +123,18 @@ pub trait SamplingSpace: LinearSpace { x: &Element>, other: &mut Array, offset: usize, + trans: TransMode, ); + + fn clone_vec( + &self, + other: &Element>, + ) -> Element>; + + fn conj_vec( + &self, + other: &Element>, + ) -> Element>; } impl SamplingSpace for ArrayVectorSpace @@ -110,8 +174,43 @@ where x: &Element>, other: &mut Array, offset: usize, + trans: TransMode, ) { - other.r_mut().slice(0, offset).fill_from(x.view()); + match trans { + TransMode::NoTrans => other.r_mut().slice(0, offset).fill_from(x.view()), + TransMode::ConjNoTrans => todo!(), + TransMode::Trans => other.r_mut().slice(0, offset).fill_from(x.view().conj()), + TransMode::ConjTrans => todo!(), + }; + } + + fn clone_vec( + &self, + other: &Element>, + ) -> Element> { + let mut new = + Element::>::new(Self::E::new(other.space())); + new.view_mut().fill_from(other.view()); + + new + } + + fn conj_vec( + &self, + other: &Element>, + ) -> Element> { + let mut aux_array = rlst_dynamic_array2!(Item, [self.dimension(), 1]); + + aux_array.r_mut().slice(1, 0).fill_from(other.view()); + + let mut new = + Element::>::new(Self::E::new(other.space())); + + new.view_mut() + .iter_mut() + .enumerate() + .for_each(|(i, val)| *val = aux_array.r().data()[i].conj()); + new } } @@ -159,57 +258,146 @@ where x: &Element>, other: &mut Array, offset: usize, + trans: TransMode, ) { - other + match trans { + TransMode::NoTrans => other + .r_mut() + .slice(0, offset) + .fill_from(x.view().local().r()), + TransMode::ConjNoTrans => todo!(), + TransMode::Trans => other + .r_mut() + .slice(0, offset) + .fill_from(x.view().local().r().conj()), + TransMode::ConjTrans => todo!(), + }; + } + + fn clone_vec( + &self, + other: &Element>, + ) -> Element> { + let mut new = + Element::>::new(Self::E::new(other.space())); + new.view_mut() + .local_mut() + .fill_from(other.view().local().r()); + + new + } + + fn conj_vec( + &self, + other: &Element>, + ) -> Element> { + let mut aux_array = rlst_dynamic_array2!(Item, [self.dimension(), 1]); + + aux_array .r_mut() - .slice(0, offset) - .fill_from(x.view().local().r()); + .slice(1, 0) + .fill_from(other.view().local().r()); + + let mut new = + Element::>::new(Self::E::new(other.space())); + + new.view_mut() + .local_mut() + .iter_mut() + .enumerate() + .for_each(|(i, val)| *val = aux_array.r().data()[i].conj()); + new } } -thread_local! { - static THREAD_RNG: RefCell = RefCell::new(init_rng()); +pub(crate) fn mix_seed(mut seed: u64) -> u64 { + seed = seed.wrapping_add(0x9E3779B97F4A7C15); + seed = (seed ^ (seed >> 30)).wrapping_mul(0xBF58476D1CE4E5B9); + seed = (seed ^ (seed >> 27)).wrapping_mul(0x94D049BB133111EB); + seed ^ (seed >> 31) } -fn init_rng() -> ChaCha8Rng { - let time_seed = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_nanos(); - let seed = (time_seed as u64).wrapping_mul(0x9E3779B97F4A7C15); // or add thread ID if needed - ChaCha8Rng::seed_from_u64(seed) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "value")] +pub enum Shift { + True(f64), //TODO: Change to Item + False, } -pub fn with_thread_rng(f: F) -> R -where - F: FnOnce(&mut ChaCha8Rng) -> R, -{ - THREAD_RNG.with(|rng_cell| { - let mut rng = rng_cell.borrow_mut(); - f(&mut rng) - }) +pub(crate) fn shift_alpha(shift: &Shift) -> f64 { + match shift { + Shift::True(alpha) => *alpha, + Shift::False => 0.0, + } } -fn resize_rows< - Item: RlstScalar, - ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Stride<2> + RawAccessMut + Shape<2>, ->( - arr: &Array, - new_shape: [usize; 2], -) -> DynamicArray { - let mut new_arr = rlst_dynamic_array2!(Item, new_shape); - new_arr - .r_mut() - .into_subview([0, 0], arr.shape()) - .fill_from(arr.r()); - - new_arr +pub(crate) fn apply_shift_delta( + sketch: &mut DynamicArray, + test: &DynamicArray, + delta: f64, +) { + if delta.abs() <= f64::EPSILON { + return; + } + + let delta_item = Item::from(delta).unwrap(); + sketch + .data_mut() + .iter_mut() + .zip(test.data().iter()) + .for_each(|(sketch_val, test_val)| { + *sketch_val += delta_item * *test_val; + }); +} + +impl SketchData { + pub fn new(dim: usize, trans: TransMode) -> Self { + let test: Array, 2>, 2> = empty_array(); + let sketch: Array, 2>, 2> = empty_array(); + Self { + sketch, + test, + dim, + num_samples: 0, + trans, + } + } + + pub fn trans_val(&self) -> bool { + match self.trans { + TransMode::NoTrans => false, + TransMode::ConjNoTrans => false, + TransMode::Trans => true, + TransMode::ConjTrans => true, + } + } + + pub fn chunk_iter( + &self, + subs_sample_dim: usize, + cols_per_row: usize, + live_buffers: usize, + ) -> SampleChunkIter<'_, Item> { + let subs_sample_dim = subs_sample_dim + .min(self.test.shape()[0]) + .min(self.sketch.shape()[0]); + let chunk_rows = + streaming_chunk_rows::(subs_sample_dim, cols_per_row, live_buffers).max(1); + + SampleChunkIter { + data: self, + subs_sample_dim, + chunk_rows, + next_row: 0, + } + } } impl< Item: RlstScalar + RandScalar + MatrixId + + MatrixIdNoSkel + MatrixInverse + MatrixPseudoInverse + MatrixLu @@ -222,19 +410,8 @@ where MatrixLuDecomposition, TriangularMatrix: TriangularOperations, ::Real: RandScalar, + Item: IOData, { - pub fn new(dim: usize, trans: bool) -> Self { - let test: Array, 2>, 2> = empty_array(); - let sketch: Array, 2>, 2> = empty_array(); - Self { - sketch, - test, - dim, - num_samples: 0, - trans, - } - } - pub fn add_samples< Space: SamplingSpace, OpImpl: AsApply, @@ -242,19 +419,19 @@ where &mut self, extra_num_samples: usize, operator: Operator, - _seed: u64, + shift: &Shift, + save_samples: bool, + sample_storage_dir: Option<&Path>, + seed: u64, ) -> u128 { let sampling_start: Instant = Instant::now(); let test_shape = self.test.shape(); let total_samples = test_shape[0] + extra_num_samples; - let trans_mode = if self.trans { - TransMode::ConjTrans - } else { - TransMode::NoTrans - }; - self.test = resize_rows(&self.test, [total_samples, self.dim]); - self.sketch = resize_rows(&self.sketch, [total_samples, self.dim]); + // Preserve existing samples while letting the backing Vec grow amortized + // instead of rebuilding a fresh matrix on every resize. + self.test.resize_in_place([total_samples, self.dim]); + self.sketch.resize_in_place([total_samples, self.dim]); let mut sample_generation = std::time::Duration::ZERO; let mut multiplication = std::time::Duration::ZERO; @@ -263,29 +440,42 @@ where let start: Instant = Instant::now(); let offset = test_shape[0] + row; let mut chunk_test_vec = SamplingSpace::zero(operator.r().domain()); - - with_thread_rng(|rng| { - operator.domain().sampling( - &mut chunk_test_vec, - rng, - SampleType::RealStandardNormal, - ); - }); + let row_seed = mix_seed( + seed ^ (offset as u64).wrapping_mul(0x9E3779B97F4A7C15) + ^ (self.dim as u64).rotate_left(21) + ^ u64::from(self.trans_val()), + ); + let mut rng = ChaCha8Rng::seed_from_u64(row_seed); + operator.domain().sampling( + &mut chunk_test_vec, + &mut rng, + SampleType::RealStandardNormal, + ); sample_generation += start.elapsed(); let start: Instant = Instant::now(); - let chunk_sketch_vec: Element> = - operator.apply(chunk_test_vec.r(), trans_mode); + + let chunk_sketch_vec = match shift { + Shift::True(alpha) => { + let mut chunk_sketch_vec_stab = operator.domain().clone_vec(&chunk_test_vec); + chunk_sketch_vec_stab.scale_inplace(Item::from(*alpha).unwrap()); + chunk_sketch_vec_stab + .sum_inplace(operator.apply(chunk_test_vec.r(), self.trans)); + chunk_sketch_vec_stab + } + Shift::False => operator.apply(chunk_test_vec.r(), self.trans), + }; + multiplication += start.elapsed(); let start: Instant = Instant::now(); operator .domain() - .fill_array(&chunk_test_vec, &mut self.test, offset); + .fill_array(&chunk_test_vec, &mut self.test, offset, self.trans); operator .domain() - .fill_array(&chunk_sketch_vec, &mut self.sketch, offset); + .fill_array(&chunk_sketch_vec, &mut self.sketch, offset, self.trans); filling += start.elapsed(); if (row + 1) % 30 == 0 { @@ -302,6 +492,62 @@ where filling = std::time::Duration::ZERO; } }); + + if save_samples { + let save_start = Instant::now(); + let (test_base, sketch_base) = if self.trans_val() { + ("z_test_file", "z_sketch_file") + } else { + ("y_test_file", "y_sketch_file") + }; + // Persist canonical unshifted sketches on disk so saved samples can be + // reused across runs with different operator shifts. + let current_shift = shift_alpha(shift); + let test_view = self + .test + .r() + .into_subview([test_shape[0], 0], [extra_num_samples, self.dim]); + let _ = + >::append_in_dir(&test_view, test_base, sample_storage_dir); + + if current_shift.abs() > f64::EPSILON { + let mut test_sv = empty_array(); + test_sv.r_mut().fill_from_resize( + self.test + .r() + .into_subview([test_shape[0], 0], [extra_num_samples, self.dim]), + ); + let mut sketch_sv: Array, 2>, 2> = + empty_array(); + sketch_sv.r_mut().fill_from_resize( + self.sketch + .r() + .into_subview([test_shape[0], 0], [extra_num_samples, self.dim]), + ); + apply_shift_delta(&mut sketch_sv, &test_sv, -current_shift); + let _ = >::append_in_dir( + &sketch_sv, + sketch_base, + sample_storage_dir, + ); + } else { + let sketch_view = self + .sketch + .r() + .into_subview([test_shape[0], 0], [extra_num_samples, self.dim]); + let _ = >::append_in_dir( + &sketch_view, + sketch_base, + sample_storage_dir, + ); + } + + println!( + "{} samples saved in {:.3}s", + extra_num_samples, + save_start.elapsed().as_secs_f64() + ) + } let duration = sampling_start.elapsed(); self.num_samples = test_shape[0] + extra_num_samples; //TODO: Change this to total_samples @@ -309,13 +555,23 @@ where duration.as_millis() } + #[allow(clippy::too_many_arguments)] pub fn update_samples( &mut self, update_start: usize, samples_to_update: usize, level: usize, update_type: &UpdateType, + fact_type: &FactType, + thread_pool: &ThreadPool, + num_threads: usize, ) -> (u128, u128) { + let (factor_1, factor_2) = if self.trans_val() { + (FactorType::S, FactorType::F) + } else { + (FactorType::F, FactorType::S) + }; + let (mut sub_test, mut sub_sketch) = ( self.test .r_mut() @@ -328,12 +584,6 @@ where let mut id_time = 0_u128; let mut lu_time = 0_u128; - let (factor_1, factor_2) = if !self.trans { - (FactorType::F, FactorType::S) - } else { - (FactorType::S, FactorType::F) - }; - match update_type { UpdateType::Lu(lu_batch) => { lu_time += update_lu_level( @@ -344,6 +594,8 @@ where &factor_1, &factor_2, self.trans, + thread_pool, + num_threads, ); } UpdateType::Id(id_batch) => { @@ -355,10 +607,25 @@ where &factor_1, &factor_2, self.trans, + thread_pool, + num_threads, ); } - UpdateType::Both(rsrs_factors) => { - (0..level).for_each(|level_it| { + UpdateType::Both(rsrs_factors) => match fact_type { + FactType::Joint => (0..level).for_each(|level_it| { + let (loc_id_time, loc_lu_time) = update_level( + &mut sub_sketch, + &mut sub_test, + level_it, + BatchUpdateType::Multi(rsrs_factors), + &factor_1, + &factor_2, + self.trans, + ); + id_time += loc_id_time; + lu_time += loc_lu_time; + }), + FactType::Split => (0..level).for_each(|level_it| { id_time += update_id_level( &mut sub_sketch, &mut sub_test, @@ -367,6 +634,8 @@ where &factor_1, &factor_2, self.trans, + thread_pool, + num_threads, ); lu_time += update_lu_level( @@ -377,17 +646,27 @@ where &factor_1, &factor_2, self.trans, + thread_pool, + num_threads, ); - }); - } + }), + }, } (id_time, lu_time) } } +#[allow(clippy::too_many_arguments)] pub fn update_id_level< - Item: RlstScalar + RandScalar + MatrixId + MatrixInverse + MatrixPseudoInverse + MatrixLu + MatrixQr, + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + Stride<2> + RawAccessMut @@ -403,7 +682,9 @@ pub fn update_id_level< update_type: BatchUpdateType, factor_1: &FactorType, factor_2: &FactorType, - trans: bool, + trans: TransMode, + thread_pool: &ThreadPool, + num_threads: usize, ) -> u128 where LuDecomposition, 2>>: @@ -411,25 +692,32 @@ where TriangularMatrix: TriangularOperations, { let start = Instant::now(); - let sketch_factor_options = MulOptions { + let sketch_base_options = BaseFactorOptions { inv: true, trans, + trans_target: true, + }; + let test_base_options = BaseFactorOptions { + inv: false, + trans, + trans_target: true, + }; + + let sketch_factor_options = MulOptions { + base_options: sketch_base_options, side: Side::Left, factor_type: factor_1.clone(), - t_trans: true, }; let test_factor_options = MulOptions { - inv: false, - trans, + base_options: test_base_options, side: Side::Left, factor_type: factor_2.clone(), - t_trans: true, }; match update_type { BatchUpdateType::Single(id_batch) => { - id_batch.mul(sketch, &sketch_factor_options); - id_batch.mul(test, &test_factor_options); + id_batch.mul(sketch, thread_pool, num_threads, &sketch_factor_options); + id_batch.mul(test, thread_pool, num_threads, &test_factor_options); } BatchUpdateType::Multi(rsrs_factors) => { rsrs_factors.apply_id_level(sketch, &sketch_factor_options, level_it); @@ -440,8 +728,16 @@ where start.elapsed().as_millis() } +#[allow(clippy::too_many_arguments)] pub fn update_lu_level< - Item: RlstScalar + RandScalar + MatrixId + MatrixInverse + MatrixPseudoInverse + MatrixLu + MatrixQr, + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + Stride<2> + RawAccessMut @@ -457,7 +753,9 @@ pub fn update_lu_level< update_type: BatchUpdateType, factor_1: &FactorType, factor_2: &FactorType, - trans: bool, + trans: TransMode, + thread_pool: &ThreadPool, + num_threads: usize, ) -> u128 where LuDecomposition, 2>>: @@ -465,25 +763,32 @@ where TriangularMatrix: TriangularOperations, { let start = Instant::now(); - let sketch_factor_options = MulOptions { + let sketch_base_options = BaseFactorOptions { inv: true, trans, + trans_target: true, + }; + let test_base_options = BaseFactorOptions { + inv: false, + trans, + trans_target: true, + }; + + let sketch_factor_options = MulOptions { side: Side::Left, factor_type: factor_1.clone(), - t_trans: true, + base_options: sketch_base_options, }; let test_factor_options = MulOptions { - inv: false, - trans, side: Side::Left, factor_type: factor_2.clone(), - t_trans: true, + base_options: test_base_options, }; match update_type { BatchUpdateType::Single(lu_batch) => { - lu_batch.mul(sketch, &sketch_factor_options); - lu_batch.mul(test, &test_factor_options); + lu_batch.mul(sketch, thread_pool, num_threads, &sketch_factor_options); + lu_batch.mul(test, thread_pool, num_threads, &test_factor_options); } BatchUpdateType::Multi(rsrs_factors) => { rsrs_factors.apply_lu_level(sketch, &sketch_factor_options, false, level_it); @@ -493,3 +798,76 @@ where start.elapsed().as_millis() } + +pub fn update_level< + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2> + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item> + + std::marker::Send + + std::marker::Sync, +>( + sketch: &mut Array, + test: &mut Array, + level_it: usize, + update_type: BatchUpdateType, + factor_1: &FactorType, + factor_2: &FactorType, + trans: TransMode, +) -> (u128, u128) +where + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + let sketch_base_options = BaseFactorOptions { + inv: true, + trans, + trans_target: true, + }; + let test_base_options = BaseFactorOptions { + inv: false, + trans, + trans_target: true, + }; + let sketch_factor_options = MulOptions { + side: Side::Left, + factor_type: factor_1.clone(), + base_options: sketch_base_options, + }; + let test_factor_options = MulOptions { + side: Side::Left, + factor_type: factor_2.clone(), + base_options: test_base_options, + }; + + let mut id_update_time = 0; + let mut lu_update_time = 0; + + match update_type { + BatchUpdateType::Single(_lu_batch) => panic!("Only implemented for multi batch updates"), + BatchUpdateType::Multi(rsrs_factors) => { + let (id_time, lu_time) = + rsrs_factors.apply_level(sketch, &sketch_factor_options, false, level_it); + id_update_time += id_time; + lu_update_time += lu_time; + + let (id_time, lu_time) = + rsrs_factors.apply_level(test, &test_factor_options, false, level_it); + id_update_time += id_time; + lu_update_time += lu_time; + } + } + + (id_update_time, lu_update_time) +} diff --git a/src/rsrs/statistics.rs b/src/rsrs/statistics.rs new file mode 100644 index 0000000..ec92e79 --- /dev/null +++ b/src/rsrs/statistics.rs @@ -0,0 +1,134 @@ +use serde::Serialize; + +#[derive(Debug, Serialize, Clone)] +pub struct LuTimes { + pub extraction: u128, + pub lu: u128, +} + +#[derive(Debug, Serialize, Clone)] +pub struct LevelEffort { + pub time: u128, + pub num_boxes: usize, + pub num_batches: usize, + pub effective_dofs: usize, + pub sketch_len: usize, + pub residual_len: usize, +} + +#[derive(Debug, Serialize, Clone)] +pub struct IdTimes { + pub nullification: u128, + pub id: u128, +} + +pub enum Times { + Lu(LuTimes), + Id(IdTimes), +} + +#[derive(Debug, Serialize, Clone)] +pub struct UpdateTimes { + pub id: u128, + pub lu: u128, +} + +macro_rules! impl_times_operations { + ($struct_name:ident, $trait_name:ident, $arg_1:ident, $arg_2:ident) => { + pub trait $trait_name { + fn new() -> Self; + fn sum(&mut self, $arg_1: u128, $arg_2: u128); + } + + impl $trait_name for $struct_name { + fn new() -> Self { + Self { + $arg_1: 0_u128, + $arg_2: 0_u128, + } + } + + fn sum(&mut self, $arg_1: u128, $arg_2: u128) { + self.$arg_1 += $arg_1; + self.$arg_2 += $arg_2; + } + } + }; +} + +impl_times_operations!(IdTimes, IdTimesOperations, nullification, id); +impl_times_operations!(LuTimes, LuTimesOperations, extraction, lu); +impl_times_operations!(UpdateTimes, UpdateTimesOperations, id, lu); + +#[derive(Debug)] +pub struct LimitingLevel { + pub level: usize, + pub num_boxes: usize, + pub active_points: usize, + pub elapsed_time: u128, +} + +#[derive(Debug)] +pub struct LimitingFactors { + pub min_samples: usize, + pub max_level: usize, + pub limiting_level: LimitingLevel, + pub leaf_count: usize, +} + +#[derive(Debug, Serialize, Clone, Default)] +pub struct FactorMemoryStats { + pub total_bytes: u64, + pub id_bytes: u64, + pub lu_bytes: u64, + pub diag_bytes: u64, + pub perm_bytes: u64, + pub id_count: usize, + pub lu_count: usize, + pub diag_count: usize, +} + +#[derive(Debug, Serialize, Clone, Default)] +pub struct MemorySnapshot { + pub label: String, + pub rss_bytes: Option, + pub peak_rss_bytes: Option, + pub baseline_rss_bytes: Option, + pub sample_buffer_bytes: u64, + pub factor_memory: FactorMemoryStats, + pub accounted_factorization_bytes: u64, + pub estimated_temporary_runtime_bytes: Option, +} + +#[derive(Debug)] +pub struct Stats { + pub sampling_time: Vec, + pub sample_loading_time: u128, + pub sampling_extraction_time: u128, + pub id_times: Vec, + pub tot_id_time: u128, + pub lu_times: Vec, + pub tot_lu_time: u128, + pub update_times: Vec, + pub total_elapsed_time: u128, + pub total_elapsed_time_wo_sampling: u128, + pub dim: usize, + pub extraction_time: u128, + pub residual_size: usize, + pub ranks: Vec, + pub box_sizes: Vec, + pub near_field_sizes: Vec, + pub dec_boxes_per_level: Vec, + pub index_calculation: u128, + pub sorting_near_field: u128, + pub residual_calculation: u128, + pub limiting_factors: LimitingFactors, + pub level_effort: Vec, + pub mv_avg_time: Vec, + pub memory_snapshots: Vec, + pub run_start_rss_bytes: Option, + pub max_sample_buffer_bytes: u64, + pub max_factor_bytes: u64, + pub max_accounted_factorization_bytes: u64, + pub max_estimated_temporary_runtime_bytes: Option, +} diff --git a/src/rsrs/tree_indexing.rs b/src/rsrs/tree_indexing.rs index c1be291..9bae03d 100644 --- a/src/rsrs/tree_indexing.rs +++ b/src/rsrs/tree_indexing.rs @@ -1,9 +1,10 @@ use bempp_octree::{morton::MortonKey, octree::Octree}; use mpi::traits::CommunicatorCollectives; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap}; +#[derive(Clone)] pub struct TreeData { - pub level_keys: HashSet, + pub level_keys: BTreeSet, pub boxes_map: HashMap>, pub max_level: usize, pub current_level: usize, @@ -16,16 +17,9 @@ pub trait TreeIndexing: Sized { fn update_level_keys(&mut self); - fn next_level_keys(&mut self) -> HashSet; + fn next_level_keys(&mut self) -> BTreeSet; - fn get_box_near_field_keys(&self, box_key: &MortonKey, level: usize) -> HashSet; - - //Function to get indices of points in a box - fn get_box_indices(&self, box_key: &MortonKey) -> Option<&Vec>; - - fn get_box_far_field_keys(&self, box_key: &MortonKey) -> HashSet; - - fn get_neighbouring_indices(&self, box_key: &MortonKey) -> Option>; + fn get_box_near_field_keys(&self, box_key: &MortonKey, level: usize) -> BTreeSet; } impl TreeIndexing for TreeData { @@ -35,7 +29,7 @@ impl TreeIndexing for TreeData { octree_data.leaf_keys_to_local_point_indices().clone(); let leaf_tree_keys = octree_data.leaf_keys().iter().cloned(); let max_level = octree_data.global_max_level(); - let level_keys = leaf_tree_keys.collect::>(); + let level_keys = leaf_tree_keys.collect::>(); let current_level = max_level; Self { level_keys, @@ -51,7 +45,7 @@ impl TreeIndexing for TreeData { self.current_level -= 1; } - fn next_level_keys(&mut self) -> HashSet { + fn next_level_keys(&mut self) -> BTreeSet { let next_level_keys = self .level_keys .iter() @@ -71,12 +65,12 @@ impl TreeIndexing for TreeData { .filter(|key| { self.current_level == self.max_level || key.level() == self.current_level - 1 }) - .collect::>(); + .collect::>(); next_level_keys } - fn get_box_near_field_keys(&self, box_key: &MortonKey, level: usize) -> HashSet { + fn get_box_near_field_keys(&self, box_key: &MortonKey, level: usize) -> BTreeSet { if level == self.max_level { self.neighbour_map .get(box_key) @@ -94,38 +88,4 @@ impl TreeIndexing for TreeData { .collect() } } - - fn get_box_indices(&self, box_key: &MortonKey) -> Option<&Vec> { - self.boxes_map.get(box_key) - } - - fn get_box_far_field_keys(&self, box_key: &MortonKey) -> HashSet { - let level_keys: &HashSet = &self.level_keys; - let near_keys: HashSet = - self.get_box_near_field_keys(box_key, self.current_level); - let far_keys: HashSet = level_keys - .difference(&near_keys) - .cloned() - .collect::>(); - far_keys - } - - fn get_neighbouring_indices(&self, box_key: &MortonKey) -> Option> { - match self.boxes_map.get(box_key) { - Some(indices) => { - let mut neighbour_indices: Vec = Vec::new(); - neighbour_indices.extend_from_slice(indices); - let neighbour_keys: std::collections::hash_set::IntoIter = self - .get_box_near_field_keys(box_key, self.current_level) - .into_iter(); - for neighbour_key in neighbour_keys { - if let Some(indices) = self.boxes_map.get(&neighbour_key) { - neighbour_indices.extend_from_slice(indices); - } - } - Some(neighbour_indices) - } - None => None, - } - } } diff --git a/src/utils.rs b/src/utils.rs index e0ce315..d5b75ee 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -2,7 +2,9 @@ pub mod data_ins_ext; pub mod elementary_matrix; -pub mod least_squares_and_null; +pub mod io; +pub mod linear_algebra; +pub mod memory; pub mod operator_templates; pub mod print; diff --git a/src/utils/data_ins_ext.rs b/src/utils/data_ins_ext.rs index 033b734..8f56c98 100644 --- a/src/utils/data_ins_ext.rs +++ b/src/utils/data_ins_ext.rs @@ -1,4 +1,5 @@ pub use rlst::prelude::*; +use std::marker::PhantomData; pub struct Extraction { pub ext: DynamicArray, @@ -24,6 +25,45 @@ pub enum ExtInsType { Cross(Vec, Vec), } +#[derive(Clone, Copy)] +pub struct RawMatrixMut { + ptr: *mut T, + offset: usize, + stride: [usize; 2], + shape: [usize; 2], + _marker: PhantomData, +} + +unsafe impl Send for RawMatrixMut {} +unsafe impl Sync for RawMatrixMut {} + +impl RawMatrixMut { + #[inline] + unsafe fn elem_ptr(&self, row: usize, col: usize) -> *mut T { + debug_assert!(row < self.shape[0]); + debug_assert!(col < self.shape[1]); + unsafe { + self.ptr + .add(self.offset + row * self.stride[0] + col * self.stride[1]) + } + } +} + +pub fn raw_matrix_mut< + T: RlstScalar, + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = T> + Shape<2> + Stride<2> + RawAccessMut, +>( + target_arr: &mut Array, +) -> RawMatrixMut { + RawMatrixMut { + ptr: target_arr.buff_ptr_mut(), + offset: target_arr.offset(), + stride: target_arr.stride(), + shape: target_arr.shape(), + _marker: PhantomData, + } +} + impl MatrixExtraction for Extraction { type Item = T; fn new< @@ -64,6 +104,57 @@ impl MatrixExtraction for Extraction { } } +fn axis_shape(num_rows: usize, num_cols: usize, exchange_axis: bool) -> [usize; 2] { + if exchange_axis { + [num_cols, num_rows] + } else { + [num_rows, num_cols] + } +} + +pub fn extract_axis_into< + T: RlstScalar, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> + + Shape<2> + + RawAccess + + UnsafeRandomAccessByRef<2, Item = T>, +>( + target_arr: &mut DynamicArray, + source_arr: &Array, + inds: &[usize], + axis: usize, + exchange_axis: bool, +) { + if axis == 0 { + fill_rows(target_arr, inds, source_arr, exchange_axis); + } else { + fill_cols(target_arr, inds, source_arr, exchange_axis); + } +} + +pub fn extract_cross_into< + T: RlstScalar, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> + + Shape<2> + + RawAccess + + UnsafeRandomAccessByRef<2, Item = T>, +>( + target_arr: &mut DynamicArray, + source_arr: &Array, + rows: &[usize], + cols: &[usize], +) { + target_arr.resize_in_place([rows.len(), cols.len()]); + let mut target_view = target_arr.r_mut(); + let source_view = source_arr.r(); + + for (col_ind, col) in cols.iter().enumerate() { + for (row_ind, row) in rows.iter().enumerate() { + target_view[[row_ind, col_ind]] = source_view[[*row, *col]]; + } + } +} + fn get_rows< T: RlstScalar, ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> @@ -75,34 +166,45 @@ fn get_rows< source_arr: &Array, exchange_axis: bool, ) -> DynamicArray { + let mut target_arr = empty_array(); + fill_rows(&mut target_arr, &inds, source_arr, exchange_axis); + target_arr +} + +fn fill_rows< + T: RlstScalar, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> + + Shape<2> + + RawAccess + + UnsafeRandomAccessByRef<2, Item = T>, +>( + target_arr: &mut DynamicArray, + inds: &[usize], + source_arr: &Array, + exchange_axis: bool, +) { let num_cols = source_arr.shape()[1]; let num_rows = inds.len(); + target_arr.resize_in_place(axis_shape(num_rows, num_cols, exchange_axis)); + let view_2 = source_arr.r(); - let mut target_arr = if exchange_axis { - rlst_dynamic_array2!(T, [num_cols, num_rows]) - } else { - rlst_dynamic_array2!(T, [num_rows, num_cols]) - }; let mut view_1 = target_arr.r_mut(); + if exchange_axis { for col_ind in 0..num_cols { let col_slice = view_2.r().slice(1, col_ind); for (row_ind, &row) in inds.iter().enumerate() { - let val = col_slice[[row]]; - view_1[[col_ind, row_ind]] = val; + view_1[[col_ind, row_ind]] = col_slice[[row]]; } } } else { for col_ind in 0..num_cols { let col_slice = view_2.r().slice(1, col_ind); for (row_ind, &row) in inds.iter().enumerate() { - let val = col_slice[[row]]; - view_1[[row_ind, col_ind]] = val; + view_1[[row_ind, col_ind]] = col_slice[[row]]; } } } - - target_arr } fn insert_rows< @@ -146,6 +248,54 @@ fn insert_rows< } } +fn accumulate_rows< + T: RlstScalar, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> + + Shape<2> + + RawAccess + + UnsafeRandomAccessMut<2, Item = T> + + UnsafeRandomAccessByRef<2, Item = T>, + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = T> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = T> + + UnsafeRandomAccessByRef<2, Item = T>, +>( + inds: &[usize], + source_arr: &Array, + target_arr: &mut Array, + exchange_axis: bool, + subtract: bool, +) { + let num_cols = source_arr.shape()[1]; + let mut view_1 = target_arr.r_mut(); + let view_2 = source_arr.r(); + + if exchange_axis { + for col_ind in 0..num_cols { + let col_slice = view_2.r().slice(1, col_ind); + for (row_ind, &row) in inds.iter().enumerate() { + if subtract { + view_1[[col_ind, row]] -= col_slice[[row_ind]]; + } else { + view_1[[col_ind, row]] += col_slice[[row_ind]]; + } + } + } + } else { + for col_ind in 0..num_cols { + let col_slice = view_2.r().slice(1, col_ind); + for (row_ind, &row) in inds.iter().enumerate() { + if subtract { + view_1[[row, col_ind]] -= col_slice[[row_ind]]; + } else { + view_1[[row, col_ind]] += col_slice[[row_ind]]; + } + } + } + } +} + fn get_cols< T: RlstScalar, ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> @@ -157,37 +307,45 @@ fn get_cols< source_arr: &Array, exchange_axis: bool, ) -> DynamicArray { + let mut target_arr = empty_array(); + fill_cols(&mut target_arr, &inds, source_arr, exchange_axis); + target_arr +} + +fn fill_cols< + T: RlstScalar, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> + + Shape<2> + + RawAccess + + UnsafeRandomAccessByRef<2, Item = T>, +>( + target_arr: &mut DynamicArray, + inds: &[usize], + source_arr: &Array, + exchange_axis: bool, +) { let num_rows = source_arr.shape()[0]; let num_cols = inds.len(); - let view_2 = source_arr.r(); - - let mut target_arr = if exchange_axis { - rlst_dynamic_array2!(T, [num_cols, num_rows]) - } else { - rlst_dynamic_array2!(T, [num_rows, num_cols]) - }; + target_arr.resize_in_place(axis_shape(num_rows, num_cols, exchange_axis)); + let view_2 = source_arr.r(); let mut view_1 = target_arr.r_mut(); if exchange_axis { for (col_ind, &col) in inds.iter().enumerate() { let col_slice = view_2.r().slice(1, col); for row in 0..num_rows { - let val = col_slice[[row]]; - view_1[[col_ind, row]] = val; + view_1[[col_ind, row]] = col_slice[[row]]; } } } else { for (col_ind, &col) in inds.iter().enumerate() { let col_slice = view_2.r().slice(1, col); for row in 0..num_rows { - let val = col_slice[[row]]; - view_1[[row, col_ind]] = val; + view_1[[row, col_ind]] = col_slice[[row]]; } } } - - target_arr } fn insert_cols< @@ -232,6 +390,202 @@ fn insert_cols< } } +fn accumulate_cols< + T: RlstScalar, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> + + Shape<2> + + RawAccess + + UnsafeRandomAccessMut<2, Item = T> + + UnsafeRandomAccessByRef<2, Item = T>, + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = T> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = T> + + UnsafeRandomAccessByRef<2, Item = T>, +>( + inds: &[usize], + source_arr: &Array, + target_arr: &mut Array, + exchange_axis: bool, + subtract: bool, +) { + let num_rows = source_arr.shape()[0]; + let mut view_1 = target_arr.r_mut(); + let view_2 = source_arr.r(); + + if exchange_axis { + for (col_ind, &col) in inds.iter().enumerate() { + let col_slice = view_2.r().slice(1, col_ind); + for row in 0..num_rows { + if subtract { + view_1[[col, row]] -= col_slice[[row]]; + } else { + view_1[[col, row]] += col_slice[[row]]; + } + } + } + } else { + for (col_ind, &col) in inds.iter().enumerate() { + let col_slice = view_2.r().slice(1, col_ind); + for row in 0..num_rows { + if subtract { + view_1[[row, col]] -= col_slice[[row]]; + } else { + view_1[[row, col]] += col_slice[[row]]; + } + } + } + } +} + +pub fn matrix_accumulation< + T: RlstScalar, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> + + Shape<2> + + Stride<2> + + UnsafeRandomAccessMut<2, Item = T> + + UnsafeRandomAccessByRef<2, Item = T> + + RawAccessMut + + Shape<2>, + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = T> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = T> + + UnsafeRandomAccessByRef<2, Item = T>, +>( + target_arr: &mut Array, + source_arr: &Array, + indices: ExtInsType, + subtract: bool, +) { + match indices { + ExtInsType::Axis(inds, axis, exchange_axis) => { + if axis == 0 { + accumulate_rows(&inds, source_arr, target_arr, exchange_axis, subtract); + } else { + accumulate_cols(&inds, source_arr, target_arr, exchange_axis, subtract); + } + } + ExtInsType::Cross(rows, cols) => { + let mut view_1 = target_arr.r_mut(); + let view_2 = source_arr.r(); + for (col_ind, col) in cols.iter().enumerate() { + for (row_ind, row) in rows.iter().enumerate() { + if subtract { + view_1[[row_ind, col_ind]] -= view_2[[*row, *col]]; + } else { + view_1[[row_ind, col_ind]] += view_2[[*row, *col]]; + } + } + } + } + } +} + +/// Accumulate a dense source block into a raw target matrix view. +/// +/// # Safety +/// +/// `target_arr` must point to a valid writable matrix region large enough for +/// every location addressed through `indices`, and those writes must not alias +/// any mutable references held elsewhere. +pub unsafe fn matrix_accumulation_raw( + target_arr: RawMatrixMut, + source_arr: &DynamicArray, + indices: ExtInsType, + subtract: bool, +) { + match indices { + ExtInsType::Axis(inds, axis, exchange_axis) => { + if axis == 0 { + let num_cols = source_arr.shape()[1]; + let source_view = source_arr.r(); + if exchange_axis { + for col_ind in 0..num_cols { + let col_slice = source_view.r().slice(1, col_ind); + for (row_ind, &row) in inds.iter().enumerate() { + let ptr = unsafe { target_arr.elem_ptr(col_ind, row) }; + let val = col_slice[[row_ind]]; + unsafe { + if subtract { + *ptr -= val; + } else { + *ptr += val; + } + } + } + } + } else { + for col_ind in 0..num_cols { + let col_slice = source_view.r().slice(1, col_ind); + for (row_ind, &row) in inds.iter().enumerate() { + let ptr = unsafe { target_arr.elem_ptr(row, col_ind) }; + let val = col_slice[[row_ind]]; + unsafe { + if subtract { + *ptr -= val; + } else { + *ptr += val; + } + } + } + } + } + } else { + let num_rows = source_arr.shape()[0]; + let source_view = source_arr.r(); + if exchange_axis { + for (col_ind, &col) in inds.iter().enumerate() { + let col_slice = source_view.r().slice(1, col_ind); + for row in 0..num_rows { + let ptr = unsafe { target_arr.elem_ptr(col, row) }; + let val = col_slice[[row]]; + unsafe { + if subtract { + *ptr -= val; + } else { + *ptr += val; + } + } + } + } + } else { + for (col_ind, &col) in inds.iter().enumerate() { + let col_slice = source_view.r().slice(1, col_ind); + for row in 0..num_rows { + let ptr = unsafe { target_arr.elem_ptr(row, col) }; + let val = col_slice[[row]]; + unsafe { + if subtract { + *ptr -= val; + } else { + *ptr += val; + } + } + } + } + } + } + } + ExtInsType::Cross(rows, cols) => { + let source_view = source_arr.r(); + for (col_ind, col) in cols.iter().enumerate() { + for (row_ind, row) in rows.iter().enumerate() { + let ptr = unsafe { target_arr.elem_ptr(*row, *col) }; + let val = source_view[[row_ind, col_ind]]; + unsafe { + if subtract { + *ptr -= val; + } else { + *ptr += val; + } + } + } + } + } + } +} + pub fn matrix_insertion< T: RlstScalar, ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> @@ -284,9 +638,9 @@ pub fn extract_axis< axis: usize, trans: bool, ) -> DynamicArray { - as MatrixExtraction>::new(mat, ExtInsType::Axis(inds.to_vec(), axis, trans)) - .unwrap() - .ext + let mut extracted = empty_array(); + extract_axis_into(&mut extracted, mat, inds, axis, trans); + extracted } /* pub enum SubArrType { diff --git a/src/utils/elementary_matrix.rs b/src/utils/elementary_matrix.rs index fa5fd13..124ed83 100644 --- a/src/utils/elementary_matrix.rs +++ b/src/utils/elementary_matrix.rs @@ -1,5 +1,10 @@ //! Elementary matrices (row swapping, row multiplication and row addition) -use super::data_ins_ext::{matrix_insertion, ExtInsType, Extraction, MatrixExtraction}; +use crate::rsrs::rsrs_factors::base_factors::BaseFactorOptions; + +use super::data_ins_ext::{ + matrix_accumulation, matrix_accumulation_raw, matrix_insertion, ExtInsType, Extraction, + MatrixExtraction, RawMatrixMut, +}; use num::One; use rlst::{ dense::{ @@ -255,7 +260,7 @@ pub fn row_ops< } else { subarr_rows.sum_into(res_mul.r()); } - + println!("1"); matrix_insertion( right_arr, &subarr_rows, @@ -330,7 +335,7 @@ pub fn col_ops< } else { subarr_cols.sum_into(res_mul.r()); } - + println!("2"); matrix_insertion( right_arr, &subarr_cols, @@ -348,16 +353,15 @@ pub fn ext_rows< c_indices: Vec, r_indices: Vec, right_arr: &Array, - trans: bool, - trans_right_arr: bool, + base_options: &BaseFactorOptions, ) -> DynamicArray { - let row_indices = if trans { + let row_indices = if base_options.trans_val() { c_indices.clone() } else { r_indices.clone() }; - let (axis, transposed) = if trans_right_arr { + let (axis, transposed) = if base_options.trans_target { (1, true) } else { (0, false) @@ -383,12 +387,15 @@ pub fn ext_cols< c_indices: Vec, r_indices: Vec, right_arr: &Array, - trans: bool, - trans_right_arr: bool, + base_options: &BaseFactorOptions, ) -> DynamicArray { - let col_indices = if trans { r_indices } else { c_indices }; + let col_indices = if base_options.trans_val() { + r_indices + } else { + c_indices + }; - let (axis, transposed) = if trans_right_arr { + let (axis, transposed) = if base_options.trans_target { (0, true) } else { (1, false) @@ -493,12 +500,14 @@ pub fn row_subs< r_indices: Vec, source_arr: &DynamicArray, target_arr: &mut Array, - trans: bool, - trans_subs: bool, + base_options: &BaseFactorOptions, ) { - let row_indices = if trans { c_indices } else { r_indices }; - - if trans_subs { + let row_indices = if base_options.trans_val() { + c_indices + } else { + r_indices + }; + if base_options.trans_target { matrix_insertion( target_arr, source_arr, @@ -513,6 +522,83 @@ pub fn row_subs< } } +pub fn row_delta< + Item: RlstScalar, + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, +>( + c_indices: &[usize], + r_indices: &[usize], + source_arr: &DynamicArray, + target_arr: &mut Array, + base_options: &BaseFactorOptions, + subtract: bool, +) { + let row_indices = if base_options.trans_val() { + c_indices + } else { + r_indices + }; + if base_options.trans_target { + matrix_accumulation( + target_arr, + source_arr, + ExtInsType::Axis(row_indices.to_vec(), 0, true), + subtract, + ); + } else { + matrix_accumulation( + target_arr, + source_arr, + ExtInsType::Axis(row_indices.to_vec(), 0, false), + subtract, + ); + } +} + +/// Apply a row-wise delta update into a raw target matrix. +/// +/// # Safety +/// +/// `target_arr` must be valid for every indexed write induced by +/// `c_indices`/`r_indices` and must not alias any other mutable access. +pub unsafe fn row_delta_raw( + c_indices: &[usize], + r_indices: &[usize], + source_arr: &DynamicArray, + target_arr: RawMatrixMut, + base_options: &BaseFactorOptions, + subtract: bool, +) { + let row_indices = if base_options.trans_val() { + c_indices + } else { + r_indices + }; + if base_options.trans_target { + unsafe { + matrix_accumulation_raw( + target_arr, + source_arr, + ExtInsType::Axis(row_indices.to_vec(), 0, true), + subtract, + ) + }; + } else { + unsafe { + matrix_accumulation_raw( + target_arr, + source_arr, + ExtInsType::Axis(row_indices.to_vec(), 0, false), + subtract, + ) + }; + } +} + ///This method implements the row addition/substraction pub fn col_ops_no_sub< Item: RlstScalar, @@ -598,12 +684,15 @@ pub fn col_subs< r_indices: Vec, source_arr: &DynamicArray, target_arr: &mut Array, - trans: bool, - trans_subs: bool, + base_options: &BaseFactorOptions, ) { - let col_indices = if trans { r_indices } else { c_indices }; + let col_indices = if base_options.trans_val() { + r_indices + } else { + c_indices + }; - if trans_subs { + if base_options.trans_target { matrix_insertion( target_arr, source_arr, @@ -618,6 +707,85 @@ pub fn col_subs< } } +pub fn col_delta< + Item: RlstScalar, + ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Item> + + Shape<2> + + RawAccessMut + + UnsafeRandomAccessMut<2, Item = Item> + + UnsafeRandomAccessByRef<2, Item = Item>, +>( + c_indices: &[usize], + r_indices: &[usize], + source_arr: &DynamicArray, + target_arr: &mut Array, + base_options: &BaseFactorOptions, + subtract: bool, +) { + let col_indices = if base_options.trans_val() { + r_indices + } else { + c_indices + }; + + if base_options.trans_target { + matrix_accumulation( + target_arr, + source_arr, + ExtInsType::Axis(col_indices.to_vec(), 1, true), + subtract, + ); + } else { + matrix_accumulation( + target_arr, + source_arr, + ExtInsType::Axis(col_indices.to_vec(), 1, false), + subtract, + ); + } +} + +/// Apply a column-wise delta update into a raw target matrix. +/// +/// # Safety +/// +/// `target_arr` must be valid for every indexed write induced by +/// `c_indices`/`r_indices` and must not alias any other mutable access. +pub unsafe fn col_delta_raw( + c_indices: &[usize], + r_indices: &[usize], + source_arr: &DynamicArray, + target_arr: RawMatrixMut, + base_options: &BaseFactorOptions, + subtract: bool, +) { + let col_indices = if base_options.trans_val() { + r_indices + } else { + c_indices + }; + + if base_options.trans_target { + unsafe { + matrix_accumulation_raw( + target_arr, + source_arr, + ExtInsType::Axis(col_indices.to_vec(), 1, true), + subtract, + ) + }; + } else { + unsafe { + matrix_accumulation_raw( + target_arr, + source_arr, + ExtInsType::Axis(col_indices.to_vec(), 1, false), + subtract, + ) + }; + } +} + ///This method implements the row permutation pub fn row_perm< Item: RlstScalar, diff --git a/src/utils/io.rs b/src/utils/io.rs new file mode 100644 index 0000000..a3dfa69 --- /dev/null +++ b/src/utils/io.rs @@ -0,0 +1,1033 @@ +use hdf5::File; +use num_complex::Complex; +pub use rlst::prelude::*; +use std::path::{Path, PathBuf}; + +// ---------------------- +// Chunking / blocking parameters +// ---------------------- +// BLOCK_COLS: number of columns (across N) stored per PART FILE. +const BLOCK_COLS: usize = 4096; +// CHUNK_ROWS: HDF5 internal chunk length ~ CHUNK_ROWS * block_width elements. +const CHUNK_ROWS: usize = 256; + +pub const DEFAULT_SAMPLING_DIR: &str = "sampling"; + +pub fn preferred_sampling_dir(configured_sampling_dir: Option<&str>) -> PathBuf { + configured_sampling_dir + .filter(|dir| !dir.trim().is_empty()) + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from(DEFAULT_SAMPLING_DIR)) +} + +fn candidate_sampling_dirs(configured_sampling_dir: Option<&Path>) -> Vec { + let preferred = configured_sampling_dir + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from(DEFAULT_SAMPLING_DIR)); + let legacy = PathBuf::from(DEFAULT_SAMPLING_DIR); + + if preferred == legacy { + vec![preferred] + } else { + vec![preferred, legacy] + } +} + +fn ensure_sampling_dir(sampling_dir: &Path) -> hdf5::Result<()> { + std::fs::create_dir_all(sampling_dir).map_err(|e| { + hdf5::Error::Internal(format!( + "failed to create '{}' directory: {e}", + sampling_dir.display() + )) + }) +} + +// ---------------------- +// Helpers +// ---------------------- +fn nblocks(ncols: usize) -> usize { + ncols.div_ceil(BLOCK_COLS) +} + +fn block_width(ncols: usize, b: usize) -> usize { + let col0 = b * BLOCK_COLS; + (ncols - col0).min(BLOCK_COLS) +} + +fn chunk_len(total_len: usize, wk: usize) -> usize { + total_len.min((CHUNK_ROWS * wk).max(1)).max(1) +} + +fn append_rows_column_major( + old_flat: &[T], + extra_flat: &[T], + m_old: usize, + m_add: usize, + wk: usize, +) -> Vec { + let m_new = m_old + m_add; + let mut merged = vec![T::default(); m_new * wk]; + + for j in 0..wk { + let old_src = &old_flat[j * m_old..(j + 1) * m_old]; + let extra_src = &extra_flat[j * m_add..(j + 1) * m_add]; + let dst_col = &mut merged[j * m_new..(j + 1) * m_new]; + let (dst_old, dst_extra) = dst_col.split_at_mut(m_old); + dst_old.copy_from_slice(old_src); + dst_extra.copy_from_slice(extra_src); + } + + merged +} + +fn gather_block_column_major< + T: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = T> + Shape<2>, +>( + arr: &Array, + col0: usize, + wk: usize, +) -> Vec { + let shape = arr.shape(); + let rows = shape[0]; + let mut flat = Vec::with_capacity(rows * wk); + + for j in 0..wk { + let col = col0 + j; + for i in 0..rows { + flat.push(arr.get_value([i, col]).unwrap()); + } + } + + flat +} + +/// Accepts base names like: +/// "y_sketch_file" +/// "y_sketch_file.h5" +/// "y_sketch_file.00000.h5" +/// and returns the canonical base: +/// "y_sketch_file" +fn canonical_base(input: &str) -> String { + let s = input.to_string(); + + // Strip suffix ".00000.h5" if present + if s.len() >= 9 && s.ends_with(".h5") { + let tail = &s[s.len() - 9..]; // ".00000.h5" + if tail.as_bytes()[0] == b'.' + && tail.as_bytes()[6] == b'.' + && &tail[7..] == "h5" + && tail[1..6].chars().all(|c| c.is_ascii_digit()) + { + return s[..s.len() - 9].to_string(); + } + } + + // Strip plain ".h5" + if let Some(stripped) = s.strip_suffix(".h5") { + return stripped.to_string(); + } + + s +} + +/// Part file naming: "{dir}/{base}.00000.h5" +fn part_path(sampling_dir: &Path, base: &str, b: usize) -> PathBuf { + let b0 = canonical_base(base); + sampling_dir.join(format!("{b0}.{:05}.h5", b)) +} + +/// Discover all part files in the given directory that match "{base}.00000.h5", sort by index. +fn find_part_files_in_dir(sampling_dir: &Path, base: &str) -> hdf5::Result> { + let stem = canonical_base(base); + + if !sampling_dir.exists() { + return Ok(Vec::new()); + } + + let mut parts: Vec<(usize, PathBuf)> = Vec::new(); + + let rd = std::fs::read_dir(sampling_dir).map_err(|e| { + hdf5::Error::Internal(format!( + "read_dir failed for '{}': {e}", + sampling_dir.display() + )) + })?; + + for entry in rd { + let entry = + entry.map_err(|e| hdf5::Error::Internal(format!("read_dir entry error: {e}")))?; + let path = entry.path(); + if !path.is_file() { + continue; + } + let fname = match path.file_name() { + Some(x) => x.to_string_lossy(), + None => continue, + }; + + let prefix = format!("{stem}."); + if !fname.starts_with(&prefix) || !fname.ends_with(".h5") { + continue; + } + + let middle = &fname[prefix.len()..fname.len() - 3]; + if middle.len() != 5 || !middle.chars().all(|c| c.is_ascii_digit()) { + continue; + } + + let idx: usize = middle.parse().unwrap(); + parts.push((idx, path)); + } + + parts.sort_by_key(|(i, _)| *i); + + for (expected, (idx, _)) in parts.iter().enumerate() { + if *idx != expected { + return Err(hdf5::Error::Internal(format!( + "missing part file index {} (found {}) in {}", + expected, + idx, + sampling_dir.display() + ))); + } + } + + Ok(parts) +} + +type LocatedPartFiles = (PathBuf, Vec<(usize, PathBuf)>); + +fn find_part_files( + base: &str, + sampling_dir: Option<&Path>, +) -> hdf5::Result> { + for dir in candidate_sampling_dirs(sampling_dir) { + let parts = find_part_files_in_dir(&dir, base)?; + if !parts.is_empty() { + return Ok(Some((dir, parts))); + } + } + Ok(None) +} + +pub fn resolve_sampling_dir( + configured_sampling_dir: Option<&str>, + bases: &[&str], +) -> hdf5::Result> { + for dir in candidate_sampling_dirs(configured_sampling_dir.map(Path::new)) { + let mut all_present = true; + + for base in bases { + let parts = find_part_files_in_dir(&dir, base)?; + if parts.is_empty() { + all_present = false; + break; + } + } + + if all_present { + return Ok(Some(dir)); + } + } + + Ok(None) +} + +// Robust shape attr IO (avoids ndarray conversion issues) +fn write_shape(file: &File, shape: [usize; 2]) -> hdf5::Result<()> { + let sh: [u64; 2] = [shape[0] as u64, shape[1] as u64]; + if file.attr("shape").is_ok() { + file.attr("shape")?.write_raw(&sh)?; + } else { + file.new_attr::() + .shape([2]) + .create("shape")? + .write_raw(&sh)?; + } + Ok(()) +} + +fn read_shape(file: &File) -> hdf5::Result<[usize; 2]> { + let sh: Vec = file.attr("shape")?.read_raw()?; + if sh.len() != 2 { + return Err(hdf5::Error::Internal( + "shape attr must have length 2".into(), + )); + } + Ok([sh[0] as usize, sh[1] as usize]) +} + +// Your existing helper (unchanged) +pub fn resize_rows< + Item: RlstScalar, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + Stride<2> + RawAccessMut + Shape<2>, +>( + arr: &Array, + new_shape: [usize; 2], +) -> DynamicArray { + let mut new_arr = rlst_dynamic_array2!(Item, new_shape); + new_arr + .r_mut() + .into_subview([0, 0], arr.shape()) + .fill_from(arr.r()); + + new_arr +} + +// ---------------------- +// Trait +// ---------------------- +pub trait IOData { + type Item: RlstScalar; + + fn load(path: &str) -> hdf5::Result> { + Self::load_in_dir(path, None) + } + + fn load_in_dir(path: &str, sampling_dir: Option<&Path>) -> hdf5::Result>; + + fn load_into_in_dir( + target: &mut DynamicArray, + dim: usize, + path: &str, + sampling_dir: Option<&Path>, + ) -> hdf5::Result<()>; + + fn append + Shape<2>>( + data: &Array, + path: &str, + ) -> hdf5::Result<()> { + Self::append_in_dir(data, path, None) + } + + fn append_in_dir + Shape<2>>( + data: &Array, + path: &str, + sampling_dir: Option<&Path>, + ) -> hdf5::Result<()>; +} + +// ---------------------- +// Real implementation +// ---------------------- +macro_rules! implement_io_data_real { + ($scalar:ty) => { + impl IOData<$scalar> for $scalar { + type Item = $scalar; + + fn load_in_dir( + path: &str, + sampling_dir: Option<&Path>, + ) -> hdf5::Result> { + if Path::new(path).exists() { + let file = File::open(path)?; + if file.dataset("real").is_ok() { + let ds = file.dataset("real")?; + let array = ds.read()?; + return Ok(array.to_vec()); + } + } + + let Some((_dir, parts)) = find_part_files(path, sampling_dir)? else { + return Err(hdf5::Error::Internal(format!( + "no '{}' (old format) and no multipart parts found for base '{}'", + path, + canonical_base(path) + ))); + }; + + let file0 = File::open(&parts[0].1)?; + let [m, ncols] = read_shape(&file0)?; + + let mut out = vec![<$scalar as Default>::default(); m * ncols]; + + for (b, p) in parts.iter() { + let wk = block_width(ncols, *b); + let col0 = (*b) * BLOCK_COLS; + + let fb = File::open(p)?; + let [m2, n2] = read_shape(&fb)?; + if m2 != m || n2 != ncols { + return Err(hdf5::Error::Internal(format!( + "shape mismatch in part '{}': got [{m2},{n2}] expected [{m},{ncols}]", + p.display() + ))); + } + + let ds = fb.dataset("real")?; + let flat: Vec<$scalar> = ds.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::real' as {}: {e}", + p.display(), + stringify!($scalar) + )) + })?; + + if flat.len() != m * wk { + return Err(hdf5::Error::Internal(format!( + "length mismatch in '{}::real': got {}, expected {}", + p.display(), + flat.len(), + m * wk + ))); + } + + for j in 0..wk { + let gcol = col0 + j; + let src = &flat[j * m..(j + 1) * m]; + let dst = &mut out[gcol * m..(gcol + 1) * m]; + dst.copy_from_slice(src); + } + } + + Ok(out) + } + + fn load_into_in_dir( + target: &mut DynamicArray, + dim: usize, + path: &str, + sampling_dir: Option<&Path>, + ) -> hdf5::Result<()> { + if Path::new(path).exists() { + let file = File::open(path)?; + if file.dataset("real").is_ok() { + let ds = file.dataset("real")?; + let flat: Vec<$scalar> = ds.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::real' as {}: {e}", + path, + stringify!($scalar) + )) + })?; + if dim == 0 { + target.resize_in_place([0, 0]); + return Ok(()); + } + if flat.len() % dim != 0 { + return Err(hdf5::Error::Internal(format!( + "length mismatch in '{}::real': got {}, not divisible by dim {}", + path, + flat.len(), + dim + ))); + } + let rows = flat.len() / dim; + target.resize_in_place([rows, dim]); + target.data_mut().copy_from_slice(&flat); + return Ok(()); + } + } + + let Some((_dir, parts)) = find_part_files(path, sampling_dir)? else { + return Err(hdf5::Error::Internal(format!( + "no '{}' (old format) and no multipart parts found for base '{}'", + path, + canonical_base(path) + ))); + }; + + let file0 = File::open(&parts[0].1)?; + let [m, ncols] = read_shape(&file0)?; + if dim != ncols { + return Err(hdf5::Error::Internal(format!( + "column mismatch for base '{}': stored {}, expected {}", + canonical_base(path), + ncols, + dim + ))); + } + + target.resize_in_place([m, ncols]); + let out = target.data_mut(); + + for (b, p) in parts.iter() { + let wk = block_width(ncols, *b); + let col0 = (*b) * BLOCK_COLS; + + let fb = File::open(p)?; + let [m2, n2] = read_shape(&fb)?; + if m2 != m || n2 != ncols { + return Err(hdf5::Error::Internal(format!( + "shape mismatch in part '{}': got [{m2},{n2}] expected [{m},{ncols}]", + p.display() + ))); + } + + let ds = fb.dataset("real")?; + let flat: Vec<$scalar> = ds.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::real' as {}: {e}", + p.display(), + stringify!($scalar) + )) + })?; + + if flat.len() != m * wk { + return Err(hdf5::Error::Internal(format!( + "length mismatch in '{}::real': got {}, expected {}", + p.display(), + flat.len(), + m * wk + ))); + } + + for j in 0..wk { + let gcol = col0 + j; + let src = &flat[j * m..(j + 1) * m]; + let dst = &mut out[gcol * m..(gcol + 1) * m]; + dst.copy_from_slice(src); + } + } + + Ok(()) + } + + fn append_in_dir + Shape<2>>( + extra_arr: &Array<$scalar, ArrayImpl, 2>, + path: &str, + sampling_dir: Option<&Path>, + ) -> hdf5::Result<()> { + let sampling_dir = sampling_dir.unwrap_or_else(|| Path::new(DEFAULT_SAMPLING_DIR)); + ensure_sampling_dir(sampling_dir)?; + + let shape = extra_arr.shape(); + let m_add = shape[0]; + let ncols = shape[1]; + if m_add == 0 { + return Ok(()); + } + + let existing_parts = find_part_files_in_dir(sampling_dir, path)?; + let m_old = if existing_parts.is_empty() { + 0 + } else { + if existing_parts.len() != nblocks(ncols) { + return Err(hdf5::Error::Internal(format!( + "part count mismatch for base '{}': found {}, expected {}", + canonical_base(path), + existing_parts.len(), + nblocks(ncols) + ))); + } + + let file0 = File::open(&existing_parts[0].1)?; + let [stored_rows, stored_ncols] = read_shape(&file0)?; + if stored_ncols != ncols { + return Err(hdf5::Error::Internal(format!( + "column mismatch for base '{}': stored {}, incoming {}", + canonical_base(path), + stored_ncols, + ncols + ))); + } + stored_rows + }; + + let total_rows = m_old + m_add; + + for b in 0..nblocks(ncols) { + let wk = block_width(ncols, b); + let col0 = b * BLOCK_COLS; + let extra_block = gather_block_column_major(extra_arr, col0, wk); + + let merged = if m_old == 0 { + extra_block + } else { + let p = &existing_parts[b].1; + let fb = File::open(p)?; + let flat: Vec<$scalar> = fb.dataset("real")?.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::real' as {}: {e}", + p.display(), + stringify!($scalar) + )) + })?; + + if flat.len() != m_old * wk { + return Err(hdf5::Error::Internal(format!( + "length mismatch in '{}::real': got {}, expected {}", + p.display(), + flat.len(), + m_old * wk + ))); + } + + append_rows_column_major(&flat, &extra_block, m_old, m_add, wk) + }; + + let p = part_path(sampling_dir, path, b); + let file = File::create(&p)?; + write_shape(&file, [total_rows, ncols])?; + + file.new_dataset::<$scalar>() + .shape((merged.len(),)) + .chunk((chunk_len(merged.len(), wk),)) + .create("real")? + .write(&merged)?; + } + + Ok(()) + } + } + }; +} + +// ---------------------- +// Complex implementation +// ---------------------- +macro_rules! implement_io_data_complex { + ($scalar:ty) => { + impl IOData> for Complex<$scalar> { + type Item = Complex<$scalar>; + + fn load_in_dir( + path: &str, + sampling_dir: Option<&Path>, + ) -> hdf5::Result> { + if Path::new(path).exists() { + let file = File::open(path)?; + if file.dataset("real").is_ok() && file.dataset("imag").is_ok() { + let re_array = file.dataset("real")?.read()?; + let im_array = file.dataset("imag")?.read()?; + let re: Vec<$scalar> = re_array.to_vec(); + let im: Vec<$scalar> = im_array.to_vec(); + + if re.len() != im.len() { + return Err(hdf5::Error::Internal( + "mismatched real/imag lengths".into(), + )); + } + + return Ok(re + .into_iter() + .zip(im.into_iter()) + .map(|(r, i)| Complex::new(r, i)) + .collect()); + } + } + + let Some((_dir, parts)) = find_part_files(path, sampling_dir)? else { + return Err(hdf5::Error::Internal(format!( + "no '{}' (old format) and no multipart parts found for base '{}'", + path, + canonical_base(path) + ))); + }; + + let file0 = File::open(&parts[0].1)?; + let [m, ncols] = read_shape(&file0)?; + + let mut re_full = vec![<$scalar as Default>::default(); m * ncols]; + let mut im_full = vec![<$scalar as Default>::default(); m * ncols]; + + for (b, p) in parts.iter() { + let wk = block_width(ncols, *b); + let col0 = (*b) * BLOCK_COLS; + + let fb = File::open(p)?; + let [m2, n2] = read_shape(&fb)?; + if m2 != m || n2 != ncols { + return Err(hdf5::Error::Internal(format!( + "shape mismatch in part '{}': got [{m2},{n2}] expected [{m},{ncols}]", + p.display() + ))); + } + + let re_blk: Vec<$scalar> = fb.dataset("real")?.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::real' as {}: {e}", + p.display(), + stringify!($scalar) + )) + })?; + let im_blk: Vec<$scalar> = fb.dataset("imag")?.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::imag' as {}: {e}", + p.display(), + stringify!($scalar) + )) + })?; + + if re_blk.len() != m * wk || im_blk.len() != m * wk { + return Err(hdf5::Error::Internal(format!( + "length mismatch in '{}': re={}, im={}, expected={}", + p.display(), + re_blk.len(), + im_blk.len(), + m * wk + ))); + } + + for j in 0..wk { + let gcol = col0 + j; + re_full[gcol * m..(gcol + 1) * m] + .copy_from_slice(&re_blk[j * m..(j + 1) * m]); + im_full[gcol * m..(gcol + 1) * m] + .copy_from_slice(&im_blk[j * m..(j + 1) * m]); + } + } + + Ok(re_full + .into_iter() + .zip(im_full.into_iter()) + .map(|(r, i)| Complex::new(r, i)) + .collect()) + } + + fn load_into_in_dir( + target: &mut DynamicArray, + dim: usize, + path: &str, + sampling_dir: Option<&Path>, + ) -> hdf5::Result<()> { + if Path::new(path).exists() { + let file = File::open(path)?; + if file.dataset("real").is_ok() && file.dataset("imag").is_ok() { + let re: Vec<$scalar> = file.dataset("real")?.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::real' as {}: {e}", + path, + stringify!($scalar) + )) + })?; + let im: Vec<$scalar> = file.dataset("imag")?.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::imag' as {}: {e}", + path, + stringify!($scalar) + )) + })?; + if re.len() != im.len() { + return Err(hdf5::Error::Internal( + "mismatched real/imag lengths".into(), + )); + } + if dim == 0 { + target.resize_in_place([0, 0]); + return Ok(()); + } + if re.len() % dim != 0 { + return Err(hdf5::Error::Internal(format!( + "length mismatch in '{}': got {}, not divisible by dim {}", + path, + re.len(), + dim + ))); + } + let rows = re.len() / dim; + target.resize_in_place([rows, dim]); + for (dst, (r, i)) in + target.data_mut().iter_mut().zip(re.into_iter().zip(im)) + { + *dst = Complex::new(r, i); + } + return Ok(()); + } + } + + let Some((_dir, parts)) = find_part_files(path, sampling_dir)? else { + return Err(hdf5::Error::Internal(format!( + "no '{}' (old format) and no multipart parts found for base '{}'", + path, + canonical_base(path) + ))); + }; + + let file0 = File::open(&parts[0].1)?; + let [m, ncols] = read_shape(&file0)?; + if dim != ncols { + return Err(hdf5::Error::Internal(format!( + "column mismatch for base '{}': stored {}, expected {}", + canonical_base(path), + ncols, + dim + ))); + } + + target.resize_in_place([m, ncols]); + let out = target.data_mut(); + + for (b, p) in parts.iter() { + let wk = block_width(ncols, *b); + let col0 = (*b) * BLOCK_COLS; + + let fb = File::open(p)?; + let [m2, n2] = read_shape(&fb)?; + if m2 != m || n2 != ncols { + return Err(hdf5::Error::Internal(format!( + "shape mismatch in part '{}': got [{m2},{n2}] expected [{m},{ncols}]", + p.display() + ))); + } + + let re_blk: Vec<$scalar> = fb.dataset("real")?.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::real' as {}: {e}", + p.display(), + stringify!($scalar) + )) + })?; + let im_blk: Vec<$scalar> = fb.dataset("imag")?.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::imag' as {}: {e}", + p.display(), + stringify!($scalar) + )) + })?; + + if re_blk.len() != m * wk || im_blk.len() != m * wk { + return Err(hdf5::Error::Internal(format!( + "length mismatch in '{}': real {}, imag {}, expected {}", + p.display(), + re_blk.len(), + im_blk.len(), + m * wk + ))); + } + + for j in 0..wk { + let gcol = col0 + j; + let re_src = &re_blk[j * m..(j + 1) * m]; + let im_src = &im_blk[j * m..(j + 1) * m]; + let dst = &mut out[gcol * m..(gcol + 1) * m]; + for (dst_val, (r, i)) in + dst.iter_mut().zip(re_src.iter().zip(im_src.iter())) + { + *dst_val = Complex::new(*r, *i); + } + } + } + + Ok(()) + } + + fn append_in_dir< + ArrayImpl: UnsafeRandomAccessByValue<2, Item = Complex<$scalar>> + Shape<2>, + >( + extra_arr: &Array, ArrayImpl, 2>, + path: &str, + sampling_dir: Option<&Path>, + ) -> hdf5::Result<()> { + let sampling_dir = sampling_dir.unwrap_or_else(|| Path::new(DEFAULT_SAMPLING_DIR)); + ensure_sampling_dir(sampling_dir)?; + + let shape = extra_arr.shape(); + let m_add = shape[0]; + let ncols = shape[1]; + if m_add == 0 { + return Ok(()); + } + + let existing_parts = find_part_files_in_dir(sampling_dir, path)?; + let m_old = if existing_parts.is_empty() { + 0 + } else { + if existing_parts.len() != nblocks(ncols) { + return Err(hdf5::Error::Internal(format!( + "part count mismatch for base '{}': found {}, expected {}", + canonical_base(path), + existing_parts.len(), + nblocks(ncols) + ))); + } + + let file0 = File::open(&existing_parts[0].1)?; + let [stored_rows, stored_ncols] = read_shape(&file0)?; + if stored_ncols != ncols { + return Err(hdf5::Error::Internal(format!( + "column mismatch for base '{}': stored {}, incoming {}", + canonical_base(path), + stored_ncols, + ncols + ))); + } + stored_rows + }; + + let total_rows = m_old + m_add; + + for b in 0..nblocks(ncols) { + let wk = block_width(ncols, b); + let col0 = b * BLOCK_COLS; + let extra_block = gather_block_column_major(extra_arr, col0, wk); + + let mut extra_re: Vec<$scalar> = Vec::with_capacity(extra_block.len()); + let mut extra_im: Vec<$scalar> = Vec::with_capacity(extra_block.len()); + for z in extra_block.iter() { + extra_re.push(z.re); + extra_im.push(z.im); + } + + let (merged_re, merged_im) = if m_old == 0 { + (extra_re, extra_im) + } else { + let p = &existing_parts[b].1; + let fb = File::open(p)?; + let old_re: Vec<$scalar> = fb.dataset("real")?.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::real' as {}: {e}", + p.display(), + stringify!($scalar) + )) + })?; + let old_im: Vec<$scalar> = fb.dataset("imag")?.read_raw().map_err(|e| { + hdf5::Error::Internal(format!( + "failed reading '{}::imag' as {}: {e}", + p.display(), + stringify!($scalar) + )) + })?; + + if old_re.len() != m_old * wk || old_im.len() != m_old * wk { + return Err(hdf5::Error::Internal(format!( + "length mismatch in '{}': re={}, im={}, expected={}", + p.display(), + old_re.len(), + old_im.len(), + m_old * wk + ))); + } + + ( + append_rows_column_major(&old_re, &extra_re, m_old, m_add, wk), + append_rows_column_major(&old_im, &extra_im, m_old, m_add, wk), + ) + }; + + let p = part_path(sampling_dir, path, b); + let file = File::create(&p)?; + write_shape(&file, [total_rows, ncols])?; + + file.new_dataset::<$scalar>() + .shape((merged_re.len(),)) + .chunk((chunk_len(merged_re.len(), wk),)) + .create("real")? + .write(&merged_re)?; + file.new_dataset::<$scalar>() + .shape((merged_im.len(),)) + .chunk((chunk_len(merged_im.len(), wk),)) + .create("imag")? + .write(&merged_im)?; + } + + Ok(()) + } + } + }; +} + +implement_io_data_real!(f64); +implement_io_data_real!(f32); +implement_io_data_complex!(f64); +implement_io_data_complex!(f32); + +#[cfg(test)] +mod tests { + use super::*; + use std::{ + fs, + time::{SystemTime, UNIX_EPOCH}, + }; + + fn unique_temp_dir(tag: &str) -> PathBuf { + let mut dir = std::env::temp_dir(); + let stamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + dir.push(format!("rsrs_{tag}_{stamp}_{}", std::process::id())); + fs::create_dir_all(&dir).unwrap(); + dir + } + + fn array_from_rows(rows: &[&[Item]]) -> DynamicArray { + let m = rows.len(); + let n = rows.first().map_or(0, |row| row.len()); + let mut arr = rlst_dynamic_array2!(Item, [m, n]); + let mut view = arr.r_mut(); + + for i in 0..m { + for j in 0..n { + view[[i, j]] = rows[i][j]; + } + } + + arr + } + + fn flatten_column_major(rows: &[&[T]]) -> Vec { + let m = rows.len(); + let n = rows.first().map_or(0, |row| row.len()); + let mut flat = Vec::with_capacity(m * n); + + for j in 0..n { + for row in rows.iter().take(m) { + flat.push(row[j]); + } + } + + flat + } + + #[test] + fn real_append_preserves_column_major_blocks() { + let dir = unique_temp_dir("io_real_append"); + + let initial = array_from_rows::(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]); + >::append_in_dir(&initial, "y_test_file", Some(dir.as_path())).unwrap(); + + let extra = array_from_rows::(&[&[7.0, 8.0, 9.0]]); + >::append_in_dir(&extra, "y_test_file", Some(dir.as_path())).unwrap(); + + let loaded = >::load_in_dir("y_test_file", Some(dir.as_path())).unwrap(); + assert_eq!( + loaded, + flatten_column_major(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0],]) + ); + + fs::remove_dir_all(dir).unwrap(); + } + + #[test] + fn complex_append_preserves_column_major_blocks() { + let dir = unique_temp_dir("io_complex_append"); + + let initial = array_from_rows::>(&[ + &[Complex::new(1.0, 1.5), Complex::new(2.0, 2.5)], + &[Complex::new(3.0, 3.5), Complex::new(4.0, 4.5)], + ]); + as IOData>>::append_in_dir( + &initial, + "z_sketch_file", + Some(dir.as_path()), + ) + .unwrap(); + + let extra = + array_from_rows::>(&[&[Complex::new(5.0, 5.5), Complex::new(6.0, 6.5)]]); + as IOData>>::append_in_dir( + &extra, + "z_sketch_file", + Some(dir.as_path()), + ) + .unwrap(); + + let loaded = as IOData>>::load_in_dir( + "z_sketch_file", + Some(dir.as_path()), + ) + .unwrap(); + assert_eq!( + loaded, + flatten_column_major(&[ + &[Complex::new(1.0, 1.5), Complex::new(2.0, 2.5)], + &[Complex::new(3.0, 3.5), Complex::new(4.0, 4.5)], + &[Complex::new(5.0, 5.5), Complex::new(6.0, 6.5)], + ]) + ); + + fs::remove_dir_all(dir).unwrap(); + } +} diff --git a/src/utils/least_squares_and_null.rs b/src/utils/linear_algebra.rs similarity index 60% rename from src/utils/least_squares_and_null.rs rename to src/utils/linear_algebra.rs index a656d9c..d95cfee 100644 --- a/src/utils/least_squares_and_null.rs +++ b/src/utils/linear_algebra.rs @@ -1,8 +1,9 @@ -use crate::rsrs::rsrs_cycle::{ExtractOptions, IdOptions}; use rlst::dense::linalg::{lu::MatrixLu, null_space::Method}; pub use rlst::prelude::*; use serde::Deserialize; +use crate::rsrs::rsrs_factors::null_and_extract::{ExtractOptions, IdOptions}; + fn solve_svd< Item: RlstScalar + MatrixPseudoInverse, ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> @@ -34,7 +35,6 @@ fn solve_svd< sketch_mat.r(), num::Zero::zero(), ); - sol } @@ -47,6 +47,38 @@ pub struct NormalEquations< pub normal: LuDecomposition, 2>>, } +pub struct NormalEquationAccumulator { + normal: DynamicArray, + rhs: DynamicArray, +} + +pub struct NormalEquationScratch { + solution: DynamicArray, + projected_rhs: DynamicArray, +} + +pub fn streaming_chunk_rows( + max_rows: usize, + cols_per_row: usize, + live_buffers: usize, +) -> usize { + const STREAMING_TARGET_BYTES: usize = 2 * 1024 * 1024; + + if max_rows == 0 { + return 0; + } + + let bytes_per_row = cols_per_row + .saturating_mul(live_buffers) + .saturating_mul(std::mem::size_of::()); + + if bytes_per_row == 0 { + return max_rows; + } + + (STREAMING_TARGET_BYTES / bytes_per_row).clamp(1, max_rows) +} + pub fn add_diagonal( arr: &mut DynamicArray, val: ::Real, @@ -94,9 +126,20 @@ impl< LuDecomposition, 2>>: MatrixLuDecomposition, { - let arr_shape = self.arr.shape(); - let mut new_rhs = rlst_dynamic_array2!(Item, [arr_shape[1], arr_shape[1]]); - new_rhs.r_mut().mult_into_resize( + let mut solution = empty_array(); + self.solve_normal_equations_into(rhs, &mut solution); + solution + } + + pub fn solve_normal_equations_into( + &self, + rhs: &Array, + solution: &mut DynamicArray, + ) where + LuDecomposition, 2>>: + MatrixLuDecomposition, + { + solution.r_mut().mult_into_resize( TransMode::ConjTrans, TransMode::NoTrans, num::One::one(), @@ -107,27 +150,90 @@ impl< let _ = as MatrixLuDecomposition>::solve_mat( &self.normal, TransMode::NoTrans, - new_rhs.r_mut(), + solution.r_mut(), ); - new_rhs } - fn apply_null_projector(&self, rhs: &mut Array) - where + pub fn apply_null_projector_with_scratch( + &self, + rhs: &mut Array, + scratch: &mut NormalEquationScratch, + ) where LuDecomposition, 2>>: MatrixLuDecomposition, { - let proj = self.solve_normal_equations(rhs); - let mut proj_rhs = rlst_dynamic_array2!(Item, rhs.shape()); - proj_rhs.r_mut().mult_into_resize( + self.solve_normal_equations_into(rhs, &mut scratch.solution); + scratch.projected_rhs.r_mut().mult_into_resize( TransMode::NoTrans, TransMode::NoTrans, num::One::one(), self.arr.r(), - proj.r(), + scratch.solution.r(), num::Zero::zero(), ); - rhs.r_mut().sub_into(proj_rhs.r()); + rhs.r_mut().sub_into(scratch.projected_rhs.r()); + } +} + +impl NormalEquationAccumulator { + pub fn new(lhs_cols: usize, rhs_cols: usize) -> Self { + let mut normal = rlst_dynamic_array2!(Item, [lhs_cols, lhs_cols]); + let mut rhs = rlst_dynamic_array2!(Item, [lhs_cols, rhs_cols]); + normal.r_mut().set_zero(); + rhs.r_mut().set_zero(); + + Self { normal, rhs } + } + + pub fn add_chunk(&mut self, lhs: &DynamicArray, rhs: &DynamicArray) { + self.normal.r_mut().mult_into( + TransMode::ConjTrans, + TransMode::NoTrans, + num::One::one(), + lhs.r(), + lhs.r(), + num::One::one(), + ); + self.rhs.r_mut().mult_into( + TransMode::ConjTrans, + TransMode::NoTrans, + num::One::one(), + lhs.r(), + rhs.r(), + num::One::one(), + ); + } +} + +impl NormalEquationAccumulator +where + LuDecomposition, 2>>: + MatrixLuDecomposition, +{ + pub fn solve(mut self, tol_lstq: ::Real) -> DynamicArray { + add_diagonal(&mut self.normal, tol_lstq); + let lu = ::into_lu_alloc(self.normal).unwrap(); + let _ = as MatrixLuDecomposition>::solve_mat( + &lu, + TransMode::NoTrans, + self.rhs.r_mut(), + ); + self.rhs + } +} + +impl NormalEquationScratch { + pub fn new() -> Self { + Self { + solution: empty_array(), + projected_rhs: empty_array(), + } + } +} + +impl Default for NormalEquationScratch { + fn default() -> Self { + Self::new() } } @@ -154,12 +260,35 @@ pub fn block_extraction< where LuDecomposition, 2>>: MatrixLuDecomposition, +{ + let mut out = empty_array(); + block_extraction_into(test_mat, sketch_mat, ext_options, &mut out); + out +} + +pub fn block_extraction_into< + Item: RlstScalar + MatrixPseudoInverse + MatrixLu, + ArrayImpl: UnsafeRandomAccessByValue<2, Item = Item> + + UnsafeRandomAccessMut<2, Item = Item> + + Stride<2> + + RawAccessMut + + Shape<2>, +>( + test_mat: &mut Array, + sketch_mat: &Array, + ext_options: &ExtractOptions, + out: &mut DynamicArray, +) where + LuDecomposition, 2>>: + MatrixLuDecomposition, { match ext_options.block_extraction_method { - BlockExtractionMethod::Svd => solve_svd(test_mat, sketch_mat, ext_options.tol_lstsq), + BlockExtractionMethod::Svd => { + *out = solve_svd(test_mat, sketch_mat, ext_options.tol_lstsq); + } BlockExtractionMethod::LuLstSq => { let normal = NormalEquations::new(test_mat, ext_options.tol_lstsq); - normal.solve_normal_equations(sketch_mat) + normal.solve_normal_equations_into(sketch_mat, out); } } } @@ -187,6 +316,7 @@ pub fn nullify_near_sketch< sub_test: &Array, sub_sketch: &mut Array, id_options: &IdOptions, + normal_scratch: &mut NormalEquationScratch, ) where QrDecomposition, 2>>: MatrixQrDecomposition, @@ -229,7 +359,7 @@ pub fn nullify_near_sketch< } NullMethod::Projection => { let normal = NormalEquations::new(sub_test, id_options.tol_null); - normal.apply_null_projector(sub_sketch); + normal.apply_null_projector_with_scratch(sub_sketch, normal_scratch); } }; } diff --git a/src/utils/memory.rs b/src/utils/memory.rs new file mode 100644 index 0000000..d632961 --- /dev/null +++ b/src/utils/memory.rs @@ -0,0 +1,217 @@ +//! Lightweight process memory reporting helpers. + +use std::sync::{ + atomic::{AtomicU64, Ordering}, + OnceLock, +}; + +#[derive(Debug, Clone, Copy, Default)] +pub struct ProcessMemoryUsage { + pub resident_bytes: Option, + pub peak_resident_bytes: Option, +} + +static TRACE_LAST_RSS_BYTES: AtomicU64 = AtomicU64::new(0); +static TRACE_ENABLED: OnceLock = OnceLock::new(); +static TRACE_DELTA_BYTES: OnceLock = OnceLock::new(); + +pub fn format_bytes(bytes: u64) -> String { + const KIB: f64 = 1024.0; + const MIB: f64 = 1024.0 * 1024.0; + const GIB: f64 = 1024.0 * 1024.0 * 1024.0; + + let bytes_f64 = bytes as f64; + if bytes_f64 >= GIB { + format!("{:.2} GiB", bytes_f64 / GIB) + } else if bytes_f64 >= MIB { + format!("{:.2} MiB", bytes_f64 / MIB) + } else if bytes_f64 >= KIB { + format!("{:.2} KiB", bytes_f64 / KIB) + } else { + format!("{bytes} B") + } +} + +pub fn process_memory_usage() -> ProcessMemoryUsage { + platform::process_memory_usage() +} + +pub fn matrix_bytes(rows: usize, cols: usize) -> u64 { + (rows as u64) * (cols as u64) * (std::mem::size_of::() as u64) +} + +pub fn memory_trace_enabled() -> bool { + *TRACE_ENABLED.get_or_init(|| { + std::env::var("RSRS_TRACE_MEMORY") + .map(|value| value != "0") + .unwrap_or(false) + }) +} + +fn trace_delta_bytes() -> u64 { + *TRACE_DELTA_BYTES.get_or_init(|| { + std::env::var("RSRS_TRACE_MEMORY_DELTA_MB") + .ok() + .and_then(|value| value.parse::().ok()) + .map(|mb| mb * 1024 * 1024) + .unwrap_or(4 * 1024 * 1024) + }) +} + +pub fn trace_memory_growth(label: &str, estimated_bytes: Option) { + if !memory_trace_enabled() { + return; + } + + let usage = process_memory_usage(); + let Some(rss) = usage.resident_bytes else { + return; + }; + + let last_rss = TRACE_LAST_RSS_BYTES.load(Ordering::Relaxed); + if last_rss != 0 && rss < last_rss.saturating_add(trace_delta_bytes()) { + return; + } + + TRACE_LAST_RSS_BYTES.store(rss, Ordering::Relaxed); + + let peak = usage + .peak_resident_bytes + .map(format_bytes) + .unwrap_or_else(|| "unavailable".to_string()); + + match estimated_bytes { + Some(bytes) => println!( + "Memory trace [{label}]: rss = {}, peak = {peak}, est tmp ~= {}", + format_bytes(rss), + format_bytes(bytes) + ), + None => println!( + "Memory trace [{label}]: rss = {}, peak = {peak}", + format_bytes(rss) + ), + } +} + +pub fn trace_memory_event(label: &str, estimated_bytes: Option) { + if !memory_trace_enabled() { + return; + } + + let usage = process_memory_usage(); + let rss = usage + .resident_bytes + .map(format_bytes) + .unwrap_or_else(|| "unavailable".to_string()); + let peak = usage + .peak_resident_bytes + .map(format_bytes) + .unwrap_or_else(|| "unavailable".to_string()); + + match estimated_bytes { + Some(bytes) => println!( + "Memory event [{label}]: rss = {rss}, peak = {peak}, est tmp ~= {}", + format_bytes(bytes) + ), + None => println!("Memory event [{label}]: rss = {rss}, peak = {peak}"), + } +} + +#[cfg(target_os = "macos")] +mod platform { + use super::ProcessMemoryUsage; + use std::{mem::size_of, os::raw::c_int}; + + type KernReturn = c_int; + type MachPort = u32; + type TaskFlavor = u32; + type TaskInfoCount = u32; + + const KERN_SUCCESS: KernReturn = 0; + const MACH_TASK_BASIC_INFO: TaskFlavor = 20; + + #[repr(C)] + #[derive(Clone, Copy, Default)] + struct TimeValue { + seconds: c_int, + microseconds: c_int, + } + + #[repr(C)] + #[derive(Clone, Copy, Default)] + struct MachTaskBasicInfo { + virtual_size: u64, + resident_size: u64, + resident_size_max: u64, + user_time: TimeValue, + system_time: TimeValue, + policy: c_int, + suspend_count: c_int, + } + + unsafe extern "C" { + fn mach_task_self() -> MachPort; + fn task_info( + target_task: MachPort, + flavor: TaskFlavor, + task_info_out: *mut c_int, + task_info_out_count: *mut TaskInfoCount, + ) -> KernReturn; + } + + pub fn process_memory_usage() -> ProcessMemoryUsage { + let mut info = MachTaskBasicInfo::default(); + let mut count = (size_of::() / size_of::()) as TaskInfoCount; + + let result = unsafe { + task_info( + mach_task_self(), + MACH_TASK_BASIC_INFO, + &mut info as *mut _ as *mut c_int, + &mut count, + ) + }; + + if result == KERN_SUCCESS { + ProcessMemoryUsage { + resident_bytes: Some(info.resident_size), + peak_resident_bytes: Some(info.resident_size_max), + } + } else { + ProcessMemoryUsage::default() + } + } +} + +#[cfg(target_os = "linux")] +mod platform { + use super::ProcessMemoryUsage; + + fn parse_kib(contents: &str, key: &str) -> Option { + contents.lines().find_map(|line| { + let rest = line.strip_prefix(key)?.trim(); + let value = rest.split_whitespace().next()?.parse::().ok()?; + Some(value * 1024) + }) + } + + pub fn process_memory_usage() -> ProcessMemoryUsage { + let Ok(contents) = std::fs::read_to_string("/proc/self/status") else { + return ProcessMemoryUsage::default(); + }; + + ProcessMemoryUsage { + resident_bytes: parse_kib(&contents, "VmRSS:"), + peak_resident_bytes: parse_kib(&contents, "VmHWM:"), + } + } +} + +#[cfg(not(any(target_os = "macos", target_os = "linux")))] +mod platform { + use super::ProcessMemoryUsage; + + pub fn process_memory_usage() -> ProcessMemoryUsage { + ProcessMemoryUsage::default() + } +} diff --git a/tests/fixtures/biegrid_perturbed/cells22_scale1e-2_seed12345.bin b/tests/fixtures/biegrid_perturbed/cells22_scale1e-2_seed12345.bin new file mode 100644 index 0000000..3a4a7bb Binary files /dev/null and b/tests/fixtures/biegrid_perturbed/cells22_scale1e-2_seed12345.bin differ diff --git a/tests/fixtures/biegrid_perturbed/generate.py b/tests/fixtures/biegrid_perturbed/generate.py new file mode 100644 index 0000000..2e5b119 --- /dev/null +++ b/tests/fixtures/biegrid_perturbed/generate.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +"""Generate perturbed BIEGrid fixtures for tests/rsrs_operator_mv.rs.""" + +from __future__ import annotations + +import argparse +import struct +import sys +from pathlib import Path + +import numpy as np + +MAGIC = b"RSRS_BIEGRID_FIXTURE_V1" + + +def rel_fro_defect(lhs: np.ndarray, rhs: np.ndarray) -> float: + return float(np.linalg.norm(lhs - rhs) / max(np.linalg.norm(rhs), 1.0e-14)) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--cells-per-axis", type=int, default=22) + parser.add_argument("--scale", type=float, default=1.0e-2) + parser.add_argument("--seed", type=int, default=12345) + parser.add_argument("--rsrs-exps", type=Path) + parser.add_argument( + "--out", + type=Path, + ) + args = parser.parse_args() + + script_dir = Path(__file__).resolve().parent + rsrs_exps = args.rsrs_exps or (script_dir.parents[3] / "rsrs-exps") + out = args.out or (script_dir / "cells22_scale1e-2_seed12345.bin") + + sys.path.insert(0, str(rsrs_exps.resolve())) + from python.bie_grid import ( # pylint: disable=import-outside-toplevel + BIEGrid, + build_rank_one_box_perturbations, + build_real_rank_one_box_perturbations, + make_complex_wrapped_operator, + make_real_wrapped_operator, + ) + + ndim = 2 + experiment = BIEGrid(args.cells_per_axis**ndim, ndim) + n = experiment.N + raw_points = np.asarray(experiment.XX, dtype=np.float64) + if raw_points.shape[1] == 2: + points = np.column_stack([raw_points, np.zeros(raw_points.shape[0], dtype=np.float64)]) + else: + points = raw_points + + real_perturbation = build_real_rank_one_box_perturbations( + n, ndim, scale=args.scale, seed=args.seed + ) + complex_perturbation = build_rank_one_box_perturbations( + n, ndim, scale=args.scale, seed=args.seed + ) + + base_op = experiment.fast_apply_op + identity_real = np.eye(n, dtype=np.float64) + identity_complex = np.eye(n, dtype=np.complex128) + + real_symmetric_op = make_real_wrapped_operator( + base_op, n, perturbation=real_perturbation, symmetry_mode="real_symmetric" + ) + real_nonsymmetric_op = make_real_wrapped_operator( + base_op, n, perturbation=real_perturbation, symmetry_mode="none" + ) + complex_symmetric_op, _ = make_complex_wrapped_operator( + base_op, n, perturbation=complex_perturbation, symmetry_mode="complex_symmetric" + ) + complex_nonsymmetric_op, _ = make_complex_wrapped_operator( + base_op, n, perturbation=complex_perturbation, symmetry_mode="none" + ) + + real_symmetric = np.ascontiguousarray( + real_symmetric_op.matmat(identity_real), dtype=np.float64 + ) + real_nonsymmetric = np.ascontiguousarray( + real_nonsymmetric_op.matmat(identity_real), dtype=np.float64 + ) + complex_symmetric = np.ascontiguousarray( + complex_symmetric_op.matmat(identity_complex), dtype=np.complex128 + ) + complex_nonsymmetric = np.ascontiguousarray( + complex_nonsymmetric_op.matmat(identity_complex), dtype=np.complex128 + ) + + out.parent.mkdir(parents=True, exist_ok=True) + with out.open("wb") as handle: + handle.write(MAGIC) + handle.write(struct.pack(", + cells_per_axis: usize, + perturbation_scale: f64, + perturbation_seed: u64, + real_symmetric: DynamicArray, + real_nonsymmetric: DynamicArray, + complex_symmetric: DynamicArray, 2>, + complex_nonsymmetric: DynamicArray, 2>, +} + +struct FixtureReader<'a> { + bytes: &'a [u8], + offset: usize, +} + +impl<'a> FixtureReader<'a> { + fn new(bytes: &'a [u8]) -> Self { + Self { bytes, offset: 0 } + } + + fn read_exact(&mut self, len: usize) -> &'a [u8] { + let end = self + .offset + .checked_add(len) + .expect("BIEGrid fixture offset overflow"); + let out = self + .bytes + .get(self.offset..end) + .expect("BIEGrid fixture ended unexpectedly"); + self.offset = end; + out + } + + fn read_u64(&mut self) -> u64 { + u64::from_le_bytes( + self.read_exact(8) + .try_into() + .expect("BIEGrid fixture u64 has wrong width"), + ) + } + + fn read_f64(&mut self) -> f64 { + f64::from_le_bytes( + self.read_exact(8) + .try_into() + .expect("BIEGrid fixture f64 has wrong width"), + ) + } + + fn finish(self) { + assert_eq!( + self.offset, + self.bytes.len(), + "BIEGrid fixture has trailing bytes" + ); + } +} + +fn load_real_fixture_matrix(reader: &mut FixtureReader<'_>, dim: usize) -> DynamicArray { + let mut matrix = rlst_dynamic_array2!(f64, [dim, dim]); + for row in 0..dim { + for col in 0..dim { + matrix[[row, col]] = reader.read_f64(); + } + } + matrix +} + +fn load_complex_fixture_matrix( + reader: &mut FixtureReader<'_>, + dim: usize, +) -> DynamicArray, 2> { + let mut matrix = rlst_dynamic_array2!(Complex, [dim, dim]); + for row in 0..dim { + for col in 0..dim { + matrix[[row, col]] = Complex::new(reader.read_f64(), reader.read_f64()); + } + } + matrix +} + +fn load_biegrid_perturbed_fixture() -> BiegridPerturbedFixture { + let mut reader = FixtureReader::new(BIEGRID_FIXTURE_BYTES); + assert_eq!( + reader.read_exact(BIEGRID_FIXTURE_MAGIC.len()), + BIEGRID_FIXTURE_MAGIC, + "BIEGrid fixture has an unexpected magic header" + ); + + let dim = reader.read_u64() as usize; + let cells_per_axis = reader.read_u64() as usize; + let perturbation_scale = reader.read_f64(); + let perturbation_seed = reader.read_u64(); + let mut points = Vec::with_capacity(dim); + for _ in 0..dim { + points.push(Point::new( + [reader.read_f64(), reader.read_f64(), reader.read_f64()], + 0, + )); + } + + let real_symmetric = load_real_fixture_matrix(&mut reader, dim); + let real_nonsymmetric = load_real_fixture_matrix(&mut reader, dim); + let complex_symmetric = load_complex_fixture_matrix(&mut reader, dim); + let complex_nonsymmetric = load_complex_fixture_matrix(&mut reader, dim); + reader.finish(); + + BiegridPerturbedFixture { + points, + cells_per_axis, + perturbation_scale, + perturbation_seed, + real_symmetric, + real_nonsymmetric, + complex_symmetric, + complex_nonsymmetric, + } +} + +fn run_dense_diagnostics() -> bool { + std::env::var(RUN_DENSE_DIAGNOSTICS_ENV) + .map(|value| !matches!(value.as_str(), "" | "0" | "false" | "False" | "FALSE")) + .unwrap_or(false) +} + +/// Generates deterministic points and projects them back onto an approximate +/// sphere to keep the geometry smooth and reproducible. +#[allow(dead_code)] +fn sphere_surface(npoints: usize, comm: &C) -> Vec { + let mut rng = ChaCha8Rng::seed_from_u64(0); + let mut points = generate_random_points(npoints, &mut rng, comm); + + let x: Vec = points.iter().map(|point| point.coords()[0]).collect(); + let y: Vec = points.iter().map(|point| point.coords()[1]).collect(); + let z: Vec = points.iter().map(|point| point.coords()[2]).collect(); + let centre = Point::new( + [ + (x.iter().copied().fold(f64::INFINITY, f64::min) + + x.iter().copied().fold(f64::NEG_INFINITY, f64::max)) + * 0.5, + (y.iter().copied().fold(f64::INFINITY, f64::min) + + y.iter().copied().fold(f64::NEG_INFINITY, f64::max)) + * 0.5, + (z.iter().copied().fold(f64::INFINITY, f64::min) + + z.iter().copied().fold(f64::NEG_INFINITY, f64::max)) + * 0.5, + ], + 0, + ); + + for point in &mut points { + let mut aux = [ + point.coords()[0] - centre.coords()[0], + point.coords()[1] - centre.coords()[1], + point.coords()[2] - centre.coords()[2], + ]; + let len = (aux[0] * aux[0] + aux[1] * aux[1] + aux[2] * aux[2]).sqrt(); + aux[0] /= len; + aux[1] /= len; + aux[2] /= len; + point.coords_mut()[0] = aux[0] + centre.coords()[0]; + point.coords_mut()[1] = aux[1] + centre.coords()[1]; + point.coords_mut()[2] = aux[2] + centre.coords()[2]; + } + + points +} + +/// Scalar Laplace kernel used for the dense reference matrices. +#[allow(dead_code)] +fn laplace_kernel(dist: f64, npoints: usize) -> f64 { + let pi = std::f64::consts::PI; + let n = npoints as f64; + 1.0 / (4.0 * pi * n * dist) +} + +/// Dense symmetric reference matrix based on the Laplace kernel. +#[allow(dead_code)] +fn laplace_matrix(points: &[Point]) -> DynamicArray { + let n = points.len(); + let mut arr = rlst_dynamic_array2!(f64, [n, n]); + let mut view = arr.r_mut(); + for (i, point_x) in points.iter().enumerate() { + for (j, point_y) in points.iter().enumerate() { + let coords_x = point_x.coords(); + let coords_y = point_y.coords(); + let dist = ((coords_x[0] - coords_y[0]).powi(2) + + (coords_x[1] - coords_y[1]).powi(2) + + (coords_x[2] - coords_y[2]).powi(2)) + .sqrt(); + view[[i, j]] = if dist > 0.0 { + laplace_kernel(dist, n) + } else { + 1.0 + }; + } + } + arr +} + +/// Real-valued perturbation of the Laplace matrix that deliberately breaks +/// symmetry while staying close to the same scaling. +#[allow(dead_code)] +fn nonsymmetric_real_matrix(points: &[Point]) -> DynamicArray { + let base = laplace_matrix(points); + let n = points.len(); + let mut arr = rlst_dynamic_array2!(f64, [n, n]); + let mut view = arr.r_mut(); + for i in 0..n { + for j in 0..n { + let skew = 0.20 * (points[i].coords()[0] - points[j].coords()[0]); + view[[i, j]] = if i == j { + 1.0 + } else { + base[[i, j]] * (1.0 + skew) + }; + } + } + arr +} + +/// Complex Hermitian matrix used to exercise transpose and conjugation logic. +/// +/// The imaginary part is antisymmetric, so the full matrix satisfies +/// `A^H = A` while still being genuinely complex. +#[allow(dead_code)] +fn hermitian_complex_matrix(points: &[Point]) -> DynamicArray, 2> { + let base = laplace_matrix(points); + let n = points.len(); + let mut arr = rlst_dynamic_array2!(Complex, [n, n]); + let mut view = arr.r_mut(); + for i in 0..n { + for j in 0..n { + if i == j { + view[[i, j]] = Complex::new(1.0, 0.0); + } else { + let imag = 0.20 * (points[i].coords()[0] - points[j].coords()[0]); + view[[i, j]] = Complex::new(base[[i, j]], base[[i, j]] * imag); + } + } + } + arr +} + +/// Complex symmetric matrix used to compare the `Symmetric` and `NoSymm` +/// construction paths on the same operator. +/// +/// The imaginary part is symmetric, so `A^T = A`, but the matrix is not +/// Hermitian because conjugation changes the sign of that component. +#[allow(dead_code)] +fn symmetric_complex_matrix(points: &[Point]) -> DynamicArray, 2> { + let base = laplace_matrix(points); + let n = points.len(); + let mut arr = rlst_dynamic_array2!(Complex, [n, n]); + let mut view = arr.r_mut(); + for i in 0..n { + for j in 0..n { + if i == j { + view[[i, j]] = Complex::new(1.0, 0.0); + } else { + let imag = 0.15 * (points[i].coords()[0] + points[j].coords()[0]) + + 0.05 * points[i].coords()[1] * points[j].coords()[1]; + view[[i, j]] = Complex::new(base[[i, j]], base[[i, j]] * imag); + } + } + } + arr +} + +/// Complex matrix with no transpose or Hermitian symmetry. +#[allow(dead_code)] +fn nonsymmetric_complex_matrix(points: &[Point]) -> DynamicArray, 2> { + let base = laplace_matrix(points); + let n = points.len(); + let mut arr = rlst_dynamic_array2!(Complex, [n, n]); + let mut view = arr.r_mut(); + for i in 0..n { + for j in 0..n { + if i == j { + view[[i, j]] = Complex::new(1.0, 0.10); + } else { + let real_skew = 0.18 * (points[i].coords()[0] - points[j].coords()[1]); + let imag_skew = 0.14 * points[i].coords()[2] + 0.07 * points[j].coords()[0] + - 0.05 * points[i].coords()[1] * points[j].coords()[2]; + view[[i, j]] = + Complex::new(base[[i, j]] * (1.0 + real_skew), base[[i, j]] * imag_skew); + } + } + } + arr +} + +/// Deterministic real input vector used across the real-valued cases. +fn deterministic_real_vector(n: usize) -> Vec { + (0..n) + .map(|i| { + let t = (i + 1) as f64; + t.sin() + 0.25 * t.cos() + }) + .collect() +} + +/// Deterministic complex input vector used across the complex-valued cases. +fn deterministic_complex_vector(n: usize) -> Vec> { + (0..n) + .map(|i| { + let t = (i + 1) as f64; + Complex::new(t.sin() + 0.25 * t.cos(), 0.5 * t.cos() - 0.15 * t.sin()) + }) + .collect() +} + +/// Relative `l2` error with a small denominator floor for near-zero references. +fn rel_l2_error(actual: &[Item], expected: &[Item]) -> f64 { + let mut actual_arr = rlst_dynamic_array1!(Item, [actual.len()]); + let mut expected_arr = rlst_dynamic_array1!(Item, [expected.len()]); + actual_arr.fill_from_raw_data(actual); + expected_arr.fill_from_raw_data(expected); + + let mut diff = empty_array(); + diff.fill_from_resize(actual_arr.r() - expected_arr.r()); + + let num = diff.norm_2(); + let den = expected_arr.norm_2(); + let num_f64: f64 = num::NumCast::from(num).unwrap(); + let den_f64: f64 = num::NumCast::from(den).unwrap(); + num_f64 / den_f64.max(1.0e-14) +} + +// Glossary for the vector relative-l2 diagnostics printed in this test: +// - `op_vs_factor_left`: public operator `NoTrans` versus direct left factor apply. +// - `op_vs_factor_trans_identity`: public operator `Trans` versus direct right +// factor apply, the identity used to implement transpose-vector products. +// - `left_dense`: public operator `NoTrans` versus dense `A x`. +// - `trans_dense`: public operator `Trans` versus dense `A^T x`. +// - `left_trans_dense`: direct left factor `Trans` apply versus dense `A^T x`. +// - `right_dense`: direct right factor `NoTrans` apply versus dense `x A`. +// - `right_trans_dense`: direct right factor `Trans` apply versus dense `x A^T`. +// - `conj_no_trans`: public operator `ConjNoTrans` versus dense `conj(A) x`. +// - `conj_trans`: public operator `ConjTrans` versus dense `A^H x`. +// - `no_trans_vs_trans`: public operator `A x` versus public operator `A^T x`. +// - `dense_no_trans_vs_trans`: dense `A x` versus dense `A^T x`; this is the +// reference signal that a nonsymmetric matrix has not collapsed to symmetric. + +/// Dense reference apply used to validate both the operator wrapper and the +/// direct factor-level matvec path. +/// +/// `use_adjoint_reference` exists for cases where a caller wants the dense +/// `Trans` reference to behave like an adjoint instead of a plain transpose. +fn dense_apply( + matrix: &DynamicArray, + input: &[Item], + side: Side, + trans_mode: TransMode, + use_adjoint_reference: bool, +) -> Vec { + let mut input_arr = match side { + Side::Left => rlst_dynamic_array2!(Item, [input.len(), 1]), + Side::Right => rlst_dynamic_array2!(Item, [1, input.len()]), + }; + input_arr.fill_from_raw_data(input); + + let mut out = empty_array(); + if use_adjoint_reference && matches!(trans_mode, TransMode::Trans) { + let mut adjoint = empty_array(); + adjoint.fill_from_resize(matrix.r().transpose().conj()); + match side { + Side::Left => out + .r_mut() + .simple_mult_into_resize(adjoint.r(), input_arr.r()), + Side::Right => out + .r_mut() + .simple_mult_into_resize(input_arr.r(), adjoint.r()), + }; + } else { + match side { + Side::Left => out.r_mut().mult_into_resize( + trans_mode, + TransMode::NoTrans, + Item::from_real(Item::real(1.0)), + matrix.r(), + input_arr.r(), + Item::from_real(Item::real(0.0)), + ), + Side::Right => out.r_mut().mult_into_resize( + TransMode::NoTrans, + trans_mode, + Item::from_real(Item::real(1.0)), + input_arr.r(), + matrix.r(), + Item::from_real(Item::real(0.0)), + ), + }; + } + + out.r().iter().collect() +} + +/// Canonical basis vector used to materialize full operator action matrices. +fn basis_vector(dim: usize, index: usize) -> Vec { + let mut basis = vec![Item::from_real(Item::real(0.0)); dim]; + basis[index] = Item::from_real(Item::real(1.0)); + basis +} + +/// Flattens a dense matrix in column-major order so transposed actions can be +/// compared with basis-vector operator assembly. +fn flatten_column_major(matrix: &DynamicArray) -> Vec { + let shape = matrix.shape(); + let view = matrix.r(); + let mut values = Vec::with_capacity(shape[0] * shape[1]); + + for col in 0..shape[1] { + for row in 0..shape[0] { + values.push(view[[row, col]]); + } + } + + values +} + +/// Transposes a square column-major matrix while preserving the same flattened +/// column-major layout in the output. +fn transpose_column_major(matrix: &[Item], dim: usize) -> Vec { + let mut transposed = Vec::with_capacity(matrix.len()); + + for col in 0..dim { + for row in 0..dim { + transposed.push(matrix[row * dim + col]); + } + } + + transposed +} + +fn conjugate_column_major(matrix: &[Item]) -> Vec { + matrix.iter().map(|value| value.conj()).collect() +} + +fn adjoint_column_major(matrix: &[Item], dim: usize) -> Vec { + conjugate_column_major(&transpose_column_major(matrix, dim)) +} + +/// Materializes the full action matrix of the public RSRS operator for one +/// transpose mode by applying it to each basis vector. +fn assemble_operator_matrix(op: &Op, dim: usize, trans_mode: TransMode) -> Vec +where + Item: RlstScalar, + Op: AsApply, Range = ArrayVectorSpace>, +{ + let mut matrix = Vec::with_capacity(dim * dim); + + for col in 0..dim { + let basis = basis_vector(dim, col); + matrix.extend(apply_operator(op, &basis, trans_mode)); + } + + matrix +} + +/// Materializes the left- or right-application matrix exposed by the factor +/// container, again by probing the canonical basis. +fn assemble_factor_matrix( + factors: &Factors, + dim: usize, + side: Side, + base_options: &BaseFactorOptions, +) -> Vec +where + Item: RlstScalar, + Factors: RsrsFactorsImpl, +{ + let mut matrix = Vec::with_capacity(dim * dim); + + for col in 0..dim { + let basis = basis_vector(dim, col); + let mut output = vec![Item::from_real(Item::real(0.0)); dim]; + factors.matvec(&basis, &mut output, side, base_options); + matrix.extend(output); + } + + matrix +} + +/// Applies the high-level `rlst` operator interface and materializes the output +/// as a plain vector for easy comparison. +fn apply_operator(op: &Op, input: &[Item], trans_mode: TransMode) -> Vec +where + Item: RlstScalar, + Op: AsApply, Range = ArrayVectorSpace>, +{ + let mut x = zero_element(op.domain()); + x.imp_mut().fill_inplace_raw(input); + let y = op.apply(x.r(), trans_mode); + y.view().iter().collect() +} + +fn frobenius_norm_flat(matrix: &[Item]) -> f64 { + matrix + .iter() + .map(|value| { + let norm_sqr: f64 = num::NumCast::from((*value).square()).unwrap(); + norm_sqr + }) + .sum::() + .sqrt() +} + +fn rel_fro_error_flat(actual: &[Item], expected: &[Item]) -> f64 { + let diff = actual + .iter() + .zip(expected.iter()) + .map(|(actual, expected)| { + let norm_sqr: f64 = num::NumCast::from((*actual - *expected).square()).unwrap(); + norm_sqr + }) + .sum::() + .sqrt(); + + diff / frobenius_norm_flat(expected).max(1.0e-14) +} + +fn multiply_column_major(left: &[Item], right: &[Item], dim: usize) -> Vec { + let mut product = vec![Item::from_real(Item::real(0.0)); dim * dim]; + for col in 0..dim { + for row in 0..dim { + let mut value = Item::from_real(Item::real(0.0)); + for k in 0..dim { + value += left[row + k * dim] * right[k + col * dim]; + } + product[row + col * dim] = value; + } + } + + product +} + +fn rel_identity_fro_error_flat(matrix: &[Item], dim: usize) -> f64 { + let diff = matrix + .iter() + .enumerate() + .map(|(index, value)| { + let row = index % dim; + let col = index / dim; + let expected = if row == col { + Item::from_real(Item::real(1.0)) + } else { + Item::from_real(Item::real(0.0)) + }; + let norm_sqr: f64 = num::NumCast::from((*value - expected).square()).unwrap(); + norm_sqr + }) + .sum::() + .sqrt(); + + diff / (dim as f64).sqrt().max(1.0e-14) +} + +fn run_inverse_dense_diagnostic( + label: &str, + matrix: &DynamicArray, + rsrs_op: &mut Op, +) where + Item: RlstScalar, + Op: AsApply, Range = ArrayVectorSpace> + Inv, +{ + let dim = matrix.shape()[0]; + let dense_matrix = flatten_column_major(matrix); + + rsrs_op.inv(true); + let inverse_matrix = assemble_operator_matrix(rsrs_op, dim, TransMode::NoTrans); + rsrs_op.inv(false); + + // Full left inverse check: ||RSRS^{-1} A - I||_F / ||I||_F. + let left_identity = multiply_column_major(&inverse_matrix, &dense_matrix, dim); + // Full right inverse check: ||A RSRS^{-1} - I||_F / ||I||_F. + let right_identity = multiply_column_major(&dense_matrix, &inverse_matrix, dim); + // Full sandwich check: ||RSRS^{-1} A RSRS^{-1} - RSRS^{-1}||_F / ||RSRS^{-1}||_F. + let sandwich = multiply_column_major(&left_identity, &inverse_matrix, dim); + + println!( + "{label}: dense_inverse full_left_identity={:.3e}, full_right_identity={:.3e}, full_sandwich_vs_inverse={:.3e}, inverse_norm={:.3e}", + rel_identity_fro_error_flat(&left_identity, dim), + rel_identity_fro_error_flat(&right_identity, dim), + rel_fro_error_flat(&sandwich, &inverse_matrix), + frobenius_norm_flat(&inverse_matrix) + ); +} + +fn clone_matrix(matrix: &DynamicArray) -> DynamicArray { + let mut out = empty_array(); + out.r_mut().fill_from_resize(matrix.r()); + out +} + +fn block_fro_norm( + matrix: &DynamicArray, + rows: &[usize], + cols: &[usize], +) -> f64 { + rows.iter() + .flat_map(|&row| cols.iter().map(move |&col| matrix[[row, col]])) + .map(|value| { + let norm_sqr: f64 = num::NumCast::from(value.square()).unwrap(); + norm_sqr + }) + .sum::() + .sqrt() +} + +fn apply_dense_factor( + factor: &Factor, + target: &mut DynamicArray, + options: &MulOptions, +) where + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, + StandardNormal: rand_distr::Distribution, + Standard: rand_distr::Distribution, + ::Real: RandScalar, + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + match factor { + Factor::Id(id_factor) => id_factor.mul(target, options), + Factor::Lu(lu_factor) => lu_factor.mul(target, options), + Factor::Diag(diag_factor) => diag_factor.mul(target, options), + } +} + +#[derive(Clone, Copy)] +enum DenseBoxApplyMode { + LeftOnly, + RightOnly, + Sandwich, +} + +impl DenseBoxApplyMode { + fn label(self) -> &'static str { + match self { + Self::LeftOnly => "left_only", + Self::RightOnly => "right_only", + Self::Sandwich => "sandwich", + } + } +} + +#[derive(Clone, Copy)] +struct DenseBoxNorms { + first: f64, + second: f64, +} + +struct DenseBoxKind { + name: &'static str, + first_label: &'static str, + second_label: &'static str, +} + +fn dense_box_rel_norms(after: DenseBoxNorms, before: DenseBoxNorms) -> DenseBoxNorms { + DenseBoxNorms { + first: after.first / before.first.max(1.0e-14), + second: after.second / before.second.max(1.0e-14), + } +} + +fn dense_box_norms( + factor: &Factor, + target: &DynamicArray, +) -> Option<(DenseBoxKind, DenseBoxNorms)> { + match factor { + Factor::Id(id_factor) => Some(( + DenseBoxKind { + name: "id", + first_label: "rf", + second_label: "fr", + }, + DenseBoxNorms { + // ID left-zero target: residual rows against far-field columns. + first: block_fro_norm(target, &id_factor.ind_r, &id_factor.ind_f), + // ID right-zero target: far-field rows against residual columns. + second: block_fro_norm(target, &id_factor.ind_f, &id_factor.ind_r), + }, + )), + Factor::Lu(lu_factor) => Some(( + DenseBoxKind { + name: "lu", + first_label: "rt", + second_label: "tr", + }, + DenseBoxNorms { + // LU left-zero target: residual rows against target columns. + first: block_fro_norm(target, &lu_factor.ind_r, &lu_factor.ind_t), + // LU right-zero target: target rows against residual columns. + second: block_fro_norm(target, &lu_factor.ind_t, &lu_factor.ind_r), + }, + )), + Factor::Diag(_) => None, + } +} + +fn apply_dense_box_factor_mode( + factor: &Factor, + target: &mut DynamicArray, + mode: DenseBoxApplyMode, + left_options: &MulOptions, + right_options: &MulOptions, +) where + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, + StandardNormal: rand_distr::Distribution, + Standard: rand_distr::Distribution, + ::Real: RandScalar, + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + match mode { + DenseBoxApplyMode::LeftOnly => apply_dense_factor(factor, target, left_options), + DenseBoxApplyMode::RightOnly => apply_dense_factor(factor, target, right_options), + DenseBoxApplyMode::Sandwich => { + apply_dense_factor(factor, target, left_options); + apply_dense_factor(factor, target, right_options); + } + } +} + +#[allow(clippy::too_many_arguments)] +fn run_dense_box_error_factor( + label: &str, + mode: DenseBoxApplyMode, + level: usize, + batch: usize, + index: usize, + factor: &Factor, + target: &mut DynamicArray, + left_options: &MulOptions, + right_options: &MulOptions, +) -> Option +where + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, + StandardNormal: rand_distr::Distribution, + Standard: rand_distr::Distribution, + ::Real: RandScalar, + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + let (kind, before) = dense_box_norms(factor, target)?; + + let mut conj_target = empty_array(); + conj_target.r_mut().fill_from_resize(target.r().conj()); + + apply_dense_box_factor_mode(factor, target, mode, left_options, right_options); + apply_dense_box_factor_mode(factor, &mut conj_target, mode, left_options, right_options); + + let (_, after) = dense_box_norms(factor, target).expect("ID/LU factor kind changed"); + let (_, conj_after) = dense_box_norms(factor, &conj_target).expect("ID/LU factor kind changed"); + let rel = dense_box_rel_norms(after, before); + let conj_rel = dense_box_rel_norms(conj_after, before); + + println!( + "{label}: dense_boxes mode={} family={} level={} batch={} index={} target=({}={:.3e}, {}={:.3e}) conj_target=({}={:.3e}, {}={:.3e}) before=({}={:.3e}, {}={:.3e}) after=({}={:.3e}, {}={:.3e}) conj_after=({}={:.3e}, {}={:.3e})", + mode.label(), + kind.name, + level, + batch, + index, + kind.first_label, + rel.first, + kind.second_label, + rel.second, + kind.first_label, + conj_rel.first, + kind.second_label, + conj_rel.second, + kind.first_label, + before.first, + kind.second_label, + before.second, + kind.first_label, + after.first, + kind.second_label, + after.second, + kind.first_label, + conj_after.first, + kind.second_label, + conj_after.second + ); + + Some(rel) +} + +#[allow(clippy::too_many_arguments)] +fn run_dense_box_error_batch( + label: &str, + mode: DenseBoxApplyMode, + level: usize, + batch: usize, + factors: &[Factor], + target: &mut DynamicArray, + left_options: &MulOptions, + right_options: &MulOptions, +) where + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, + StandardNormal: rand_distr::Distribution, + Standard: rand_distr::Distribution, + ::Real: RandScalar, + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + for (index, factor) in factors.iter().enumerate() { + run_dense_box_error_factor( + label, + mode, + level, + batch, + index, + factor, + target, + left_options, + right_options, + ); + } +} + +fn run_dense_box_errors_mode( + label: &str, + mode: DenseBoxApplyMode, + matrix: &DynamicArray, + factors: &RsrsFactors, +) where + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, + StandardNormal: rand_distr::Distribution, + Standard: rand_distr::Distribution, + ::Real: RandScalar, + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + println!( + "{label}: dense_boxes mode={} copied from get_boxes_errors; target matrix is mutated step by step", + mode.label() + ); + + let mut target = clone_matrix(matrix); + let base_options = BaseFactorOptions { + inv: true, + trans: TransMode::NoTrans, + trans_target: false, + }; + let left_options = MulOptions { + base_options: base_options.clone(), + side: Side::Left, + factor_type: FactorType::F, + }; + let right_options = MulOptions { + base_options, + side: Side::Right, + factor_type: FactorType::S, + }; + + match &factors.id_factors { + MultiLevelIdFactors::Batched(levels) => { + for (level, id_batches) in levels.iter().enumerate().take(factors.num_levels) { + for (batch, id_batch) in id_batches.iter().enumerate() { + run_dense_box_error_batch( + label, + mode, + level, + batch, + id_batch, + &mut target, + &left_options, + &right_options, + ); + run_dense_box_error_batch( + label, + mode, + level, + batch, + &factors.lu_factors[level][batch], + &mut target, + &left_options, + &right_options, + ); + } + } + } + MultiLevelIdFactors::Single(levels) => { + for (level, id_factors) in levels.iter().enumerate().take(factors.num_levels) { + run_dense_box_error_batch( + label, + mode, + level, + 0, + id_factors, + &mut target, + &left_options, + &right_options, + ); + for (batch, lu_batch) in factors.lu_factors[level].iter().enumerate() { + run_dense_box_error_batch( + label, + mode, + level, + batch, + lu_batch, + &mut target, + &left_options, + &right_options, + ); + } + } + } + } +} + +fn run_dense_box_errors_left_application( + label: &str, + matrix: &DynamicArray, + factors: &RsrsFactors, +) where + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, + StandardNormal: rand_distr::Distribution, + Standard: rand_distr::Distribution, + ::Real: RandScalar, + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + run_dense_box_errors_mode(label, DenseBoxApplyMode::LeftOnly, matrix, factors); +} + +fn run_dense_box_errors_right_application( + label: &str, + matrix: &DynamicArray, + factors: &RsrsFactors, +) where + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, + StandardNormal: rand_distr::Distribution, + Standard: rand_distr::Distribution, + ::Real: RandScalar, + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + run_dense_box_errors_mode(label, DenseBoxApplyMode::RightOnly, matrix, factors); +} + +fn run_dense_box_errors_sandwich_application( + label: &str, + matrix: &DynamicArray, + factors: &RsrsFactors, +) where + Item: RlstScalar + + RandScalar + + MatrixId + + MatrixIdNoSkel + + MatrixInverse + + MatrixPseudoInverse + + MatrixLu + + MatrixQr, + StandardNormal: rand_distr::Distribution, + Standard: rand_distr::Distribution, + ::Real: RandScalar, + LuDecomposition, 2>>: + MatrixLuDecomposition, + TriangularMatrix: TriangularOperations, +{ + run_dense_box_errors_mode(label, DenseBoxApplyMode::Sandwich, matrix, factors); +} + +struct FrobeniusMetrics { + // ||A||_F for the dense no-transpose reference. + reference_norm: f64, + // ||A^T||_F for the dense transpose reference. + transpose_reference_norm: f64, + // rsrs-exps `rel_fro`/`apply_fro`: ||RSRS - A||_F / ||A||_F. + op_no_trans_dense_err: f64, + // rsrs-exps `rel_fro_T`: ||RSRS^T_route - A^T||_F / ||A^T||_F. + op_trans_dense_err: f64, + // ||factor_right - A^T||_F / ||A^T||_F, checking the right-apply route. + factor_right_dense_err: f64, + // ||operator_no_trans - factor_left||_F / ||factor_left||_F. + op_no_trans_factor_err: f64, + // ||operator_trans - factor_right||_F / ||factor_right||_F. + op_trans_factor_err: f64, + // Plain transpose-route check: ||operator_trans - operator_no_trans^T||_F. + op_trans_vs_no_trans_transpose: f64, + // rsrs-exps `adj(T)`: ||operator_trans - operator_no_trans^H||_F. + // For real matrices this is the same as the plain transpose-route check. + op_trans_vs_no_trans_adjoint: f64, + // Plain factor transpose-route check: ||factor_right - factor_left^T||_F. + factor_right_vs_left_transpose: f64, + // Symmetry collapse check: ||operator_no_trans - operator_trans||_F. + no_trans_vs_trans: f64, + // Dense-reference symmetry check: ||A - A^T||_F. + dense_no_trans_vs_trans: f64, +} + +fn run_frobenius_matrix_check( + label: &str, + matrix: &DynamicArray, + rsrs_op: &Op, + factors: &Factors, + base_no_trans: &BaseFactorOptions, +) -> FrobeniusMetrics +where + Item: RlstScalar, + Op: AsApply, Range = ArrayVectorSpace>, + Factors: RsrsFactorsImpl, +{ + let dim = matrix.shape()[0]; + let op_no_trans_matrix = assemble_operator_matrix(rsrs_op, dim, TransMode::NoTrans); + let op_trans_matrix = assemble_operator_matrix(rsrs_op, dim, TransMode::Trans); + let factor_left_matrix = assemble_factor_matrix(factors, dim, Side::Left, base_no_trans); + let factor_right_matrix = assemble_factor_matrix(factors, dim, Side::Right, base_no_trans); + let dense_no_trans_matrix = flatten_column_major(matrix); + let dense_trans_matrix = transpose_column_major(&dense_no_trans_matrix, dim); + + let metrics = FrobeniusMetrics { + reference_norm: frobenius_norm_flat(&dense_no_trans_matrix), + transpose_reference_norm: frobenius_norm_flat(&dense_trans_matrix), + op_no_trans_dense_err: rel_fro_error_flat(&op_no_trans_matrix, &dense_no_trans_matrix), + op_trans_dense_err: rel_fro_error_flat(&op_trans_matrix, &dense_trans_matrix), + factor_right_dense_err: rel_fro_error_flat(&factor_right_matrix, &dense_trans_matrix), + op_no_trans_factor_err: rel_fro_error_flat(&op_no_trans_matrix, &factor_left_matrix), + op_trans_factor_err: rel_fro_error_flat(&op_trans_matrix, &factor_right_matrix), + op_trans_vs_no_trans_transpose: rel_fro_error_flat( + &op_trans_matrix, + &transpose_column_major(&op_no_trans_matrix, dim), + ), + op_trans_vs_no_trans_adjoint: rel_fro_error_flat( + &op_trans_matrix, + &adjoint_column_major(&op_no_trans_matrix, dim), + ), + factor_right_vs_left_transpose: rel_fro_error_flat( + &factor_right_matrix, + &transpose_column_major(&factor_left_matrix, dim), + ), + no_trans_vs_trans: rel_fro_error_flat(&op_no_trans_matrix, &op_trans_matrix), + dense_no_trans_vs_trans: rel_fro_error_flat(&dense_no_trans_matrix, &dense_trans_matrix), + }; + + println!( + "{label}: frob_ref={:.3e}, frob_ref_trans={:.3e}, frob_no_trans_dense={:.3e}, frob_trans_dense={:.3e}, frob_factor_right_dense={:.3e}, frob_no_trans_factor={:.3e}, frob_trans_factor={:.3e}, frob_trans_vs_no_trans_transpose={:.3e}, frob_trans_vs_no_trans_adjoint={:.3e}, frob_factor_right_vs_left_transpose={:.3e}, frob_no_trans_vs_trans={:.3e}, frob_dense_no_trans_vs_trans={:.3e}", + metrics.reference_norm, + metrics.transpose_reference_norm, + metrics.op_no_trans_dense_err, + metrics.op_trans_dense_err, + metrics.factor_right_dense_err, + metrics.op_no_trans_factor_err, + metrics.op_trans_factor_err, + metrics.op_trans_vs_no_trans_transpose, + metrics.op_trans_vs_no_trans_adjoint, + metrics.factor_right_vs_left_transpose, + metrics.no_trans_vs_trans, + metrics.dense_no_trans_vs_trans + ); + // Frobenius diagnostics mirrored from rsrs-exps. We deliberately do not + // add the rsrs-exps l2/spectral/solve checks here; the existing single + // vector relative-l2 diagnostics remain printed separately above. + println!( + "{label}: rsrs_exps_frob rel_fro={:.3e}, rel_fro_T={:.3e}, apply_fro={:.3e}, adj(T)={:.3e}, transpose_route={:.3e}", + metrics.op_no_trans_dense_err, + metrics.op_trans_dense_err, + metrics.op_no_trans_dense_err, + metrics.op_trans_vs_no_trans_adjoint, + metrics.op_trans_vs_no_trans_transpose + ); + + metrics +} + +struct ConjugateFrobeniusMetrics { + // ||ConjNoTrans(RSRS) - conj(A)||_F / ||conj(A)||_F. + conj_no_trans_dense_err: f64, + // ||ConjTrans(RSRS) - A^H||_F / ||A^H||_F. + conj_trans_dense_err: f64, + // rsrs-exps `adj(H)`: ||ConjTrans(RSRS) - NoTrans(RSRS)^H||_F. + conj_trans_vs_no_trans_adjoint: f64, + // Exact dense identity error for direct ConjTrans vs conj(Trans(conj(x))). + exact_manual_conjtrans: f64, + // RSRS identity error for direct ConjTrans vs conj(Trans(conj(x))). + rsrs_manual_conjtrans: f64, +} + +fn run_conjugate_frobenius_matrix_check( + label: &str, + matrix: &DynamicArray, + rsrs_op: &Op, +) -> ConjugateFrobeniusMetrics +where + Item: RlstScalar, + Op: AsApply, Range = ArrayVectorSpace>, +{ + let dim = matrix.shape()[0]; + let dense_no_trans_matrix = flatten_column_major(matrix); + let dense_trans_matrix = transpose_column_major(&dense_no_trans_matrix, dim); + let dense_conj_no_trans_matrix = conjugate_column_major(&dense_no_trans_matrix); + let dense_conj_trans_matrix = adjoint_column_major(&dense_no_trans_matrix, dim); + let op_no_trans_matrix = assemble_operator_matrix(rsrs_op, dim, TransMode::NoTrans); + let op_trans_matrix = assemble_operator_matrix(rsrs_op, dim, TransMode::Trans); + let op_conj_no_trans_matrix = assemble_operator_matrix(rsrs_op, dim, TransMode::ConjNoTrans); + let op_conj_trans_matrix = assemble_operator_matrix(rsrs_op, dim, TransMode::ConjTrans); + + let metrics = ConjugateFrobeniusMetrics { + conj_no_trans_dense_err: rel_fro_error_flat( + &op_conj_no_trans_matrix, + &dense_conj_no_trans_matrix, + ), + conj_trans_dense_err: rel_fro_error_flat(&op_conj_trans_matrix, &dense_conj_trans_matrix), + conj_trans_vs_no_trans_adjoint: rel_fro_error_flat( + &op_conj_trans_matrix, + &adjoint_column_major(&op_no_trans_matrix, dim), + ), + exact_manual_conjtrans: rel_fro_error_flat( + &dense_conj_trans_matrix, + &conjugate_column_major(&dense_trans_matrix), + ), + rsrs_manual_conjtrans: rel_fro_error_flat( + &op_conj_trans_matrix, + &conjugate_column_major(&op_trans_matrix), + ), + }; + + println!( + "{label}: frob_conj_no_trans_dense={:.3e}, frob_conj_trans_dense={:.3e}, frob_conj_trans_vs_no_trans_adjoint={:.3e}, frob_exact_manual_H={:.3e}, frob_rsrs_manual_H={:.3e}", + metrics.conj_no_trans_dense_err, + metrics.conj_trans_dense_err, + metrics.conj_trans_vs_no_trans_adjoint, + metrics.exact_manual_conjtrans, + metrics.rsrs_manual_conjtrans + ); + println!( + "{label}: rsrs_exps_complex_frob adj(H)={:.3e}, exact_manual_H={:.3e}, rsrs_manual_H={:.3e}", + metrics.conj_trans_vs_no_trans_adjoint, + metrics.exact_manual_conjtrans, + metrics.rsrs_manual_conjtrans + ); + + metrics +} + +/// Convenience helper for the explicit conjugation identities used in the +/// Hermitian test. +fn conjugated(values: &[Item]) -> Vec { + values.iter().map(|value| value.conj()).collect() +} + +/// Runs the common real-valued regression checks. +/// +/// Each case compares three views of the same operation: +/// - the public operator wrapper, +/// - direct factor matvec application, +/// - dense multiplication against the original matrix. +fn run_real_case( + points: &[Point], + comm: &SimpleCommunicator, + matrix: DynamicArray, + symmetry: Symmetry, + label: &str, + transpose_reference_is_adjoint: bool, + dense_tolerance: f64, +) { + let n = matrix.shape()[0]; + let tree = Octree::new(points, TEST_MAX_LEVEL, TEST_MAX_LEAF_POINTS, comm); + let args = RsrsArgs::new( + 8, + 16, + 0, + 120, + Shift::False, + NullMethod::Projection, + RankRevealingQrType::RRQR, + BlockExtractionMethod::LuLstSq, + BlockExtractionMethod::LuLstSq, + PivotMethod::Lu(1e-10), + PivotMethod::Lu(0.0), + 1e-10, + TEST_FIXED_RANK, + 1e-10, + 1e-10, + 4, + 1, + symmetry, + RankPicking::Min, + FactType::Joint, + false, + 1, + false, + true, + ); + let options = RsrsOptions::::new(Some(args)); + let operator = Operator::from(&matrix); + let mut rsrs = Rsrs::new(&tree, options, operator.domain().dimension()); + let rsrs_op = rsrs.get_rsrs_operator(operator); + let factors = rsrs_op.get_factors(); + let x = deterministic_real_vector(n); + let base_no_trans = BaseFactorOptions { + inv: false, + trans: TransMode::NoTrans, + trans_target: false, + }; + let base_trans = BaseFactorOptions { + inv: false, + trans: TransMode::Trans, + trans_target: false, + }; + + let op_left = apply_operator(&rsrs_op, &x, TransMode::NoTrans); + let op_trans = apply_operator(&rsrs_op, &x, TransMode::Trans); + let mut factor_left = vec![0.0; n]; + let mut factor_left_trans = vec![0.0; n]; + let mut factor_right = vec![0.0; n]; + let mut factor_right_trans = vec![0.0; n]; + factors.matvec(&x, &mut factor_left, Side::Left, &base_no_trans); + factors.matvec(&x, &mut factor_left_trans, Side::Left, &base_trans); + factors.matvec(&x, &mut factor_right, Side::Right, &base_no_trans); + factors.matvec(&x, &mut factor_right_trans, Side::Right, &base_trans); + + let dense_left = dense_apply( + &matrix, + &x, + Side::Left, + TransMode::NoTrans, + transpose_reference_is_adjoint, + ); + let dense_left_trans = dense_apply( + &matrix, + &x, + Side::Left, + TransMode::Trans, + transpose_reference_is_adjoint, + ); + let dense_right = dense_apply( + &matrix, + &x, + Side::Right, + TransMode::NoTrans, + transpose_reference_is_adjoint, + ); + let dense_right_trans = dense_apply( + &matrix, + &x, + Side::Right, + TransMode::Trans, + transpose_reference_is_adjoint, + ); + + // `apply_vec_mode` implements transpose-vector products through a right + // application identity, so the operator `Trans` result is expected to match + // the factor-level right-apply path. + let op_vs_factor_left = rel_l2_error(&op_left, &factor_left); + let op_vs_factor_trans = rel_l2_error(&op_trans, &factor_right); + let left_dense_err = rel_l2_error(&op_left, &dense_left); + let trans_dense_err = rel_l2_error(&op_trans, &dense_left_trans); + let left_trans_dense_err = rel_l2_error(&factor_left_trans, &dense_left_trans); + let right_dense_err = rel_l2_error(&factor_right, &dense_right); + let right_trans_dense_err = rel_l2_error(&factor_right_trans, &dense_right_trans); + let no_trans_vs_trans = rel_l2_error(&op_left, &op_trans); + + println!( + "{label}: op_vs_factor_left={op_vs_factor_left:.3e}, op_vs_factor_trans_identity={op_vs_factor_trans:.3e}, left_dense={left_dense_err:.3e}, trans_dense={trans_dense_err:.3e}, left_trans_dense={left_trans_dense_err:.3e}, right_dense={right_dense_err:.3e}, right_trans_dense={right_trans_dense_err:.3e}, no_trans_vs_trans={no_trans_vs_trans:.3e}" + ); + + let frob_metrics = + run_frobenius_matrix_check(label, &matrix, &rsrs_op, factors, &base_no_trans); + + assert!( + op_vs_factor_left <= 1.0e-12, + "{label}: operator NoTrans diverges from factor matvec (rel l2 = {op_vs_factor_left})" + ); + assert!( + op_vs_factor_trans <= 1.0e-12, + "{label}: operator Trans diverges from factor matvec (rel l2 = {op_vs_factor_trans})" + ); + assert!( + left_dense_err <= dense_tolerance, + "{label}: NoTrans RSRS-vs-dense error too large ({left_dense_err} > {dense_tolerance})" + ); + assert!( + trans_dense_err <= dense_tolerance, + "{label}: Trans RSRS-vs-dense error too large ({trans_dense_err} > {dense_tolerance})" + ); + assert!( + left_trans_dense_err <= dense_tolerance, + "{label}: left Trans factor matvec error too large ({left_trans_dense_err} > {dense_tolerance})" + ); + assert!( + right_dense_err <= dense_tolerance, + "{label}: right NoTrans factor matvec error too large ({right_dense_err} > {dense_tolerance})" + ); + assert!( + right_trans_dense_err <= dense_tolerance, + "{label}: right Trans factor matvec error too large ({right_trans_dense_err} > {dense_tolerance})" + ); + assert!( + frob_metrics.op_no_trans_factor_err <= 1.0e-12, + "{label}: Frobenius operator NoTrans diverges from factor matrix ({})", + frob_metrics.op_no_trans_factor_err + ); + assert!( + frob_metrics.op_trans_factor_err <= 1.0e-12, + "{label}: Frobenius operator Trans diverges from factor right-apply matrix ({})", + frob_metrics.op_trans_factor_err + ); + assert!( + frob_metrics.op_no_trans_dense_err <= dense_tolerance, + "{label}: Frobenius NoTrans RSRS-vs-dense error too large ({} > {dense_tolerance})", + frob_metrics.op_no_trans_dense_err + ); + assert!( + frob_metrics.op_trans_dense_err <= dense_tolerance, + "{label}: Frobenius Trans RSRS-vs-dense error too large ({} > {dense_tolerance})", + frob_metrics.op_trans_dense_err + ); + assert!( + frob_metrics.op_trans_vs_no_trans_transpose <= 1.0e-11, + "{label}: Frobenius operator Trans is not the transpose of NoTrans ({})", + frob_metrics.op_trans_vs_no_trans_transpose + ); + assert!( + frob_metrics.factor_right_vs_left_transpose <= 1.0e-11, + "{label}: Frobenius factor right-apply is not the transpose of left-apply ({})", + frob_metrics.factor_right_vs_left_transpose + ); + if matches!(label, "symmetric-real") { + // For a symmetric real matrix, `A x` and `A^T x` should agree up to the + // RSRS approximation error. + assert!( + no_trans_vs_trans <= 2.5e-3, + "{label}: NoTrans and Trans should match for symmetric matrices ({no_trans_vs_trans})" + ); + assert!( + frob_metrics.no_trans_vs_trans <= 2.5e-3, + "{label}: Frobenius NoTrans and Trans should match for symmetric matrices ({})", + frob_metrics.no_trans_vs_trans + ); + } else if matches!(label, "nonsymmetric-real") { + assert!( + frob_metrics.dense_no_trans_vs_trans >= 1.0e-4, + "{label}: dense Frobenius NoTrans and Trans references are too similar ({})", + frob_metrics.dense_no_trans_vs_trans + ); + assert!( + frob_metrics.no_trans_vs_trans >= 1.0e-4, + "{label}: RSRS Frobenius NoTrans and Trans collapsed unexpectedly ({})", + frob_metrics.no_trans_vs_trans + ); + } +} + +/// Runs the complex Hermitian regression checks. +/// +/// Besides the usual dense-vs-RSRS comparisons, this case verifies the +/// conjugation identities that the operator wrapper uses to implement +/// `ConjNoTrans` and `ConjTrans`. +#[allow(dead_code)] +fn run_complex_hermitian_case(points: &[Point], comm: &SimpleCommunicator) { + // This case stresses the operator wrapper rather than just approximation + // quality: Hermitian structure means `ConjTrans` should line up with the + // ordinary `NoTrans` action, while direct factor-level transpose diagnostics + // still go through the lower-level transposed apply path. + let n = points.len(); + let matrix = hermitian_complex_matrix(points); + let tree = Octree::new(points, TEST_MAX_LEVEL, TEST_MAX_LEAF_POINTS, comm); + let args = RsrsArgs::new( + 8, + 16, + 0, + 120, + Shift::False, + NullMethod::Projection, + RankRevealingQrType::RRQR, + BlockExtractionMethod::LuLstSq, + BlockExtractionMethod::LuLstSq, + PivotMethod::Lu(1e-10), + PivotMethod::Lu(0.0), + 1e-10, + TEST_FIXED_RANK, + 1e-10, + 1e-10, + 4, + 1, + Symmetry::Hermitian, + RankPicking::Min, + FactType::Joint, + false, + 1, + false, + true, + ); + let options = RsrsOptions::>::new(Some(args)); + let operator = Operator::from(&matrix); + let mut rsrs = Rsrs::new(&tree, options, operator.domain().dimension()); + let rsrs_op = rsrs.get_rsrs_operator(operator); + let factors = rsrs_op.get_factors(); + let x = deterministic_complex_vector(n); + let base_no_trans = BaseFactorOptions { + inv: false, + trans: TransMode::NoTrans, + trans_target: false, + }; + let base_trans = BaseFactorOptions { + inv: false, + trans: TransMode::Trans, + trans_target: false, + }; + + let op_left = apply_operator(&rsrs_op, &x, TransMode::NoTrans); + let op_trans = apply_operator(&rsrs_op, &x, TransMode::Trans); + let op_conj_no_trans = apply_operator(&rsrs_op, &x, TransMode::ConjNoTrans); + let op_conj_trans = apply_operator(&rsrs_op, &x, TransMode::ConjTrans); + let mut factor_left = vec![Complex::new(0.0, 0.0); n]; + let mut factor_left_trans = vec![Complex::new(0.0, 0.0); n]; + let mut factor_right = vec![Complex::new(0.0, 0.0); n]; + let mut factor_right_trans = vec![Complex::new(0.0, 0.0); n]; + factors.matvec(&x, &mut factor_left, Side::Left, &base_no_trans); + factors.matvec(&x, &mut factor_left_trans, Side::Left, &base_trans); + factors.matvec(&x, &mut factor_right, Side::Right, &base_no_trans); + factors.matvec(&x, &mut factor_right_trans, Side::Right, &base_trans); + + let x_conj = conjugated(&x); + let dense_left = dense_apply(&matrix, &x, Side::Left, TransMode::NoTrans, false); + let dense_left_trans = dense_apply(&matrix, &x, Side::Left, TransMode::Trans, false); + let dense_right = dense_apply(&matrix, &x, Side::Right, TransMode::NoTrans, false); + let dense_right_trans = dense_apply(&matrix, &x, Side::Right, TransMode::Trans, false); + let dense_left_conj_no_trans = conjugated(&dense_apply( + &matrix, + &x_conj, + Side::Left, + TransMode::NoTrans, + false, + )); + let dense_left_conj_trans = conjugated(&dense_apply( + &matrix, + &x_conj, + Side::Left, + TransMode::Trans, + false, + )); + + // As in the real cases, transpose-vector products are normalized through + // right-application before they reach the factor kernels. + let op_vs_factor_left = rel_l2_error(&op_left, &factor_left); + let op_vs_factor_trans = rel_l2_error(&op_trans, &factor_right); + let conj_no_trans_dense_err = rel_l2_error(&op_conj_no_trans, &dense_left_conj_no_trans); + let conj_trans_dense_err = rel_l2_error(&op_conj_trans, &dense_left_conj_trans); + let left_dense_err = rel_l2_error(&op_left, &dense_left); + let trans_dense_err = rel_l2_error(&op_trans, &dense_left_trans); + let left_trans_dense_err = rel_l2_error(&factor_left_trans, &dense_left_trans); + let right_dense_err = rel_l2_error(&factor_right, &dense_right); + let right_trans_dense_err = rel_l2_error(&factor_right_trans, &dense_right_trans); + let no_trans_vs_trans = rel_l2_error(&op_left, &op_trans); + let conj_trans_vs_no_trans = rel_l2_error(&op_conj_trans, &op_left); + + println!( + "hermitian-complex: op_vs_factor_left={op_vs_factor_left:.3e}, op_vs_factor_trans_identity={op_vs_factor_trans:.3e}, left_dense={left_dense_err:.3e}, trans_dense={trans_dense_err:.3e}, conj_no_trans={conj_no_trans_dense_err:.3e}, conj_trans={conj_trans_dense_err:.3e}, left_trans_dense={left_trans_dense_err:.3e}, right_dense={right_dense_err:.3e}, right_trans_dense={right_trans_dense_err:.3e}, no_trans_vs_trans={no_trans_vs_trans:.3e}, conj_trans_vs_no_trans={conj_trans_vs_no_trans:.3e}" + ); + + let frob_metrics = run_frobenius_matrix_check( + "hermitian-complex", + &matrix, + &rsrs_op, + factors, + &base_no_trans, + ); + let conj_frob_metrics = + run_conjugate_frobenius_matrix_check("hermitian-complex", &matrix, &rsrs_op); + + assert!( + op_vs_factor_left <= 1.0e-11, + "hermitian-complex: operator NoTrans diverges from factor matvec (rel l2 = {op_vs_factor_left})" + ); + assert!( + op_vs_factor_trans <= 1.0e-11, + "hermitian-complex: operator Trans diverges from factor matvec (rel l2 = {op_vs_factor_trans})" + ); + assert!( + left_dense_err <= 1.0e-2, + "hermitian-complex: NoTrans RSRS-vs-dense error too large ({left_dense_err})" + ); + assert!( + trans_dense_err <= 1.0e-2, + "hermitian-complex: Trans RSRS-vs-dense error too large ({trans_dense_err})" + ); + assert!( + right_dense_err <= 1.0e-2, + "hermitian-complex: right NoTrans factor matvec error too large ({right_dense_err})" + ); + assert!( + conj_no_trans_dense_err <= 1.0e-2, + "hermitian-complex: ConjNoTrans RSRS-vs-dense error too large ({conj_no_trans_dense_err})" + ); + assert!( + conj_trans_dense_err <= 1.0e-2, + "hermitian-complex: ConjTrans RSRS-vs-dense error too large ({conj_trans_dense_err})" + ); + assert!( + conj_trans_vs_no_trans <= 1.0e-2, + "hermitian-complex: ConjTrans should match NoTrans for Hermitian matrices ({conj_trans_vs_no_trans})" + ); + // The operator-level transpose path intentionally reuses right-application. + // Depending on the low-level storage view, either direct factor-level + // transpose diagnostic can be the sharper signal, so we require at least + // one of them to stay accurate. + assert!( + left_trans_dense_err <= 2.0e-2 || right_trans_dense_err <= 2.0e-2, + "hermitian-complex: both factor-level Trans diagnostics are unexpectedly poor (left={left_trans_dense_err}, right={right_trans_dense_err})" + ); + assert!( + frob_metrics.op_no_trans_factor_err <= 1.0e-11, + "hermitian-complex: Frobenius operator NoTrans diverges from factor matrix ({})", + frob_metrics.op_no_trans_factor_err + ); + assert!( + frob_metrics.op_trans_factor_err <= 1.0e-11, + "hermitian-complex: Frobenius operator Trans diverges from factor right-apply matrix ({})", + frob_metrics.op_trans_factor_err + ); + assert!( + frob_metrics.op_no_trans_dense_err <= 1.0e-2, + "hermitian-complex: Frobenius NoTrans RSRS-vs-dense error too large ({})", + frob_metrics.op_no_trans_dense_err + ); + assert!( + frob_metrics.op_trans_dense_err <= 1.0e-2, + "hermitian-complex: Frobenius Trans RSRS-vs-dense error too large ({})", + frob_metrics.op_trans_dense_err + ); + assert!( + conj_frob_metrics.conj_no_trans_dense_err <= 1.0e-2, + "hermitian-complex: Frobenius ConjNoTrans RSRS-vs-dense error too large ({})", + conj_frob_metrics.conj_no_trans_dense_err + ); + assert!( + conj_frob_metrics.conj_trans_dense_err <= 1.0e-2, + "hermitian-complex: Frobenius ConjTrans RSRS-vs-dense error too large ({})", + conj_frob_metrics.conj_trans_dense_err + ); +} + +struct ComplexSymmetricMetrics { + // Public operator `NoTrans` versus direct left factor apply. + op_vs_factor_left: f64, + // Public operator `Trans` versus direct right factor apply. + op_vs_factor_trans: f64, + // Public operator `NoTrans` versus dense `A x`. + left_dense_err: f64, + // Public operator `Trans` versus dense `A^T x`. + trans_dense_err: f64, + // Public operator `A x` versus public operator `A^T x`. + no_trans_vs_trans: f64, +} + +/// Runs a complex symmetric case and reports the main operator-level +/// diagnostics. +fn run_complex_symmetric_mode( + points: &[Point], + comm: &SimpleCommunicator, + matrix: &DynamicArray, 2>, + symmetry: Symmetry, + label: &str, +) -> ComplexSymmetricMetrics { + let n = matrix.shape()[0]; + let tree = Octree::new(points, TEST_MAX_LEVEL, TEST_MAX_LEAF_POINTS, comm); + let args = RsrsArgs::new( + 8, + 16, + 0, + 120, + Shift::False, + NullMethod::Projection, + RankRevealingQrType::RRQR, + BlockExtractionMethod::LuLstSq, + BlockExtractionMethod::LuLstSq, + PivotMethod::Lu(1e-10), + PivotMethod::Lu(0.0), + 1e-10, + TEST_FIXED_RANK, + 1e-10, + 1e-10, + 4, + 1, + symmetry, + RankPicking::Min, + FactType::Joint, + false, + 1, + false, + true, + ); + let options = RsrsOptions::>::new(Some(args)); + let operator = Operator::from(matrix); + let mut rsrs = Rsrs::new(&tree, options, operator.domain().dimension()); + let rsrs_op = rsrs.get_rsrs_operator(operator); + let factors = rsrs_op.get_factors(); + let x = deterministic_complex_vector(n); + let base_no_trans = BaseFactorOptions { + inv: false, + trans: TransMode::NoTrans, + trans_target: false, + }; + + let op_left = apply_operator(&rsrs_op, &x, TransMode::NoTrans); + let op_trans = apply_operator(&rsrs_op, &x, TransMode::Trans); + let mut factor_left = vec![Complex::new(0.0, 0.0); n]; + let mut factor_right = vec![Complex::new(0.0, 0.0); n]; + factors.matvec(&x, &mut factor_left, Side::Left, &base_no_trans); + factors.matvec(&x, &mut factor_right, Side::Right, &base_no_trans); + + let dense_left = dense_apply(matrix, &x, Side::Left, TransMode::NoTrans, false); + let dense_left_trans = dense_apply(matrix, &x, Side::Left, TransMode::Trans, false); + + let metrics = ComplexSymmetricMetrics { + op_vs_factor_left: rel_l2_error(&op_left, &factor_left), + op_vs_factor_trans: rel_l2_error(&op_trans, &factor_right), + left_dense_err: rel_l2_error(&op_left, &dense_left), + trans_dense_err: rel_l2_error(&op_trans, &dense_left_trans), + no_trans_vs_trans: rel_l2_error(&op_left, &op_trans), + }; + + println!( + "{label}: op_vs_factor_left={:.3e}, op_vs_factor_trans_identity={:.3e}, left_dense={:.3e}, trans_dense={:.3e}, no_trans_vs_trans={:.3e}", + metrics.op_vs_factor_left, + metrics.op_vs_factor_trans, + metrics.left_dense_err, + metrics.trans_dense_err, + metrics.no_trans_vs_trans + ); + + let frob_metrics = run_frobenius_matrix_check(label, matrix, &rsrs_op, factors, &base_no_trans); + let conj_frob_metrics = run_conjugate_frobenius_matrix_check(label, matrix, &rsrs_op); + assert!( + frob_metrics.op_no_trans_factor_err <= 1.0e-11, + "{label}: Frobenius operator NoTrans diverges from factor matrix ({})", + frob_metrics.op_no_trans_factor_err + ); + assert!( + frob_metrics.op_trans_factor_err <= 1.0e-11, + "{label}: Frobenius operator Trans diverges from factor right-apply matrix ({})", + frob_metrics.op_trans_factor_err + ); + assert!( + frob_metrics.op_no_trans_dense_err <= 2.0e-2, + "{label}: Frobenius NoTrans RSRS-vs-dense error too large ({})", + frob_metrics.op_no_trans_dense_err + ); + assert!( + frob_metrics.op_trans_dense_err <= 2.0e-2, + "{label}: Frobenius Trans RSRS-vs-dense error too large ({})", + frob_metrics.op_trans_dense_err + ); + assert!( + conj_frob_metrics.conj_no_trans_dense_err <= 2.0e-2, + "{label}: Frobenius ConjNoTrans RSRS-vs-dense error too large ({})", + conj_frob_metrics.conj_no_trans_dense_err + ); + assert!( + conj_frob_metrics.conj_trans_dense_err <= 2.0e-2, + "{label}: Frobenius ConjTrans RSRS-vs-dense error too large ({})", + conj_frob_metrics.conj_trans_dense_err + ); + + metrics +} + +/// Runs the complex symmetric fixture through the complex-symmetric path. +fn run_complex_symmetric_case( + points: &[Point], + comm: &SimpleCommunicator, + matrix: &DynamicArray, 2>, +) { + let symmetric_metrics = run_complex_symmetric_mode( + points, + comm, + matrix, + Symmetry::Symmetric, + "symmetric-complex", + ); + + assert!( + symmetric_metrics.op_vs_factor_left <= 1.0e-11, + "symmetric-complex: operator NoTrans diverges from factor matvec (rel l2 = {})", + symmetric_metrics.op_vs_factor_left + ); + assert!( + symmetric_metrics.op_vs_factor_trans <= 1.0e-11, + "symmetric-complex: operator Trans diverges from factor matvec (rel l2 = {})", + symmetric_metrics.op_vs_factor_trans + ); + assert!( + symmetric_metrics.left_dense_err <= 2.0e-2, + "symmetric-complex: NoTrans RSRS-vs-dense error too large ({})", + symmetric_metrics.left_dense_err + ); + assert!( + symmetric_metrics.trans_dense_err <= 2.0e-2, + "symmetric-complex: Trans RSRS-vs-dense error too large ({})", + symmetric_metrics.trans_dense_err + ); + assert!( + symmetric_metrics.no_trans_vs_trans <= 2.0e-2, + "symmetric-complex: NoTrans and Trans should match for complex symmetric matrices ({})", + symmetric_metrics.no_trans_vs_trans + ); +} + +/// Runs a genuinely complex nonsymmetric case through the `NoSymm` path. +fn run_complex_nonsymmetric_case( + points: &[Point], + comm: &SimpleCommunicator, + matrix: &DynamicArray, 2>, +) { + let n = matrix.shape()[0]; + let tree = Octree::new(points, TEST_MAX_LEVEL, TEST_MAX_LEAF_POINTS, comm); + let args = RsrsArgs::new( + 8, + 16, + 0, + 120, + Shift::False, + NullMethod::Projection, + RankRevealingQrType::RRQR, + BlockExtractionMethod::LuLstSq, + BlockExtractionMethod::LuLstSq, + PivotMethod::Lu(1e-10), + PivotMethod::Lu(0.0), + 1e-10, + TEST_FIXED_RANK, + 1e-10, + 1e-10, + 4, + 1, + Symmetry::NoSymm, + RankPicking::Min, + FactType::Joint, + false, + 1, + false, + true, + ); + let options = RsrsOptions::>::new(Some(args)); + let operator = Operator::from(matrix); + let mut rsrs = Rsrs::new(&tree, options, operator.domain().dimension()); + let mut rsrs_op = rsrs.get_rsrs_operator(operator); + let factors = rsrs_op.get_factors(); + let x = deterministic_complex_vector(n); + let base_no_trans = BaseFactorOptions { + inv: false, + trans: TransMode::NoTrans, + trans_target: false, + }; + let base_trans = BaseFactorOptions { + inv: false, + trans: TransMode::Trans, + trans_target: false, + }; + + let op_left = apply_operator(&rsrs_op, &x, TransMode::NoTrans); + let op_trans = apply_operator(&rsrs_op, &x, TransMode::Trans); + let op_conj_no_trans = apply_operator(&rsrs_op, &x, TransMode::ConjNoTrans); + let op_conj_trans = apply_operator(&rsrs_op, &x, TransMode::ConjTrans); + let mut factor_left = vec![Complex::new(0.0, 0.0); n]; + let mut factor_left_trans = vec![Complex::new(0.0, 0.0); n]; + let mut factor_right = vec![Complex::new(0.0, 0.0); n]; + let mut factor_right_trans = vec![Complex::new(0.0, 0.0); n]; + factors.matvec(&x, &mut factor_left, Side::Left, &base_no_trans); + factors.matvec(&x, &mut factor_left_trans, Side::Left, &base_trans); + factors.matvec(&x, &mut factor_right, Side::Right, &base_no_trans); + factors.matvec(&x, &mut factor_right_trans, Side::Right, &base_trans); + + let x_conj = conjugated(&x); + let dense_left = dense_apply(matrix, &x, Side::Left, TransMode::NoTrans, false); + let dense_left_trans = dense_apply(matrix, &x, Side::Left, TransMode::Trans, false); + let dense_right = dense_apply(matrix, &x, Side::Right, TransMode::NoTrans, false); + let dense_right_trans = dense_apply(matrix, &x, Side::Right, TransMode::Trans, false); + let dense_left_conj_no_trans = conjugated(&dense_apply( + matrix, + &x_conj, + Side::Left, + TransMode::NoTrans, + false, + )); + let dense_left_conj_trans = conjugated(&dense_apply( + matrix, + &x_conj, + Side::Left, + TransMode::Trans, + false, + )); + + let op_vs_factor_left = rel_l2_error(&op_left, &factor_left); + let op_vs_factor_trans = rel_l2_error(&op_trans, &factor_right); + let left_dense_err = rel_l2_error(&op_left, &dense_left); + let trans_dense_err = rel_l2_error(&op_trans, &dense_left_trans); + let left_trans_dense_err = rel_l2_error(&factor_left_trans, &dense_left_trans); + let right_dense_err = rel_l2_error(&factor_right, &dense_right); + let right_trans_dense_err = rel_l2_error(&factor_right_trans, &dense_right_trans); + let conj_no_trans_dense_err = rel_l2_error(&op_conj_no_trans, &dense_left_conj_no_trans); + let conj_trans_dense_err = rel_l2_error(&op_conj_trans, &dense_left_conj_trans); + let no_trans_vs_trans = rel_l2_error(&op_left, &op_trans); + let dense_no_trans_vs_trans = rel_l2_error(&dense_left, &dense_left_trans); + + println!( + "nonsymmetric-complex-general: op_vs_factor_left={op_vs_factor_left:.3e}, op_vs_factor_trans_identity={op_vs_factor_trans:.3e}, left_dense={left_dense_err:.3e}, trans_dense={trans_dense_err:.3e}, left_trans_dense={left_trans_dense_err:.3e}, right_dense={right_dense_err:.3e}, right_trans_dense={right_trans_dense_err:.3e}, conj_no_trans={conj_no_trans_dense_err:.3e}, conj_trans={conj_trans_dense_err:.3e}, no_trans_vs_trans={no_trans_vs_trans:.3e}, dense_no_trans_vs_trans={dense_no_trans_vs_trans:.3e}" + ); + + let frob_metrics = run_frobenius_matrix_check( + "nonsymmetric-complex-general", + matrix, + &rsrs_op, + factors, + &base_no_trans, + ); + let conj_frob_metrics = + run_conjugate_frobenius_matrix_check("nonsymmetric-complex-general", matrix, &rsrs_op); + if run_dense_diagnostics() { + run_dense_box_errors_left_application("nonsymmetric-complex-general", matrix, factors); + run_dense_box_errors_right_application("nonsymmetric-complex-general", matrix, factors); + run_dense_box_errors_sandwich_application("nonsymmetric-complex-general", matrix, factors); + run_inverse_dense_diagnostic("nonsymmetric-complex-general", matrix, &mut rsrs_op); + } + + assert!( + op_vs_factor_left <= 1.0e-11, + "nonsymmetric-complex-general: operator NoTrans diverges from factor matvec (rel l2 = {op_vs_factor_left})" + ); + assert!( + op_vs_factor_trans <= 1.0e-11, + "nonsymmetric-complex-general: operator Trans diverges from factor matvec (rel l2 = {op_vs_factor_trans})" + ); + assert!( + left_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: NoTrans RSRS-vs-dense error too large ({left_dense_err})" + ); + assert!( + trans_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: Trans RSRS-vs-dense error too large ({trans_dense_err})" + ); + assert!( + left_trans_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: left Trans factor matvec error too large ({left_trans_dense_err})" + ); + assert!( + right_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: right NoTrans factor matvec error too large ({right_dense_err})" + ); + assert!( + right_trans_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: right Trans factor matvec error too large ({right_trans_dense_err})" + ); + assert!( + conj_no_trans_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: ConjNoTrans RSRS-vs-dense error too large ({conj_no_trans_dense_err})" + ); + assert!( + conj_trans_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: ConjTrans RSRS-vs-dense error too large ({conj_trans_dense_err})" + ); + assert!( + dense_no_trans_vs_trans >= 1.0e-4, + "nonsymmetric-complex-general: dense NoTrans and Trans references are too similar ({dense_no_trans_vs_trans})" + ); + assert!( + no_trans_vs_trans >= 1.0e-4, + "nonsymmetric-complex-general: RSRS NoTrans and Trans collapsed unexpectedly ({no_trans_vs_trans})" + ); + assert!( + frob_metrics.op_no_trans_factor_err <= 1.0e-11, + "nonsymmetric-complex-general: Frobenius operator NoTrans diverges from factor matrix ({})", + frob_metrics.op_no_trans_factor_err + ); + assert!( + frob_metrics.op_trans_factor_err <= 1.0e-11, + "nonsymmetric-complex-general: Frobenius operator Trans diverges from factor right-apply matrix ({})", + frob_metrics.op_trans_factor_err + ); + assert!( + frob_metrics.op_no_trans_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: Frobenius NoTrans RSRS-vs-dense error too large ({})", + frob_metrics.op_no_trans_dense_err + ); + assert!( + frob_metrics.op_trans_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: Frobenius Trans RSRS-vs-dense error too large ({})", + frob_metrics.op_trans_dense_err + ); + assert!( + conj_frob_metrics.conj_no_trans_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: Frobenius ConjNoTrans RSRS-vs-dense error too large ({})", + conj_frob_metrics.conj_no_trans_dense_err + ); + assert!( + conj_frob_metrics.conj_trans_dense_err <= 2.0e-2, + "nonsymmetric-complex-general: Frobenius ConjTrans RSRS-vs-dense error too large ({})", + conj_frob_metrics.conj_trans_dense_err + ); + assert!( + frob_metrics.dense_no_trans_vs_trans >= 1.0e-4, + "nonsymmetric-complex-general: dense Frobenius NoTrans and Trans references are too similar ({})", + frob_metrics.dense_no_trans_vs_trans + ); + assert!( + frob_metrics.no_trans_vs_trans >= 1.0e-4, + "nonsymmetric-complex-general: RSRS Frobenius NoTrans and Trans collapsed unexpectedly ({})", + frob_metrics.no_trans_vs_trans + ); +} + +#[test] +fn rsrs_operator_matvec_diagnostic() { + std::thread::Builder::new() + .name("rsrs_operator_matvec_diagnostic_worker".into()) + .stack_size(64 * 1024 * 1024) + .spawn(rsrs_operator_matvec_diagnostic_worker) + .unwrap() + .join() + .unwrap(); +} + +fn rsrs_operator_matvec_diagnostic_worker() { + std::env::set_var("OPENBLAS_NUM_THREADS", "1"); + + let universe = mpi::initialize().unwrap(); + let comm: SimpleCommunicator = universe.world(); + let fixture = load_biegrid_perturbed_fixture(); + let points = &fixture.points; + println!( + "loaded perturbed BIEGrid fixture: n={}, cells_per_axis={}, perturbation_scale={:.3e}, perturbation_seed={}", + points.len(), + fixture.cells_per_axis, + fixture.perturbation_scale, + fixture.perturbation_seed + ); + + // These cases together cover the main routing logic: + // - symmetric real: transpose should collapse to the same action, + // - nonsymmetric real: transpose must remain distinct, + // - symmetric complex: transpose should still collapse without Hermitian + // conjugation, + // - nonsymmetric complex: transpose must remain distinct. + run_real_case( + points, + &comm, + fixture.real_symmetric, + Symmetry::Symmetric, + "symmetric-real", + false, + 1.0e-2, + ); + run_real_case( + points, + &comm, + fixture.real_nonsymmetric, + Symmetry::NoSymm, + "nonsymmetric-real", + false, + 1.5e-2, + ); + run_complex_symmetric_case(points, &comm, &fixture.complex_symmetric); + run_complex_nonsymmetric_case(points, &comm, &fixture.complex_nonsymmetric); + // The analytic Hermitian diagnostic is kept available, but this fixture + // set intentionally covers the four perturbed BIEGrid cases above. + // run_complex_hermitian_case(points, &comm); +}