diff --git a/pyproject.toml b/pyproject.toml index 3c68bb9..809a687 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "scikit-learn", "natsort", "wandb", + "samplics", ] [project.urls] diff --git a/tests/test_branch_rounding.py b/tests/test_branch_rounding.py new file mode 100644 index 0000000..27cb5bf --- /dev/null +++ b/tests/test_branch_rounding.py @@ -0,0 +1,291 @@ +import numpy as np +import pytest + +from wavefunction_branching.utils.branch_rounding import ( + probabilistic_round_child_budget, + sampford_mask, +) + + +class TestSampfordMask: + """Tests for sampford_mask function.""" + + def test_basic_functionality(self): + """Test that sampford_mask returns correct sample size.""" + # Use more balanced weights to avoid CertaintyError from samplics + weights = np.array([1.0, 1.5, 2.0, 2.5, 3.0]) + sample_size = 2 + result = sampford_mask(sample_size, weights) + + assert result.sum() == sample_size + assert result.dtype == int or result.dtype == np.int64 or result.dtype == np.int32 + assert len(result) == len(weights) + assert np.all((result == 0) | (result == 1)) + + def test_large_sample_size(self): + """Test with large sample size relative to population.""" + # Note: samplics raises CertaintyError when sample_size == population_size or when certain items become certainties. + # Use uniform weights to avoid certainties + weights = np.array([1.0, 1.1, 1.2, 1.3, 1.4]) + sample_size = 3 + result = sampford_mask(sample_size, weights) + + assert result.sum() == sample_size + assert len(result) == len(weights) + + def test_single_selection(self): + """Test when sample_size equals 1.""" + weights = np.array([1.0, 2.0, 3.0, 4.0]) + sample_size = 1 + result = sampford_mask(sample_size, weights) + + assert result.sum() == sample_size + assert np.sum(result == 1) == 1 + + def test_zero_weights_error(self): + """Test that all zero weights raise ValueError.""" + weights = np.array([0.0, 0.0, 0.0]) + sample_size = 1 + + with pytest.raises(ValueError, match="At least one weight"): + sampford_mask(sample_size, weights) + + def test_negative_weights_error(self): + """Test that negative weights raise ValueError.""" + weights = np.array([1.0, -1.0, 2.0]) + sample_size = 1 + + with pytest.raises(ValueError, match="All weights"): + sampford_mask(sample_size, weights) + + def test_invalid_sample_size_error(self): + """Test that invalid sample_size values raise ValueError.""" + weights = np.array([1.0, 2.0, 3.0]) + + with pytest.raises(ValueError, match="sample_size must be between"): + sampford_mask(0, weights) + + with pytest.raises(ValueError, match="sample_size must be between"): + sampford_mask(4, weights) + + def test_proportional_to_weights(self): + """Test that selection probabilities are roughly proportional to weights.""" + weights = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + sample_size = 2 + n_trials = 10000 + + # Count how many times each element is selected + counts = np.zeros(len(weights)) + for _ in range(n_trials): + result = sampford_mask(sample_size, weights) + counts += result + + # Normalize to get empirical probabilities + empirical_probs = counts / n_trials + + # Expected probabilities should be roughly proportional to weights + # Since we're sampling 2 out of 5, the inclusion probabilities are not + # simply weights / sum(weights), but should still correlate with weights + # Higher weights should have higher empirical probability + assert empirical_probs[4] > empirical_probs[0], ( + "Larger weight should have higher selection probability" + ) + assert empirical_probs[3] > empirical_probs[1], ( + "Larger weight should have higher selection probability" + ) + + def test_uniform_weights(self): + """Test with uniform weights.""" + weights = np.ones(5) + sample_size = 2 + + # With uniform weights, selection should be roughly uniform + result = sampford_mask(sample_size, weights) + assert result.sum() == sample_size + + +class TestProbabilisticRoundChildBudget: + """Tests for probabilistic_round_child_budget function.""" + + def test_basic_functionality(self): + """Test basic functionality - sum equals max_children.""" + max_children = 10 + probs = np.array([0.1, 0.2, 0.3, 0.4]) + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == max_children + assert isinstance(result, np.ndarray) + assert np.all(result >= 0) + assert len(result) == len(probs) + + def test_exact_integer_case(self): + """Test when expected values are exact integers (no rounding needed).""" + max_children = 10 + probs = np.array([0.2, 0.3, 0.5]) # Expected: [2.0, 3.0, 5.0] + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == max_children + np.testing.assert_array_equal(result, [2, 3, 5]) + + def test_fractional_case(self): + """Test when expected values have fractional parts.""" + max_children = 10 + probs = np.array([0.15, 0.25, 0.6]) # Expected: [1.5, 2.5, 6.0] + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == max_children + # Floor parts: [1, 2, 6] = 9, so 1 child needs to be allocated + # Should be allocated probabilistically based on fractional parts [0.5, 0.5, 0.0] + assert result[0] >= 1 and result[0] <= 2 + assert result[1] >= 2 and result[1] <= 3 + assert result[2] == 6 + + def test_floor_deterministic(self): + """Test that the floor of each allocation is deterministic and matches expected.""" + max_children = 17 + probs = np.array([0.15, 0.23, 0.31, 0.31]) # Expected: [2.55, 3.91, 5.27, 5.27] + n_trials = 100 + + # Calculate expected floors (the deterministic part) + expected_children = max_children * probs + expected_floors = np.floor(expected_children).astype(int) + # Expected floors: [2, 3, 5, 5] = 15, so 2 children need probabilistic allocation + + # Collect all results to verify floor determinism + all_results = [] + for _ in range(n_trials): + result = probabilistic_round_child_budget(max_children, probs) + all_results.append(result) + # Each result should be at least its expected floor + assert np.all(result >= expected_floors), ( + f"Result {result} should all be >= expected floors {expected_floors}" + ) + # Each result should be no more than its expected floor + 1 + assert np.all(result <= expected_floors + 1), ( + f"Result {result} should all be <= expected floors + 1 {expected_floors + 1}" + ) + + all_results = np.array(all_results) + + # Verify the expected floor values explicitly + # These are the deterministic guarantees: each result will be at least this value + np.testing.assert_array_equal(expected_floors, np.array([2, 3, 5, 5])) + + def test_uniform_probabilities(self): + """Test with uniform probabilities.""" + max_children = 12 + probs = np.array([1.0, 1.0, 1.0, 1.0]) # Will be normalized to [0.25, 0.25, 0.25, 0.25] + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == max_children + # With uniform probabilities, should be roughly uniform distribution + assert np.all(result >= 2) # At least floor(12/4) = 3 each, but fractional parts + assert np.all(result <= 4) + + def test_single_branch(self): + """Test with single branch (trivial case).""" + max_children = 5 + probs = np.array([1.0]) + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == max_children + assert result[0] == max_children + + def test_large_imbalance(self): + """Test with highly imbalanced probabilities.""" + max_children = 10 + probs = np.array([0.01, 0.01, 0.98]) # One branch dominates + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == max_children + # The large probability branch should get most children + assert result[2] >= 8 # Should get at least floor(10*0.98) = 9 or 10 + + def test_proportional_allocation(self): + """Test that allocations are roughly proportional to probabilities over many trials.""" + max_children = 100 + probs = np.array([0.1, 0.2, 0.3, 0.4]) + n_trials = 10000 + + # Average allocations over many trials + avg_allocations = np.zeros(len(probs)) + for _ in range(n_trials): + result = probabilistic_round_child_budget(max_children, probs) + avg_allocations += result + + avg_allocations /= n_trials + + # Normalize to compare with expected probabilities + expected = max_children * probs + # Average should be close to expected (within reasonable tolerance) + np.testing.assert_allclose(avg_allocations, expected, rtol=0.01) + + def test_zero_probabilities(self): + """Test handling of zero probabilities (should be normalized out).""" + max_children = 10 + probs = np.array([0.0, 0.3, 0.7, 0.0]) + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == max_children + assert result[0] == 0 + assert result[3] == 0 + assert result[1] + result[2] == max_children + + def test_all_zero_probabilities_error(self): + """Test that all zero probabilities raise ValueError.""" + max_children = 10 + probs = np.array([0.0, 0.0, 0.0]) + + # Should raise ValueError before attempting normalization + with pytest.raises(ValueError, match="All probabilities are zero"): + probabilistic_round_child_budget(max_children, probs) + + def test_large_max_children(self): + """Test with large max_children value.""" + max_children = 1000 + probs = np.array([0.15, 0.25, 0.35, 0.25]) + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == max_children + # Should be roughly proportional + expected = max_children * probs + np.testing.assert_allclose(result, expected, rtol=0.01) + + def test_small_max_children(self): + """Test with small max_children value.""" + max_children = 3 + probs = np.array([0.2, 0.3, 0.5]) + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == max_children + # With small values, fractional parts matter more + assert np.all(result >= 0) + assert np.all(result <= max_children) + + def test_list_input(self): + """Test that function accepts list input (not just numpy array).""" + max_children = 10 + probs = [0.1, 0.2, 0.3, 0.4] + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == max_children + assert isinstance(result, np.ndarray) + + def test_float_max_children(self): + """Test that function accepts float max_children (should be cast to int).""" + max_children = 10.7 # Should be cast to 10 + probs = np.array([0.2, 0.3, 0.5]) + + result = probabilistic_round_child_budget(max_children, probs) + + assert result.sum() == 10 # Should use int(10.7) = 10 diff --git a/wavefunction_branching/evolve_and_branch_finite.py b/wavefunction_branching/evolve_and_branch_finite.py index aab6b63..fce77dc 100644 --- a/wavefunction_branching/evolve_and_branch_finite.py +++ b/wavefunction_branching/evolve_and_branch_finite.py @@ -33,6 +33,7 @@ from wavefunction_branching.decompositions.decompositions import branch from wavefunction_branching.hamiltonians import TFIChain, TFIModel from wavefunction_branching.utils.tensors import truncate_tensor +from wavefunction_branching.utils.branch_rounding import probabilistic_round_child_budget sys.setrecursionlimit(100000) @@ -538,81 +539,18 @@ def branch_and_sample( num_candidates = len(branch_probs) candidate_indices = np.arange(num_candidates) + # --- Branch Sampling: Allocate grandchild budget using random_round --- + allocated_budgets = probabilistic_round_child_budget(max_children, branch_probs) + + keep_mask = allocated_budgets > 0 + survivor_indices = candidate_indices[keep_mask] + final_max_children = allocated_budgets[keep_mask] print( - f"{self.ID}Starting sampling/filtering with {num_candidates} candidates and budget {self.max_children}." + f"{self.ID} Indices surviving after stage 2 filtering: {survivor_indices}" ) - # --- Stage 1 Sampling: Select based on parent's max_children budget --- - if num_candidates > self.max_children: - print( - f"{self.ID}Stage 1: Sampling {self.max_children} from {num_candidates} via np.random.choice." - ) - # Ensure probabilities are normalized for sampling robustness - probs_for_choice = branch_probs / branch_probs.sum() - stage1_kept_indices = np.random.choice( - candidate_indices, - p=probs_for_choice, - replace=False, - size=self.max_children, - ) - print(f"{self.ID} Indices kept after stage 1: {stage1_kept_indices}") - else: - # Keep all candidates if budget allows - print(f"{self.ID}Stage 1: Keeping all {num_candidates} candidates (budget sufficient).") - stage1_kept_indices = candidate_indices - - num_selected_stage1 = len(stage1_kept_indices) - - # --- Stage 2 Sampling: Allocate grandchild budget using random_round --- - final_survivor_original_indices = [] - final_max_children = [] # Grandchild budgets for the final survivors - - if num_selected_stage1 > 0: - print( - f"{self.ID}Stage 2: Allocating grandchild budget ({self.max_children}) among {num_selected_stage1} candidates." - ) - selected_branch_probs = branch_probs[stage1_kept_indices] - selected_total_prob = selected_branch_probs.sum() - allocated_budgets = np.zeros(num_selected_stage1, dtype=int) - - if np.isclose(selected_total_prob, 0.0): - print( - f"{self.ID} Warning: Total probability of stage 1 selected branches is zero." - ) - else: - # Rescale probabilities of selected branches to sum to 1 for budget allocation - selected_branch_probs_rescaled = selected_branch_probs / selected_total_prob - weights_for_rounding = self.max_children * selected_branch_probs_rescaled - - # Assign the number of max children to each child branch, proportional to their probs (randomly) - # TODO: MAKE THIS MORE EFFICIENT - i = 0 - while np.sum(allocated_budgets) != self.max_children and i < 10000: - allocated_budgets = np.array( - [max(0, random_round(w)) for w in weights_for_rounding] - ) - if i >= 10000: - print( - f"No good assignment of max_children found - allocated_budgets = {allocated_budgets}" - ) - break - print(f"{self.ID} Allocated grandchild budgets: {allocated_budgets}") - - # Filter based on allocated budget - stage2_keep_mask = allocated_budgets > 0 - final_survivor_original_indices = stage1_kept_indices[stage2_keep_mask] - final_max_children = allocated_budgets[stage2_keep_mask] - print( - f"{self.ID} Indices surviving after stage 2 filtering: {final_survivor_original_indices}" - ) - else: - print(f"{self.ID}Stage 2: Skipped (no branches survived stage 1).") - # Ensure lists are empty - final_survivor_original_indices = [] - final_max_children = [] - # --- Post-Sampling Processing --- - num_kept_branches = len(final_survivor_original_indices) + num_kept_branches = len(survivor_indices) print(f"{self.ID}Total branches surviving all filtering: {num_kept_branches}") # Check if any sampling/filtering actually occurred compared to the initial set @@ -627,10 +565,10 @@ def branch_and_sample( # Retrieve the tensors for the surviving branches # Assume optional truncation happened earlier, resulting in `thetas_truncated` # If no truncation, use theta_purified directly. Let's assume we use theta_purified for now. - thetas_survivors = theta_purified[final_survivor_original_indices] + thetas_survivors = theta_purified[survivor_indices] # Calculate total probability of *kept* branches (needed if sampling occurred) - probs_survived = branch_probs[final_survivor_original_indices] + probs_survived = branch_probs[survivor_indices] total_prob_survived = probs_survived.sum() # --- Error measurement & rejection --- @@ -638,9 +576,9 @@ def branch_and_sample( print( f"{self.ID}self.max_children = {self.max_children} sum(children max_children) = {sum(final_max_children)} children max_children = {final_max_children}" ) - branch_indices = final_survivor_original_indices + branch_indices = survivor_indices print( - f"{self.ID}branch_indices with nonzero max_children = {final_survivor_original_indices}" + f"{self.ID}branch_indices with nonzero max_children = {survivor_indices}" ) trace_distances_with_sampling = measure.LMR_trace_distances(theta_orig, thetas_survivors) @@ -861,7 +799,7 @@ def branch_and_sample( ) ) print( - f"{self.ID} Child {i} (orig index {final_survivor_original_indices[i]}): prob={prob:.4f}, prob={child_prob:.6f}, max_children={child_max_children}" + f"{self.ID} Child {i} (orig index {survivor_indices[i]}): prob={prob:.4f}, prob={child_prob:.6f}, max_children={child_max_children}" ) # Verify prob conservation diff --git a/wavefunction_branching/utils/branch_rounding.py b/wavefunction_branching/utils/branch_rounding.py new file mode 100644 index 0000000..0b8adbc --- /dev/null +++ b/wavefunction_branching/utils/branch_rounding.py @@ -0,0 +1,128 @@ +import numpy as np +from samplics.sampling import SampleSelection +from samplics.utils import SelectMethod + + +def sampford_mask(sample_size, weights): + """ + PPS-without-replacement sample of fixed size using Rao–Sampford (Sampford) sampling. + + Sampford sampling is a Probability Proportional to Size (PPS) sampling method + that selects a fixed number of units without replacement, where each unit's + inclusion probability is approximately proportional to its weight (measure of size). + Unlike simple random sampling, units with larger weights are more likely to be + selected, making it useful for situations where certain items should be prioritized + based on their size or importance. + + This implementation uses the Rao-Sampford method, which ensures that the + first-order inclusion probabilities are proportional to the weights while + maintaining a fixed sample size. + + Inputs: + sample_size : int + Number of units to select. Must satisfy 1 <= sample_size <= B. + weights : array-like of shape (B,) + Nonnegative weights for each unit. Larger weights correspond to + higher inclusion probabilities. Also called "measure of size" in + sampling terminology. + + Output: + mask : ndarray of shape (B,), dtype=int + Binary indicator vector where 1 indicates selection. + Exactly `sample_size` entries are 1, and the rest are 0. + Selection probabilities are approximately proportional to `weights`. + """ + weights = np.asarray(weights, dtype=float) + population_size = weights.shape[0] + if np.any(weights < 0): + raise ValueError("All weights must be >= 0.") + if sample_size < 1 or sample_size > population_size: + raise ValueError("sample_size must be between 1 and B (population size) inclusive.") + if np.all(weights == 0): + raise ValueError("At least one weight must be > 0.") + + # Convert sample_size to Python int (samplics doesn't like numpy int64) + sample_size = int(sample_size) + + # Population unit labels 0..B-1 + unit_ids = np.arange(population_size) + + # Set up Sampford / Rao–Sampford PPS without replacement + sampler = SampleSelection( + method=SelectMethod.pps_rs, # Rao–Sampford PPS + wr=False, # sample without replacement + strat=False, # not stratified + ) + + # Run the selection + # sample_flags: length-B array of 0/1 telling whether each unit was selected + # hits: list/array of the selected unit IDs + # probs: first-order inclusion probs for each unit + sample_flags, hits, probs = sampler.select( + samp_unit=unit_ids, + samp_size=sample_size, + mos=weights, # "measure of size" = weights + ) + + # Convert to binary mask with dtype int + mask = np.asarray(sample_flags, dtype=int) + + # Sanity check: exactly sample_size ones + assert mask.sum() == sample_size, "Sampler did not return the requested sample size." + + return mask + + +def probabilistic_round_child_budget(max_children: int, probs): + """ + Randomly distributes max_children probabilistically. + + This function is used to round the child budget for each branch. + It takes a number of objects to distribute and a list of probabilities for bins to distribute them into. + It returns a list of integers, one for each bin, indicating the number of objects to put each bin in order to match the probabilities. + The sum of the list will be equal to max_children. + The function uses a stochastic rounding method to distribute the objects. + It attempts to distribute the objects as close to the expected values as possible, using randomness only for the fractional parts. + + Inputs: + max_children : int + Number of objects to distribute + probs : list (or numpy array) of floats + Children will be distributed according to these probabilities. + + Output: + child_allocations: numpy array of ints + This will sum to max_children, and be distributed according to probs + """ + max_children = int(max_children) + + # Check for all-zero probabilities before normalization + prob_sum = np.sum(probs) + if prob_sum == 0 or np.isclose(prob_sum, 0.0): + raise ValueError("All probabilities are zero. At least one probability must be > 0.") + + # normalize the probabilities + probs = probs / prob_sum + + # array of (non-integer) expected children per branch + expected_children = max_children * probs + + # round down those numbers, and make sure they're integers + floored_children = np.floor(expected_children).astype(int) + + # what's the fractional children on each branch + fractional_children = expected_children - floored_children + + num_children_to_split = max_children - np.sum(floored_children) + + if num_children_to_split == 0: + child_allocations = floored_children + else: + child_allocations = floored_children + sampford_mask( + num_children_to_split, fractional_children + ) + + if np.sum(child_allocations) != max_children: + raise RuntimeError("Didn't get the right child allocations") + + return child_allocations