From b85bc55712d9a8fc98e0b47e4e7e2885d8581988 Mon Sep 17 00:00:00 2001 From: Jordan Taylor Date: Sun, 2 Nov 2025 17:09:24 +0000 Subject: [PATCH 1/5] added cnot entropy threshold --- .../decompositions/decompositions.py | 186 +++++++++++++++++- .../evolve_and_branch_finite.py | 8 + 2 files changed, 188 insertions(+), 6 deletions(-) diff --git a/wavefunction_branching/decompositions/decompositions.py b/wavefunction_branching/decompositions/decompositions.py index 2fa586a..38065fa 100644 --- a/wavefunction_branching/decompositions/decompositions.py +++ b/wavefunction_branching/decompositions/decompositions.py @@ -321,6 +321,154 @@ def no_graddesc_different_blocks( return LSR_to_purification(L, S, R, keep_classical=True) +############################################################################################################ +# tolEntropy calculaion functions +############################################################################################################ + +# LeftSplittingTensor : TypeAlias = Complex[NDArray, "nBranches dVirt_L dSlow"] +# BlockDiagTensor : TypeAlias = Complex[NDArray, "dPhys nBranches dSlow dSlow"] +# RightSplittingTensor : TypeAlias = Complex[NDArray, "nBranches dSlow dVirt_R"] +# MatrixStack : TypeAlias = Complex[NDArray, "dPhys dVirt_L dVirt_R"] + + +def calculate_entropy_vertical( + tensor: MatrixStack, L: LeftSplittingTensor, S: BlockDiagTensor, R: RightSplittingTensor +): + """ + Calculate the residual vertical entanglement entropy after an attempted decomposision into + block-diagonal form. Vertical entanglement entropy will be zero for exact blocks, where each + block is identical. Zero vertical entanglement corresponds to a Bell-pair between L and R, not + entangled with S. This is the entropy of rho_fast in https://arxiv.org/abs/2308.04291 + + Inputs: + tensor: MatrixStack: Complex[NDArray, "dPhys dVirt_L dVirt_R"] + The original tensor which has attempted to be decomposed into L S R + as tensor approx = einsum(L, S, R, "b L l, p b l r, b r R -> p L R") + dPhys indexes the matrix in the stack. Each matrix is a dVirt_L x dVirt_R dimensional. + + L: LeftSplittingTensor: Complex[NDArray, "nBranches dVirt_L dSlow"] + The (unitary) splitting matrix decomposing the tensor into a block-diagonal form from + the left. + + S: BlockDiagTensor: Complex[NDArray, "dPhys nBranches dSlow dSlow"] + The central block-diagonal stack of matrices, where dPhys indexes the matrix in the stack, + nBranches indexes which block in the block-diagonal structure, and each block is of + dimension dSlow x dSlow. + + R: RightSplittingTensor: Complex[NDArray, "nBranches dSlow dVirt_R"] + The (unitary) splitting matrix decomposing the tensor into a block-diagonal form from + the right. + + Output: + entropy: + The vertical entanglement entropy between the slow and fast degrees of freedom. + This is zero for Bell-like entanglement (but nonzero for GHZ or non-branch-like entanglement) + """ + # Assume L and R are unitary. + # L should be a unitary map from dVirt_L to nBranches x dSlow + # R should be a unitary map from dVirt_R to nBranches x dSlow + + L_mat = rearrange(L, "nBranches dVirt_L dSlow -> dVirt_L (nBranches dSlow)") + uL_mat = utils.unitize(L_mat) + uL = rearrange(uL_mat, "dVirt_L (nBranches dSlow) -> nBranches dVirt_L dSlow") + + R_mat = rearrange(R, "nBranches dSlow dVirt_R -> dVirt_R (nBranches dSlow)") + uR_mat = utils.unitize(R_mat) + uR = rearrange(uR_mat, "dVirt_R (nBranches dSlow) -> nBranches dSlow dVirt_R") + + # Compute the residual entanglement between slow and fast degrees of freedom + overlap = einsum(np.conj(uL), tensor, np.conj(uR), "bl L l, p L R, br r R -> bl br p l r") + rho_fast = einsum(overlap, np.conj(overlap), "bl br p l r, blc brc p l r -> bl br blc brc") + vertical_spectrum = np.linalg.eigvalsh( + rearrange(rho_fast, "bl br blc brc -> (bl br) (blc brc)") + ) + vertical_spectrum = vertical_spectrum / np.sum(vertical_spectrum) + # Filter out zero eigenvalues to avoid log(0) + vertical_spectrum = vertical_spectrum[vertical_spectrum > 0] + vertical_entropy = -np.sum(np.log2(vertical_spectrum) * vertical_spectrum) + return vertical_entropy + + +def calculate_entropy_cnot( + tensor: MatrixStack, L: LeftSplittingTensor, S: BlockDiagTensor, R: RightSplittingTensor +): + """ + Calculate the residual non-block entanglement entropy after-like entangement after an attempted decomposision into + block-diagonal form. This entanglement entropy will be zero for exact blocks. + Zero cnot entanglement corresponds to GHZ-like or Bell-pair between L and R, not + entangled with S. This is the entropy of rho_fast in https://arxiv.org/abs/2308.04291 + + Inputs: + tensor: MatrixStack: Complex[NDArray, "dPhys dVirt_L dVirt_R"] + The original tensor which has attempted to be decomposed into L S R + as tensor approx = einsum(L, S, R, "b L l, p b l r, b r R -> p L R") + dPhys indexes the matrix in the stack. Each matrix is a dVirt_L x dVirt_R dimensional. + + L: LeftSplittingTensor: Complex[NDArray, "nBranches dVirt_L dSlow"] + The (unitary) splitting matrix decomposing the tensor into a block-diagonal form from + the left. + + S: BlockDiagTensor: Complex[NDArray, "dPhys nBranches dSlow dSlow"] + The central block-diagonal stack of matrices, where dPhys indexes the matrix in the stack, + nBranches indexes which block in the block-diagonal structure, and each block is of + dimension dSlow x dSlow. + + R: RightSplittingTensor: Complex[NDArray, "nBranches dSlow dVirt_R"] + The (unitary) splitting matrix decomposing the tensor into a block-diagonal form from + the right. + + Output: + entropy: + The vertical entanglement entropy between the slow and fast degrees of freedom. + This is zero for Bell-like entanglement (but nonzero for GHZ or non-branch-like entanglement) + """ + # Assume L and R are unitary. + # L should be a unitary map from dVirt_L to nBranches x dSlow + # R should be a unitary map from dVirt_R to nBranches x dSlow + + L_mat = rearrange(L, "nBranches dVirt_L dSlow -> dVirt_L (nBranches dSlow)") + uL_mat = utils.unitize(L_mat) + uL = rearrange(uL_mat, "dVirt_L (nBranches dSlow) -> nBranches dVirt_L dSlow") + + R_mat = rearrange(R, "nBranches dSlow dVirt_R -> dVirt_R (nBranches dSlow)") + uR_mat = utils.unitize(R_mat) + uR = rearrange(uR_mat, "dVirt_R (nBranches dSlow) -> nBranches dSlow dVirt_R") + + # Compute the residual entanglement between slow and fast degrees of freedom + overlap = einsum(np.conj(uL), tensor, np.conj(uR), "bl L l, p L R, br r R -> bl br p l r") + + # A CNOT is composed of an XOR tensor and a COPY tensor. see https://arxiv.org/pdf/1708.00006 + xor_gate = np.zeros((2, 2, 2)) + 0.0j * np.zeros((2, 2, 2)) + xor_gate[0, 0, 0] += 1.0 + xor_gate[1, 1, 0] += 1.0 + xor_gate[0, 1, 1] += 1.0 + xor_gate[1, 0, 1] += 1.0 + + xor_bottom_right = einsum(overlap, xor, "bl br p l r, bl br br_new -> bl br_new p l r") + xor_bottom_left = einsum(overlap, xor, "bl br p l r, bl br bl_new -> bl_new br p l r") + + rho_bottom_right = einsum( + xor_bottom_right, np.conj(xor_bottom_right), "bl br p l r, bl br p l r_prime -> r r_prime" + ) + rho_bottom_left = einsum( + xor_bottom_left, np.conj(xor_bottom_left), "bl br p l r, bl br p l_prime r -> l l_prime" + ) + + spectrum_bottom_right = np.linalg.eigvalsh(rho_bottom_right) + spectrum_bottom_left = np.linalg.eigvalsh(rho_bottom_left) + + spectrum_bottom_right = spectrum_bottom_right / np.sum(spectrum_bottom_right) + spectrum_bottom_left = spectrum_bottom_left / np.sum(spectrum_bottom_left) + # Filter out zero eigenvalues to avoid log(0) + spectrum_bottom_right = spectrum_bottom_right[spectrum_bottom_right > 0] + spectrum_bottom_left = spectrum_bottom_left[spectrum_bottom_left > 0] + + entropy_bottom_right = -np.sum(np.log2(spectrum_bottom_right) * spectrum_bottom_right) + entropy_bottom_left = -np.sum(np.log2(spectrum_bottom_left) * spectrum_bottom_left) + + return (entropy_bottom_right + entropy_bottom_left) / 2.0 + + ############################################################################################################ # Combination function ############################################################################################################ @@ -346,6 +494,8 @@ def branch_from_theta( ], n_steps_iterative=500, n_steps_graddesc=1000, + tolEntropy=None, + tolEntropy_kind: Literal["cnot", "vertical"] | None = "cnot", ) -> tuple[PurificationMatrixStack, dict]: if iterative_method is None or iterative_method == "None": assert graddesc_method is None or graddesc_method == "None", ( @@ -401,12 +551,32 @@ def branch_from_theta( L, S, R, info = fn_iterative(tensor, n_steps=n_steps_iterative) t2 = time.time() - # # Normalize the purification - # theta = LSR_to_purification(L, S, R, keep_classical) - # norm = einsum(theta, np.conj(theta), 'b p l r, b p l r -> ') - # S /= np.sqrt(norm/norm_orig) - - theta_purified = fn_graddesc(tensor, L, S, R, n_steps=n_steps_graddesc) + rejected = info.get("rejected"), False + + # Determine if the iterative method failed to find a good decomposition + if tolEntropy is not None and tolEntropy_kind is not None: + if tolEntropy_kind == "cnot": + entropy = calculate_entropy_cnot(tensor, L, S, R) + elif tolEntropy_kind == "vertical": + entropy = calculate_entropy_vertical(tensor, L, S, R) + else: + assert False, f"unknown tolEntropy_kind {tolEntropy_kind}" + + info["entropy"] = entropy + info["tolEntropy_kind"] = tolEntropy_kind + + if entropy > tolEntropy: + info["rejected"] = True + print( + " Further optimization was rejected as initial optimization failed to find a good decomposition." + ) + print(f" entropy = {entropy} (tolEntropy = {tolEntropy})") + print(f" tolEntropy_kind = {tolEntropy_kind}") + + # Only perform gradient descent steps if info["rejected"] == False + theta_purified = fn_graddesc( + tensor, L, S, R, n_steps=0 if info["rejected"] else n_steps_graddesc + ) t3 = time.time() return theta_purified, {"iterative_time": t2 - t1, "graddesc_time": t3 - t2, **info} @@ -435,6 +605,8 @@ def branch( coarsegrain_size=2, n_steps_iterative=500, n_steps_graddesc=1000, + tolEntropy=None, + tolEntropy_kind: Literal["cnot", "vertical"] | None = "cnot", ) -> tuple[PurificationMatrixStack, dict]: if coarsegrain_from == "half": coarsegrain_from = int(psi.L / 2 - coarsegrain_size / 2) @@ -452,4 +624,6 @@ def branch( graddesc_method, n_steps_iterative=n_steps_iterative, n_steps_graddesc=n_steps_graddesc, + tolEntropy=tolEntropy, + tolEntropy_kind=tolEntropy_kind, ) diff --git a/wavefunction_branching/evolve_and_branch_finite.py b/wavefunction_branching/evolve_and_branch_finite.py index 0272f2a..7cfe452 100644 --- a/wavefunction_branching/evolve_and_branch_finite.py +++ b/wavefunction_branching/evolve_and_branch_finite.py @@ -225,6 +225,8 @@ class BranchingMPSConfig: save_full_state: bool = False necessary_local_truncation_improvement_factor: float = 1.0 necessary_global_truncation_improvement_factor: float = 1.0 + tolEntropy: float | None = None + tolEntropy_kind: Literal["cnot", "vertical"] | None = "cnot" def check_canonical_form(psi): @@ -1571,6 +1573,8 @@ def main( necessary_local_truncation_improvement_factor=1.1, necessary_global_truncation_improvement_factor=1.1, seed=None, + tolEntropy=None, + tolEntropy_kind: Literal["cnot", "vertical"] | None = "cnot", ): print("Running main") print(f"Iterative method: {iterative_method}") @@ -1655,6 +1659,8 @@ def main( save_full_state=save_full_state, necessary_local_truncation_improvement_factor=necessary_local_truncation_improvement_factor, necessary_global_truncation_improvement_factor=necessary_global_truncation_improvement_factor, + tolEntropy=tolEntropy, + tolEntropy_kind=tolEntropy_kind, ) branch_function = partial( @@ -1662,6 +1668,8 @@ def main( iterative_method=iterative_method, graddesc_method=graddesc_method, n_steps_graddesc=maxiter_heuristic, + tolEntropy=tolEntropy, + tolEntropy_kind=tolEntropy_kind, ) print("\n\n\n\nBranching evolution:") From d8b63edbb83c564151019fdf70fb7ab817177852 Mon Sep 17 00:00:00 2001 From: jordansauce Date: Sun, 9 Nov 2025 21:13:31 +0800 Subject: [PATCH 2/5] small fix --- make_sbatch_files.py | 2 ++ .../decompositions/decompositions.py | 16 ++++------------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/make_sbatch_files.py b/make_sbatch_files.py index 817006b..ba52e93 100644 --- a/make_sbatch_files.py +++ b/make_sbatch_files.py @@ -43,6 +43,8 @@ def update_dict_copy(old_dict, new_entries): maxiter_heuristic=6000, necessary_local_truncation_improvement_factor=1.0, necessary_global_truncation_improvement_factor=1.0, + tolEntropy=0.01, + tolEntropy_kind="cnot", seed=seed, ) diff --git a/wavefunction_branching/decompositions/decompositions.py b/wavefunction_branching/decompositions/decompositions.py index 38065fa..460fbf4 100644 --- a/wavefunction_branching/decompositions/decompositions.py +++ b/wavefunction_branching/decompositions/decompositions.py @@ -537,22 +537,14 @@ def branch_from_theta( "graddesc_global_reconstruction_split_non_interfering": graddesc_global_reconstruction_split_non_interfering, } - keep_classical = True - if graddesc_method == "rho_LM_MR_trace_norm_discard_classical_identical_blocks": - keep_classical = False - if graddesc_method is None and "discard_classical" in iterative_method: - keep_classical = False fn_graddesc = fn_dict_graddesc[graddesc_method] - norm_orig = einsum(theta_scrambled, np.conj(theta_scrambled), "p l r, p l r -> ") tensor = utils.make_square(theta_scrambled, 2) t1 = time.time() L, S, R, info = fn_iterative(tensor, n_steps=n_steps_iterative) t2 = time.time() - rejected = info.get("rejected"), False - # Determine if the iterative method failed to find a good decomposition if tolEntropy is not None and tolEntropy_kind is not None: if tolEntropy_kind == "cnot": @@ -560,7 +552,7 @@ def branch_from_theta( elif tolEntropy_kind == "vertical": entropy = calculate_entropy_vertical(tensor, L, S, R) else: - assert False, f"unknown tolEntropy_kind {tolEntropy_kind}" + raise AssertionError(f"unknown tolEntropy_kind {tolEntropy_kind}") info["entropy"] = entropy info["tolEntropy_kind"] = tolEntropy_kind @@ -573,10 +565,10 @@ def branch_from_theta( print(f" entropy = {entropy} (tolEntropy = {tolEntropy})") print(f" tolEntropy_kind = {tolEntropy_kind}") + rejected = info.get("rejected"), False + # Only perform gradient descent steps if info["rejected"] == False - theta_purified = fn_graddesc( - tensor, L, S, R, n_steps=0 if info["rejected"] else n_steps_graddesc - ) + theta_purified = fn_graddesc(tensor, L, S, R, n_steps=(0 if rejected else n_steps_graddesc)) t3 = time.time() return theta_purified, {"iterative_time": t2 - t1, "graddesc_time": t3 - t2, **info} From cc7bd74ed1a4c002f0ab742e39421b04ecbf531a Mon Sep 17 00:00:00 2001 From: jordansauce Date: Sun, 9 Nov 2025 23:20:04 +0800 Subject: [PATCH 3/5] Major fixes to vertical / cnot entropy calculation, GHZ tests for them --- tests/test_decompositions.py | 69 ++++++++++++++++++- .../decompositions/decompositions.py | 32 ++++++--- 2 files changed, 90 insertions(+), 11 deletions(-) diff --git a/tests/test_decompositions.py b/tests/test_decompositions.py index 857c9c3..36a3bbc 100644 --- a/tests/test_decompositions.py +++ b/tests/test_decompositions.py @@ -1,14 +1,25 @@ # %% import numpy as np import pytest -from opt_einops import einsum +from opt_einops import einsum, rearrange from scipy.stats import unitary_group # Random unitaries from instantaneous_benchmarks.benchmark_decompositions import benchmark_blockdiag_method from instantaneous_benchmarks.generate_test_inputs import ( block_diagonal_matrices_exp_decaying_spectrum, ) -from wavefunction_branching.decompositions.decompositions import branch_from_theta +from wavefunction_branching.decompositions.decompositions import ( + branch_from_theta, + calculate_entropy_cnot, + calculate_entropy_vertical, +) +from wavefunction_branching.type_aliases import ( + BlockDiagTensor, + LeftSplittingTensor, + MatrixStack, + PurificationMatrixStack, + RightSplittingTensor, +) ITERATIVE_METHODS = ( "bell_discard_classical", @@ -250,4 +261,58 @@ def check_decompositions(N=4, branch_chi=5, noise_introduced=1e-8, n_trials=5): print("\n\nDone.") +def create_tensors_with_zero_vertical_entropy( + dPhys: int, + nBranches: int, + dVirt_L: int, + dVirt_R: int, + dSlow: int, +) -> tuple[MatrixStack, LeftSplittingTensor, np.ndarray, RightSplittingTensor]: + tensor_top = np.random.rand(dPhys, dSlow, dSlow) + 1j * np.random.rand(dPhys, dSlow, dSlow) + tensor_bottom = np.diag(np.random.rand(nBranches) + 1.0j * np.random.rand(nBranches)) + + # Create random unitaries to scramble the tensor + U = unitary_group.rvs(dVirt_L) + U = rearrange(U, "L (l bl) -> bl L l", l=dSlow, bl=nBranches) + Vh = unitary_group.rvs(dVirt_R) + Vh = rearrange(Vh, "(r br) R -> br r R", r=dSlow, br=nBranches) + + tensor_scrambled = einsum( + U, tensor_top, tensor_bottom, Vh, "bl L l, p l r, bl br, br r R -> p L R" + ) + return tensor_scrambled, U, tensor_top, Vh + + +def test_calculate_entropy_vertical(): + # Create test tensor, which should have zero vertical entanglement entropy + tensor_scrambled, U, tensor_top, Vh = create_tensors_with_zero_vertical_entropy( + dPhys=4, + nBranches=2, + dVirt_L=16, + dVirt_R=16, + dSlow=8, + ) + # Check that the vertical entanglement entropy is zero + vertical_entropy = calculate_entropy_vertical(tensor_scrambled, U, tensor_top, Vh) + assert np.isclose(vertical_entropy, 0.0, atol=1e-10), ( + f"vertical_entropy = {vertical_entropy}, expected close to 0.0" + ) + + +def test_calcuate_entropy_cnot_ghz(): + # Create test tensor, which should have zero vertical entanglement entropy + tensor_scrambled, U, tensor_top, Vh = create_tensors_with_zero_vertical_entropy( + dPhys=4, + nBranches=2, + dVirt_L=16, + dVirt_R=16, + dSlow=8, + ) + # Check that the cnot entanglement entropy is zero + cnot_entropy = calculate_entropy_cnot(tensor_scrambled, U, tensor_top, Vh) + assert np.isclose(cnot_entropy, 0.0, atol=1e-10), ( + f"cnot_entropy = {cnot_entropy}, expected close to 0.0" + ) + + # %% diff --git a/wavefunction_branching/decompositions/decompositions.py b/wavefunction_branching/decompositions/decompositions.py index 460fbf4..eb1aa2b 100644 --- a/wavefunction_branching/decompositions/decompositions.py +++ b/wavefunction_branching/decompositions/decompositions.py @@ -370,11 +370,15 @@ def calculate_entropy_vertical( L_mat = rearrange(L, "nBranches dVirt_L dSlow -> dVirt_L (nBranches dSlow)") uL_mat = utils.unitize(L_mat) - uL = rearrange(uL_mat, "dVirt_L (nBranches dSlow) -> nBranches dVirt_L dSlow") + uL = rearrange( + uL_mat, "dVirt_L (nBranches dSlow) -> nBranches dVirt_L dSlow", nBranches=L.shape[0] + ) R_mat = rearrange(R, "nBranches dSlow dVirt_R -> dVirt_R (nBranches dSlow)") uR_mat = utils.unitize(R_mat) - uR = rearrange(uR_mat, "dVirt_R (nBranches dSlow) -> nBranches dSlow dVirt_R") + uR = rearrange( + uR_mat, "dVirt_R (nBranches dSlow) -> nBranches dSlow dVirt_R", nBranches=R.shape[0] + ) # Compute the residual entanglement between slow and fast degrees of freedom overlap = einsum(np.conj(uL), tensor, np.conj(uR), "bl L l, p L R, br r R -> bl br p l r") @@ -428,13 +432,17 @@ def calculate_entropy_cnot( L_mat = rearrange(L, "nBranches dVirt_L dSlow -> dVirt_L (nBranches dSlow)") uL_mat = utils.unitize(L_mat) - uL = rearrange(uL_mat, "dVirt_L (nBranches dSlow) -> nBranches dVirt_L dSlow") + uL = rearrange( + uL_mat, "dVirt_L (nBranches dSlow) -> nBranches dVirt_L dSlow", nBranches=L.shape[0] + ) R_mat = rearrange(R, "nBranches dSlow dVirt_R -> dVirt_R (nBranches dSlow)") uR_mat = utils.unitize(R_mat) - uR = rearrange(uR_mat, "dVirt_R (nBranches dSlow) -> nBranches dSlow dVirt_R") + uR = rearrange( + uR_mat, "dVirt_R (nBranches dSlow) -> nBranches dSlow dVirt_R", nBranches=R.shape[0] + ) - # Compute the residual entanglement between slow and fast degrees of freedom + # Compute the guess for he middle tensor overlap = einsum(np.conj(uL), tensor, np.conj(uR), "bl L l, p L R, br r R -> bl br p l r") # A CNOT is composed of an XOR tensor and a COPY tensor. see https://arxiv.org/pdf/1708.00006 @@ -444,14 +452,20 @@ def calculate_entropy_cnot( xor_gate[0, 1, 1] += 1.0 xor_gate[1, 0, 1] += 1.0 - xor_bottom_right = einsum(overlap, xor, "bl br p l r, bl br br_new -> bl br_new p l r") - xor_bottom_left = einsum(overlap, xor, "bl br p l r, bl br bl_new -> bl_new br p l r") + assert xor_gate[0, 1, 0] == xor_gate[1, 0, 0] + assert xor_gate[0, 0, 1] == xor_gate[1, 0, 0] + + assert xor_gate[0, 1, 1] == xor_gate[1, 1, 0] + assert xor_gate[0, 1, 1] == xor_gate[1, 0, 1] + + xor_bottom_right = einsum(overlap, xor_gate, "bl br p l r, bl br br_new -> bl br_new p l r") + xor_bottom_left = einsum(overlap, xor_gate, "bl br p l r, bl br bl_new -> bl_new br p l r") rho_bottom_right = einsum( - xor_bottom_right, np.conj(xor_bottom_right), "bl br p l r, bl br p l r_prime -> r r_prime" + xor_bottom_right, np.conj(xor_bottom_right), "bl br p l r, bl br_prime p l r -> br br_prime" ) rho_bottom_left = einsum( - xor_bottom_left, np.conj(xor_bottom_left), "bl br p l r, bl br p l_prime r -> l l_prime" + xor_bottom_left, np.conj(xor_bottom_left), "bl br p l r, bl_prime br p l r -> bl bl_prime" ) spectrum_bottom_right = np.linalg.eigvalsh(rho_bottom_right) From bb001ec5a42a69e9eeffcf53695524a62b339675 Mon Sep 17 00:00:00 2001 From: jordansauce Date: Mon, 10 Nov 2025 02:08:51 +0800 Subject: [PATCH 4/5] Add test_calcuate_entropy_cnot_ghz --- tests/test_decompositions.py | 48 +++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/test_decompositions.py b/tests/test_decompositions.py index 36a3bbc..8c6d2dc 100644 --- a/tests/test_decompositions.py +++ b/tests/test_decompositions.py @@ -299,7 +299,7 @@ def test_calculate_entropy_vertical(): ) -def test_calcuate_entropy_cnot_ghz(): +def test_calcuate_entropy_cnot_bell(): # Create test tensor, which should have zero vertical entanglement entropy tensor_scrambled, U, tensor_top, Vh = create_tensors_with_zero_vertical_entropy( dPhys=4, @@ -315,4 +315,50 @@ def test_calcuate_entropy_cnot_ghz(): ) +def create_tensors_with_ghz_structure( + dPhys: int, + nBranches: int, + dVirt_L: int, + dVirt_R: int, + dSlow: int, +) -> tuple[MatrixStack, LeftSplittingTensor, np.ndarray, RightSplittingTensor]: + assert dVirt_L == nBranches * dSlow + assert dVirt_R == nBranches * dSlow + + tensor_top: BlockDiagTensor = np.random.rand( + dPhys, nBranches, dSlow, dSlow + ) + 1j * np.random.rand(dPhys, nBranches, dSlow, dSlow) + tensor_bottom = np.zeros((nBranches, nBranches, nBranches)) * 0.0j + for i in range(nBranches): + tensor_bottom[i, i, i] = np.random.rand() + 1.0j * np.random.rand() + + # Create random unitaries to scramble the tensor + U = unitary_group.rvs(dVirt_L) + U = rearrange(U, "L (l bl) -> bl L l", l=dSlow, bl=nBranches) + Vh = unitary_group.rvs(dVirt_R) + Vh = rearrange(Vh, "(r br) R -> br r R", r=dSlow, br=nBranches) + + tensor_scrambled = einsum( + U, tensor_top, tensor_bottom, Vh, "bl L l, p bt l r, bt bl br, br r R -> p L R" + ) + return tensor_scrambled, U, tensor_top, Vh + + +def test_calcuate_entropy_cnot_ghz(): + # Create test tensor, which should have zero vertical entanglement entropy + tensor_scrambled, U, tensor_top, Vh = create_tensors_with_ghz_structure( + dPhys=4, + nBranches=2, + dVirt_L=16, + dVirt_R=16, + dSlow=8, + ) + # Check that the cnot entanglement entropy is zero + cnot_entropy = calculate_entropy_cnot(tensor_scrambled, U, tensor_top, Vh) + assert np.isclose(cnot_entropy, 0.0, atol=1e-10), ( + f"cnot_entropy = {cnot_entropy}, expected close to 0.0" + ) + + +test_calcuate_entropy_cnot_ghz() # %% From e533b8f99c6b005c3c109a6fe52987012786d2e7 Mon Sep 17 00:00:00 2001 From: jordansauce Date: Mon, 24 Nov 2025 08:21:49 +0800 Subject: [PATCH 5/5] minor fixes --- make_sbatch_files.py | 56 ++++++++++++++++++- .../evolve_and_branch_finite.py | 6 +- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/make_sbatch_files.py b/make_sbatch_files.py index ba52e93..b92be52 100644 --- a/make_sbatch_files.py +++ b/make_sbatch_files.py @@ -43,7 +43,7 @@ def update_dict_copy(old_dict, new_entries): maxiter_heuristic=6000, necessary_local_truncation_improvement_factor=1.0, necessary_global_truncation_improvement_factor=1.0, - tolEntropy=0.01, + tolEntropy=None, tolEntropy_kind="cnot", seed=seed, ) @@ -162,6 +162,8 @@ def update_dict_copy(old_dict, new_entries): t_evo=(0.25 * n_sites - 2.0), necessary_local_truncation_improvement_factor=0.0, necessary_global_truncation_improvement_factor=0.0, + tolEntropy=0.01, + tolEntropy_kind="vertical", ), dict( iterative_method="bell_original_threshold_discard_classical", @@ -172,6 +174,8 @@ def update_dict_copy(old_dict, new_entries): t_evo=(0.25 * n_sites - 2.0), necessary_local_truncation_improvement_factor=0.0, necessary_global_truncation_improvement_factor=0.0, + tolEntropy=0.01, + tolEntropy_kind="vertical", ), dict( iterative_method="bell_original_threshold_keep_classical", @@ -182,6 +186,8 @@ def update_dict_copy(old_dict, new_entries): t_evo=(0.25 * n_sites - 2.0), necessary_local_truncation_improvement_factor=0.0, necessary_global_truncation_improvement_factor=0.0, + tolEntropy=0.01, + tolEntropy_kind="vertical", ), dict( iterative_method="bell_original_threshold_keep_classical", @@ -192,6 +198,8 @@ def update_dict_copy(old_dict, new_entries): t_evo=(0.25 * n_sites - 2.0), necessary_local_truncation_improvement_factor=0.0, necessary_global_truncation_improvement_factor=0.0, + tolEntropy=0.01, + tolEntropy_kind="vertical", ), dict( iterative_method="bell_original_threshold_keep_classical", @@ -202,6 +210,8 @@ def update_dict_copy(old_dict, new_entries): t_evo=(0.25 * n_sites - 2.0), necessary_local_truncation_improvement_factor=0.0, necessary_global_truncation_improvement_factor=0.0, + tolEntropy=0.01, + tolEntropy_kind="vertical", ), dict( iterative_method="bell_original_threshold_keep_classical", @@ -212,6 +222,8 @@ def update_dict_copy(old_dict, new_entries): t_evo=(0.25 * n_sites - 2.0), necessary_local_truncation_improvement_factor=0.0, necessary_global_truncation_improvement_factor=0.0, + tolEntropy=0.01, + tolEntropy_kind="vertical", ), dict( iterative_method="bell_original_threshold_keep_classical", @@ -222,6 +234,8 @@ def update_dict_copy(old_dict, new_entries): t_evo=(0.25 * n_sites - 2.0), necessary_local_truncation_improvement_factor=0.0, necessary_global_truncation_improvement_factor=0.0, + tolEntropy=0.01, + tolEntropy_kind="vertical", ), # Other methods dict( @@ -296,6 +310,46 @@ def update_dict_copy(old_dict, new_entries): n_sites=n_sites, t_evo=(0.25 * n_sites - 2.0), ), + + + ## Cnot entropy threshold + + dict( + iterative_method="vertical_svd_micro_bsvd", + graddesc_method="graddesc_global_reconstruction_non_interfering", + chi_max=chi_max, + chi_to_branch=chi_max, # int(chi_max*0.75), + n_sites=n_sites, + t_evo=(0.25 * n_sites - 2.0), + necessary_local_truncation_improvement_factor=0.0, + necessary_global_truncation_improvement_factor=0.0, + tolEntropy=0.01, + tolEntropy_kind="cnot", + ), + dict( + iterative_method="vertical_svd_micro_bsvd", + graddesc_method="graddesc_global_reconstruction_split_non_interfering", + chi_max=chi_max, + chi_to_branch=chi_max, # int(chi_max*0.75), + n_sites=n_sites, + t_evo=(0.25 * n_sites - 2.0), + necessary_local_truncation_improvement_factor=0.0, + necessary_global_truncation_improvement_factor=0.0, + tolEntropy=0.01, + tolEntropy_kind="cnot", + ), + dict( + iterative_method="vertical_svd_micro_bsvd", + graddesc_method="rho_LM_MR_trace_norm", + chi_max=chi_max, + chi_to_branch=chi_max, # int(chi_max*0.75), + n_sites=n_sites, + t_evo=(0.25 * n_sites - 2.0), + necessary_local_truncation_improvement_factor=0.0, + necessary_global_truncation_improvement_factor=0.0, + tolEntropy=0.01, + tolEntropy_kind="cnot", + ), # dict( # iterative_method = 'vertical_svd_micro_bsvd', # graddesc_method = 'rho_half_LM_MR_trace_norm', diff --git a/wavefunction_branching/evolve_and_branch_finite.py b/wavefunction_branching/evolve_and_branch_finite.py index 7cfe452..1f706be 100644 --- a/wavefunction_branching/evolve_and_branch_finite.py +++ b/wavefunction_branching/evolve_and_branch_finite.py @@ -542,7 +542,7 @@ def branch_and_sample( candidate_indices = np.arange(num_candidates) # --- Branch Sampling: Allocate grandchild budget using random_round --- - allocated_budgets = probabilistic_round_child_budget(max_children, branch_probs) + allocated_budgets = probabilistic_round_child_budget(self.max_children, branch_probs) keep_mask = allocated_budgets > 0 survivor_indices = candidate_indices[keep_mask] @@ -688,10 +688,10 @@ def branch_and_sample( else: print(f"{self.ID}Accepting the branch decomposition: ") print( - f" local_trace_distance = {costFun_LM_MR_trace_distance:.2e} < truncation_only_comparison = {trace_distance_truncation_only_comparison:.2e}" + f" local_trace_distance = {costFun_LM_MR_trace_distance:.2e} vs truncation_only_comparison = {trace_distance_truncation_only_comparison:.2e}" ) print( - f" global_trace_distance = {global_reconstruction_error_trace_distance:.2e} < truncation_only_comparison = {global_reconstruction_error_truncation_only_comparison:.2e}" + f" global_trace_distance = {global_reconstruction_error_trace_distance:.2e} vs truncation_only_comparison = {global_reconstruction_error_truncation_only_comparison:.2e}" ) self.last_attempted_branching_trunc_bond_dims_sites[coarsegrain_from] = None self.last_attempted_branching_trunc_trace_distance_sites[coarsegrain_from] = None