Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions make_sbatch_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def update_dict_copy(old_dict, new_entries):
maxiter_heuristic=6000,
necessary_local_truncation_improvement_factor=1.0,
necessary_global_truncation_improvement_factor=1.0,
tolEntropy=None,
tolEntropy_kind="cnot",
seed=seed,
)

Expand Down Expand Up @@ -160,6 +162,8 @@ def update_dict_copy(old_dict, new_entries):
t_evo=(0.25 * n_sites - 2.0),
necessary_local_truncation_improvement_factor=0.0,
necessary_global_truncation_improvement_factor=0.0,
tolEntropy=0.01,
tolEntropy_kind="vertical",
),
dict(
iterative_method="bell_original_threshold_discard_classical",
Expand All @@ -170,6 +174,8 @@ def update_dict_copy(old_dict, new_entries):
t_evo=(0.25 * n_sites - 2.0),
necessary_local_truncation_improvement_factor=0.0,
necessary_global_truncation_improvement_factor=0.0,
tolEntropy=0.01,
tolEntropy_kind="vertical",
),
dict(
iterative_method="bell_original_threshold_keep_classical",
Expand All @@ -180,6 +186,8 @@ def update_dict_copy(old_dict, new_entries):
t_evo=(0.25 * n_sites - 2.0),
necessary_local_truncation_improvement_factor=0.0,
necessary_global_truncation_improvement_factor=0.0,
tolEntropy=0.01,
tolEntropy_kind="vertical",
),
dict(
iterative_method="bell_original_threshold_keep_classical",
Expand All @@ -190,6 +198,8 @@ def update_dict_copy(old_dict, new_entries):
t_evo=(0.25 * n_sites - 2.0),
necessary_local_truncation_improvement_factor=0.0,
necessary_global_truncation_improvement_factor=0.0,
tolEntropy=0.01,
tolEntropy_kind="vertical",
),
dict(
iterative_method="bell_original_threshold_keep_classical",
Expand All @@ -200,6 +210,8 @@ def update_dict_copy(old_dict, new_entries):
t_evo=(0.25 * n_sites - 2.0),
necessary_local_truncation_improvement_factor=0.0,
necessary_global_truncation_improvement_factor=0.0,
tolEntropy=0.01,
tolEntropy_kind="vertical",
),
dict(
iterative_method="bell_original_threshold_keep_classical",
Expand All @@ -210,6 +222,8 @@ def update_dict_copy(old_dict, new_entries):
t_evo=(0.25 * n_sites - 2.0),
necessary_local_truncation_improvement_factor=0.0,
necessary_global_truncation_improvement_factor=0.0,
tolEntropy=0.01,
tolEntropy_kind="vertical",
),
dict(
iterative_method="bell_original_threshold_keep_classical",
Expand All @@ -220,6 +234,8 @@ def update_dict_copy(old_dict, new_entries):
t_evo=(0.25 * n_sites - 2.0),
necessary_local_truncation_improvement_factor=0.0,
necessary_global_truncation_improvement_factor=0.0,
tolEntropy=0.01,
tolEntropy_kind="vertical",
),
# Other methods
dict(
Expand Down Expand Up @@ -294,6 +310,46 @@ def update_dict_copy(old_dict, new_entries):
n_sites=n_sites,
t_evo=(0.25 * n_sites - 2.0),
),


## Cnot entropy threshold

dict(
iterative_method="vertical_svd_micro_bsvd",
graddesc_method="graddesc_global_reconstruction_non_interfering",
chi_max=chi_max,
chi_to_branch=chi_max, # int(chi_max*0.75),
n_sites=n_sites,
t_evo=(0.25 * n_sites - 2.0),
necessary_local_truncation_improvement_factor=0.0,
necessary_global_truncation_improvement_factor=0.0,
tolEntropy=0.01,
tolEntropy_kind="cnot",
),
dict(
iterative_method="vertical_svd_micro_bsvd",
graddesc_method="graddesc_global_reconstruction_split_non_interfering",
chi_max=chi_max,
chi_to_branch=chi_max, # int(chi_max*0.75),
n_sites=n_sites,
t_evo=(0.25 * n_sites - 2.0),
necessary_local_truncation_improvement_factor=0.0,
necessary_global_truncation_improvement_factor=0.0,
tolEntropy=0.01,
tolEntropy_kind="cnot",
),
dict(
iterative_method="vertical_svd_micro_bsvd",
graddesc_method="rho_LM_MR_trace_norm",
chi_max=chi_max,
chi_to_branch=chi_max, # int(chi_max*0.75),
n_sites=n_sites,
t_evo=(0.25 * n_sites - 2.0),
necessary_local_truncation_improvement_factor=0.0,
necessary_global_truncation_improvement_factor=0.0,
tolEntropy=0.01,
tolEntropy_kind="cnot",
),
# dict(
# iterative_method = 'vertical_svd_micro_bsvd',
# graddesc_method = 'rho_half_LM_MR_trace_norm',
Expand Down
115 changes: 113 additions & 2 deletions tests/test_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
# %%
import numpy as np
import pytest
from opt_einops import einsum
from opt_einops import einsum, rearrange
from scipy.stats import unitary_group # Random unitaries

