Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/sphinx/examples/qec/python/tn_noise_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def main():
H, L, true_priors = parse_detector_error_model(dem)
true_probs = np.array(true_priors)
n_checks, n_errors = H.shape
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"DEM: {n_checks} checks, {n_errors} errors")
print(f"True priors: mean={true_probs.mean():.4e} "
Expand All @@ -81,11 +82,17 @@ def main():
obs_flips = obs_flips.ravel().astype(bool)

uniform = float(true_probs.mean())
# precontract_noise=True is the recommended reduced-topology path
# for larger detector-error models. Set precontract_noise=False
# only when explicitly checking the full tensor-network contraction.
opt = NMOptimizer(H,
L, [uniform] * n_errors,
det_events,
obs_flips,
dtype="float64")
dtype="float64",
device=device,
execute="opt_einsum",
precontract_noise=True)

# Optimize in logit space — numerically stabler than raw probs.
def _to_logits(p):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
# ============================================================================ #
from __future__ import annotations

from collections import OrderedDict
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, ClassVar

import opt_einsum as oe
import torch
from quimb.tensor import TensorNetwork


Expand All @@ -33,6 +35,40 @@ def contractor(subscripts: str,
return oe.contract(subscripts, *tensors, optimize=optimize)


def oe_torch_contractor(subscripts: str,
tensors: list[torch.Tensor],
optimize: str = "auto",
**_: Any) -> Any:
"""Perform einsum contraction using opt_einsum with torch tensors.

Execution follows the input tensor device, so CUDA tensors stay on
CUDA while still preserving torch autograd.
"""
return oe.contract(subscripts, *tensors, optimize=optimize, backend="torch")


_OE_EXPR_CACHE_MAXSIZE = 32
_oe_expr_cache: OrderedDict[tuple, Any] = OrderedDict()


def oe_torch_compiled_contractor(subscripts: str,
tensors: list[torch.Tensor],
optimize: str = "auto",
**_: Any) -> Any:
"""Perform einsum contraction with a cached opt_einsum expression."""
shapes = tuple(t.shape for t in tensors)
key = (subscripts, shapes, str(optimize))
if key in _oe_expr_cache:
_oe_expr_cache.move_to_end(key)
else:
if len(_oe_expr_cache) >= _OE_EXPR_CACHE_MAXSIZE:
_oe_expr_cache.popitem(last=False)
_oe_expr_cache[key] = oe.contract_expression(subscripts,
*shapes,
optimize=optimize)
return _oe_expr_cache[key](*tensors, backend="torch")


def cutn_contractor(subscripts: str,
tensors: list[Any],
optimize: Any | None = None,
Expand Down Expand Up @@ -109,13 +145,20 @@ class ContractorConfig:
_allowed_configs: ClassVar[tuple[tuple[str, str, str], ...]] = (
("numpy", "numpy", "cpu"),
("torch", "torch", "cpu"),
("torch", "torch", "cuda"),
("oe_torch", "torch", "cpu"),
("oe_torch", "torch", "cuda"),
("oe_torch_compiled", "torch", "cpu"),
("oe_torch_compiled", "torch", "cuda"),
("cutensornet", "numpy", "cuda"),
("cutensornet", "torch", "cuda"),
)
_allowed_backends: ClassVar[list[str]] = ["numpy", "torch"]
_contractors: ClassVar[dict[str, Callable]] = {
"numpy": contractor,
"torch": contractor,
"oe_torch": oe_torch_contractor,
"oe_torch_compiled": oe_torch_compiled_contractor,
"cutensornet": cutn_contractor,
}

Expand Down
Loading
Loading