Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ jobs:
with:
repository: ignacia-fp/rsrs-exps
path: rsrs-exps
submodules: recursive
- name: Install public rsrs-exps system dependencies
run: |
sudo apt-get update
Expand Down Expand Up @@ -116,7 +117,7 @@ jobs:
- name: Bootstrap public rsrs-exps harness
run: |
cd rsrs-exps
DEPS_DIR="$PWD/.deps" WORKSPACE="$PWD" bash scripts/setup_deps.sh true
ENABLE_EXAFMM=0 DEPS_DIR="$PWD/.deps" WORKSPACE="$PWD" bash scripts/setup_deps.sh true
- name: Run perturbed BIEGrid regression suite
run: |
cd rsrs-exps
Expand Down
76 changes: 67 additions & 9 deletions src/rsrs/rsrs_cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,42 @@ fn root_fixed_rank_samples(
.max(min_num_samples)
}

fn required_post_nullification_samples(
active_box_size: usize,
active_near_with_self_size: usize,
rank: usize,
p_param: usize,
) -> usize {
active_near_with_self_size
.saturating_add(active_box_size.min(rank))
.saturating_add(p_param)
}

fn safe_local_per_level_sample_count(
per_box_samples: Option<usize>,
active_box_size: usize,
active_near_with_self_size: usize,
active_samples: usize,
rank: usize,
p_param: usize,
) -> usize {
let local_samples = per_box_samples.unwrap_or(active_samples);
let required = required_post_nullification_samples(
active_box_size,
active_near_with_self_size,
rank,
p_param,
);
if local_samples < required {
active_samples
} else {
local_samples
}
}

fn local_sample_count<Item: RlstScalar>(
per_box_samples: Option<usize>,
active_box_size: usize,
active_near_with_self_size: usize,
active_samples: usize,
fixed_rank: bool,
Expand All @@ -162,13 +196,14 @@ fn local_sample_count<Item: RlstScalar>(
match fixed_rank_sampling_mode {
FixedRankSamplingMode::Constant => active_samples,
FixedRankSamplingMode::PerLevel => {
let local_samples = per_box_samples.unwrap_or(active_samples);
let min_safe_samples = active_near_with_self_size.saturating_add(2);
let local_samples = if local_samples < min_safe_samples {
active_samples
} else {
local_samples
};
let local_samples = safe_local_per_level_sample_count(
per_box_samples,
active_box_size,
active_near_with_self_size,
active_samples,
rank,
options.sketching.oversampling,
);
assert!(
local_samples <= active_samples,
"PerLevel local sample request ({local_samples}) exceeds active samples ({active_samples}) for rank={}, p={}",
Expand All @@ -188,7 +223,8 @@ fn assert_post_nullification_headroom<Item: RlstScalar>(
) {
let rank = num::ToPrimitive::to_usize(&options.id_options.tol_id).unwrap();
let p = options.sketching.oversampling;
let required = active_near_with_self_size + active_box_size.min(rank) + p;
let required =
required_post_nullification_samples(active_box_size, active_near_with_self_size, rank, p);
assert!(
subs_sample_dim >= required,
"Insufficient fixed-rank samples for post-nullification sketch: samples={}, required={}, active_box_size={}, active_near_with_self_size={}, rank={}, p={}",
Expand Down Expand Up @@ -1524,6 +1560,7 @@ where
});
let min_num_samples = local_sample_count(
per_box_samples,
self.ind_s[box_ind].len(),
level_near_field_inds[*box_num].len(),
self.active_samples,
self.anticipated_fixed_rank_samples.is_some(),
Expand Down Expand Up @@ -1645,6 +1682,7 @@ where
});
let min_num_samples = local_sample_count(
per_box_samples,
self.ind_s[box_ind].len(),
level_near_field_inds[*box_num].len(),
self.active_samples,
self.anticipated_fixed_rank_samples.is_some(),
Expand Down Expand Up @@ -1791,6 +1829,7 @@ where
});
let min_box_samples = local_sample_count(
per_box_samples,
self.ind_s[box_ind].len(),
near_field_inds.len(),
self.active_samples,
self.anticipated_fixed_rank_samples.is_some(),
Expand Down Expand Up @@ -1946,6 +1985,7 @@ where
});
let min_num_samples = local_sample_count(
per_box_samples,
self.ind_s[box_ind].len(),
level_near_field_inds[*box_num].len(),
self.active_samples,
self.anticipated_fixed_rank_samples.is_some(),
Expand Down Expand Up @@ -2447,7 +2487,10 @@ fn pick_ranks<Item: RlstScalar>(

#[cfg(test)]
mod tests {
use super::{fixed_rank_skeleton_upper_size, FixedRankSampleBudget};
use super::{
fixed_rank_skeleton_upper_size, required_post_nullification_samples,
safe_local_per_level_sample_count, FixedRankSampleBudget,
};
use std::collections::HashMap;

#[test]
Expand Down Expand Up @@ -2481,4 +2524,19 @@ mod tests {
assert_eq!(fixed_rank_skeleton_upper_size(3, 3, 1, 118, 20), 20);
assert_eq!(fixed_rank_skeleton_upper_size(2, 3, 1, 118, 20), 118);
}

#[test]
fn per_level_local_samples_fall_back_when_box_budget_is_too_small() {
let required = required_post_nullification_samples(21, 21, 20, 8);
assert_eq!(required, 49);

let chosen = safe_local_per_level_sample_count(Some(30), 21, 21, 1970, 20, 8);
assert_eq!(chosen, 1970);
}

#[test]
fn per_level_local_samples_keep_box_budget_when_it_is_safe() {
let chosen = safe_local_per_level_sample_count(Some(60), 21, 21, 1970, 20, 8);
assert_eq!(chosen, 60);
}
}
Loading