from instantaneous_benchmarks.benchmark_decompositions import benchmark_blockdiag_method
from instantaneous_benchmarks.generate_test_inputs import (
block_diagonal_matrices_exp_decaying_spectrum,
)
from wavefunction_branching.decompositions.decompositions import branch_from_theta
from wavefunction_branching.decompositions.decompositions import (
branch_from_theta,
calculate_entropy_cnot,
calculate_entropy_vertical,
)
from wavefunction_branching.type_aliases import (
BlockDiagTensor,
LeftSplittingTensor,
MatrixStack,
PurificationMatrixStack,
RightSplittingTensor,
)

ITERATIVE_METHODS = (
"bell_discard_classical",
Expand Down Expand Up @@ -250,4 +261,104 @@ def check_decompositions(N=4, branch_chi=5, noise_introduced=1e-8, n_trials=5):
print("\n\nDone.")


def create_tensors_with_zero_vertical_entropy(
dPhys: int,
nBranches: int,
dVirt_L: int,
dVirt_R: int,
dSlow: int,
) -> tuple[MatrixStack, LeftSplittingTensor, np.ndarray, RightSplittingTensor]:
tensor_top = np.random.rand(dPhys, dSlow, dSlow) + 1j * np.random.rand(dPhys, dSlow, dSlow)
tensor_bottom = np.diag(np.random.rand(nBranches) + 1.0j * np.random.rand(nBranches))

# Create random unitaries to scramble the tensor
U = unitary_group.rvs(dVirt_L)
U = rearrange(U, "L (l bl) -> bl L l", l=dSlow, bl=nBranches)
Vh = unitary_group.rvs(dVirt_R)
Vh = rearrange(Vh, "(r br) R -> br r R", r=dSlow, br=nBranches)

tensor_scrambled = einsum(
U, tensor_top, tensor_bottom, Vh, "bl L l, p l r, bl br, br r R -> p L R"
)
return tensor_scrambled, U, tensor_top, Vh


def test_calculate_entropy_vertical():
# Create test tensor, which should have zero vertical entanglement entropy
tensor_scrambled, U, tensor_top, Vh = create_tensors_with_zero_vertical_entropy(
dPhys=4,
nBranches=2,
dVirt_L=16,
dVirt_R=16,
dSlow=8,
)
# Check that the vertical entanglement entropy is zero
vertical_entropy = calculate_entropy_vertical(tensor_scrambled, U, tensor_top, Vh)
assert np.isclose(vertical_entropy, 0.0, atol=1e-10), (
f"vertical_entropy = {vertical_entropy}, expected close to 0.0"
)


def test_calcuate_entropy_cnot_bell():
# Create test tensor, which should have zero vertical entanglement entropy
tensor_scrambled, U, tensor_top, Vh = create_tensors_with_zero_vertical_entropy(
dPhys=4,
nBranches=2,
dVirt_L=16,
dVirt_R=16,
dSlow=8,
)
# Check that the cnot entanglement entropy is zero
cnot_entropy = calculate_entropy_cnot(tensor_scrambled, U, tensor_top, Vh)
assert np.isclose(cnot_entropy, 0.0, atol=1e-10), (
f"cnot_entropy = {cnot_entropy}, expected close to 0.0"
)


def create_tensors_with_ghz_structure(
dPhys: int,
nBranches: int,
dVirt_L: int,
dVirt_R: int,
dSlow: int,
) -> tuple[MatrixStack, LeftSplittingTensor, np.ndarray, RightSplittingTensor]:
assert dVirt_L == nBranches * dSlow
assert dVirt_R == nBranches * dSlow

tensor_top: BlockDiagTensor = np.random.rand(
dPhys, nBranches, dSlow, dSlow
) + 1j * np.random.rand(dPhys, nBranches, dSlow, dSlow)
tensor_bottom = np.zeros((nBranches, nBranches, nBranches)) * 0.0j
for i in range(nBranches):
tensor_bottom[i, i, i] = np.random.rand() + 1.0j * np.random.rand()

# Create random unitaries to scramble the tensor
U = unitary_group.rvs(dVirt_L)
U = rearrange(U, "L (l bl) -> bl L l", l=dSlow, bl=nBranches)
Vh = unitary_group.rvs(dVirt_R)
Vh = rearrange(Vh, "(r br) R -> br r R", r=dSlow, br=nBranches)

tensor_scrambled = einsum(
U, tensor_top, tensor_bottom, Vh, "bl L l, p bt l r, bt bl br, br r R -> p L R"
)
return tensor_scrambled, U, tensor_top, Vh


def test_calcuate_entropy_cnot_ghz():
# Create test tensor, which should have zero vertical entanglement entropy
tensor_scrambled, U, tensor_top, Vh = create_tensors_with_ghz_structure(
dPhys=4,
nBranches=2,
dVirt_L=16,
dVirt_R=16,
dSlow=8,
)
# Check that the cnot entanglement entropy is zero
cnot_entropy = calculate_entropy_cnot(tensor_scrambled, U, tensor_top, Vh)
assert np.isclose(cnot_entropy, 0.0, atol=1e-10), (
f"cnot_entropy = {cnot_entropy}, expected close to 0.0"
)


test_calcuate_entropy_cnot_ghz()
# %%
Loading
Loading