diff --git a/docs/sphinx/examples/qec/python/tn_noise_learning.py b/docs/sphinx/examples/qec/python/tn_noise_learning.py index f1631b93..af13ad91 100644 --- a/docs/sphinx/examples/qec/python/tn_noise_learning.py +++ b/docs/sphinx/examples/qec/python/tn_noise_learning.py @@ -68,6 +68,7 @@ def main(): H, L, true_priors = parse_detector_error_model(dem) true_probs = np.array(true_priors) n_checks, n_errors = H.shape + device = "cuda" if torch.cuda.is_available() else "cpu" print(f"DEM: {n_checks} checks, {n_errors} errors") print(f"True priors: mean={true_probs.mean():.4e} " @@ -81,11 +82,17 @@ def main(): obs_flips = obs_flips.ravel().astype(bool) uniform = float(true_probs.mean()) + # precontract_noise=True is the recommended reduced-topology path + # for larger detector-error models. Set precontract_noise=False + # only when explicitly checking the full tensor-network contraction. opt = NMOptimizer(H, L, [uniform] * n_errors, det_events, obs_flips, - dtype="float64") + dtype="float64", + device=device, + execute="opt_einsum", + precontract_noise=True) # Optimize in logit space — numerically stabler than raw probs. def _to_logits(p): diff --git a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/contractors.py b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/contractors.py index 5ded4fb3..3d7951d0 100644 --- a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/contractors.py +++ b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/contractors.py @@ -7,11 +7,13 @@ # ============================================================================ # from __future__ import annotations +from collections import OrderedDict from collections.abc import Callable from dataclasses import dataclass, field from typing import Any, ClassVar import opt_einsum as oe +import torch from quimb.tensor import TensorNetwork @@ -33,6 +35,40 @@ def contractor(subscripts: str, return oe.contract(subscripts, *tensors, optimize=optimize) +def oe_torch_contractor(subscripts: str, + tensors: list[torch.Tensor], + optimize: str = "auto", + **_: Any) -> Any: + """Perform einsum contraction using opt_einsum with torch tensors. + + Execution follows the input tensor device, so CUDA tensors stay on + CUDA while still preserving torch autograd. + """ + return oe.contract(subscripts, *tensors, optimize=optimize, backend="torch") + + +_OE_EXPR_CACHE_MAXSIZE = 32 +_oe_expr_cache: OrderedDict[tuple, Any] = OrderedDict() + + +def oe_torch_compiled_contractor(subscripts: str, + tensors: list[torch.Tensor], + optimize: str = "auto", + **_: Any) -> Any: + """Perform einsum contraction with a cached opt_einsum expression.""" + shapes = tuple(t.shape for t in tensors) + key = (subscripts, shapes, str(optimize)) + if key in _oe_expr_cache: + _oe_expr_cache.move_to_end(key) + else: + if len(_oe_expr_cache) >= _OE_EXPR_CACHE_MAXSIZE: + _oe_expr_cache.popitem(last=False) + _oe_expr_cache[key] = oe.contract_expression(subscripts, + *shapes, + optimize=optimize) + return _oe_expr_cache[key](*tensors, backend="torch") + + def cutn_contractor(subscripts: str, tensors: list[Any], optimize: Any | None = None, @@ -109,6 +145,11 @@ class ContractorConfig: _allowed_configs: ClassVar[tuple[tuple[str, str, str], ...]] = ( ("numpy", "numpy", "cpu"), ("torch", "torch", "cpu"), + ("torch", "torch", "cuda"), + ("oe_torch", "torch", "cpu"), + ("oe_torch", "torch", "cuda"), + ("oe_torch_compiled", "torch", "cpu"), + ("oe_torch_compiled", "torch", "cuda"), ("cutensornet", "numpy", "cuda"), ("cutensornet", "torch", "cuda"), ) @@ -116,6 +157,8 @@ class ContractorConfig: _contractors: ClassVar[dict[str, Callable]] = { "numpy": contractor, "torch": contractor, + "oe_torch": oe_torch_contractor, + "oe_torch_compiled": oe_torch_compiled_contractor, "cutensornet": cutn_contractor, } diff --git a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py index 4e908908..5577ae39 100644 --- a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py +++ b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py @@ -17,6 +17,10 @@ """ from __future__ import annotations +import contextlib +import copy +import io +import math import warnings from typing import Any, Literal @@ -24,6 +28,7 @@ import numpy.typing as npt import opt_einsum as oe import torch +from torch.utils.checkpoint import checkpoint as _checkpoint from quimb.tensor import TensorNetwork from ..tensor_network_decoder import TensorNetworkDecoder @@ -41,6 +46,9 @@ "float32": 1e-6, } _SUPPORTED_DTYPES: tuple[str, ...] = ("float32", "float64") +_PATH_CACHE_MAXSIZE = 16 +_path_cache: dict[tuple[str, tuple[tuple[int, ...], ...]], tuple[Any, Any]] = {} +_TORCH_CONTRACTION_MAX_INTERMEDIATE_ELEMENTS = 2_147_483_647 def _validate_and_clamp_priors(noise_model: Any, dtype: str) -> list[float]: @@ -106,13 +114,45 @@ def _clamp_log_input(x: torch.Tensor) -> torch.Tensor: return x.clamp_min(torch.finfo(x.dtype).tiny) +def _finite_nonnegative(x: torch.Tensor) -> torch.Tensor: + """Drop non-finite values and negative roundoff from probability weights.""" + return torch.nan_to_num( + x, + nan=0.0, + posinf=torch.finfo(x.dtype).max, + neginf=0.0, + ).clamp_min(0.0) + + +def _normalize_prediction(out: torch.Tensor) -> torch.Tensor: + """Normalize raw decoder weights into finite per-shot probabilities.""" + positive_inf = out == float("inf") + has_positive_inf = positive_inf.any(dim=1, keepdim=True) + finite_weights = torch.where(torch.isfinite(out), out, + torch.zeros_like(out)).clamp_min(0.0) + weights = torch.where(has_positive_inf, torch.zeros_like(finite_weights), + finite_weights) + tiny = torch.finfo(weights.dtype).tiny + scale = weights.max(dim=1, keepdim=True).values + scaled = weights / scale.clamp_min(tiny) + denom = scaled.sum(dim=1, keepdim=True) + probs = scaled / denom.clamp_min(tiny) + inf_probs = positive_inf.to(out.dtype) / positive_inf.sum( + dim=1, keepdim=True).clamp_min(1).to(out.dtype) + uniform = torch.full_like(weights, 1.0 / weights.shape[1]) + probs = torch.where(scale > tiny, probs, uniform) + probs = torch.where(has_positive_inf, inf_probs, probs) + probs = torch.nan_to_num(probs, nan=0.0, posinf=1.0, neginf=0.0) + return probs.clamp(min=tiny, max=1.0) + + def remap_eq_to_ascii(eq: str) -> str: """Rewrite an einsum equation so every label is in ``[a-zA-Z]``. - Needed because :mod:`opt_einsum` falls back to non-ASCII unicode - labels once the total index count exceeds 52, but - :func:`torch.einsum` rejects them. Raises if a single step has more - than 52 distinct labels. + Needed because :mod:`opt_einsum` can use non-ASCII unicode labels + once the total index count exceeds 52, but :func:`torch.einsum` + rejects them. Raises if a single step has more than 52 distinct + labels. """ if eq.isascii(): return eq @@ -137,6 +177,7 @@ def remap_eq_to_ascii(eq: str) -> str: out_lhs = "".join(out_lhs_chars) if rhs is None: return out_lhs + out_rhs_chars: list[str] = [] for c in rhs: if c not in mapping: @@ -147,6 +188,223 @@ def remap_eq_to_ascii(eq: str) -> str: return f"{out_lhs}->{''.join(out_rhs_chars)}" +def _maybe_remap_eq_to_ascii(eq: str) -> str: + """Return an ASCII equation when torch can represent it directly. + + More than 52 distinct labels cannot be encoded for + :func:`torch.einsum`; in that case keep the original opt_einsum + labels and let :func:`_einsum_torch` use its pairwise fallback. + """ + try: + return remap_eq_to_ascii(eq) + except ValueError: + return eq + + +def _einsum_label_count(eq: str) -> int: + lhs = eq.split("->", 1)[0] + return len({c for c in lhs if c != ","}) + + +def _reshape(tensor: torch.Tensor, shape: list[int]) -> torch.Tensor: + return tensor.reshape(tuple(shape)) + + +def _prod(values: list[int]) -> int: + return int(math.prod(values)) if values else 1 + + +def _sum_unique_omitted_axes(tensor: torch.Tensor, labels: list[str], + other_labels: set[str], + rhs_labels: set[str]) -> torch.Tensor: + drop_axes = [ + axis for axis, label in enumerate(labels) + if label not in rhs_labels and label not in other_labels + ] + for axis in reversed(drop_axes): + tensor = tensor.sum(dim=axis) + labels.pop(axis) + return tensor + + +def _permute_to(tensor: torch.Tensor, labels: list[str], + target: list[str]) -> torch.Tensor: + if labels == target: + return tensor + if not target: + return tensor.reshape(()) + return tensor.permute([labels.index(label) for label in target]) + + +def _einsum_pairwise_torch(eq: str, operands: tuple[torch.Tensor, + ...]) -> torch.Tensor: + """Evaluate one opt_einsum pairwise step without torch label limits.""" + if "->" not in eq: + raise ValueError(f"Expected explicit einsum output in {eq!r}.") + lhs, rhs = eq.split("->", 1) + terms = lhs.split(",") + rhs_labels = list(rhs) + rhs_set = set(rhs_labels) + + if len(terms) == 1 and len(operands) == 1: + labels = list(terms[0]) + result = _sum_unique_omitted_axes(operands[0], labels, set(), rhs_set) + return _permute_to(result, labels, rhs_labels) + + if len(terms) != 2 or len(operands) != 2: + raise ValueError( + "The high-label torch fallback only supports unary or pairwise " + f"einsum steps; got equation {eq!r}.") + + a, b = operands + labels_a = list(terms[0]) + labels_b = list(terms[1]) + set_a = set(labels_a) + set_b = set(labels_b) + + a = _sum_unique_omitted_axes(a, labels_a, set_b, rhs_set) + b = _sum_unique_omitted_axes(b, labels_b, set_a, rhs_set) + + batch_labels = [ + label for label in rhs_labels if label in labels_a and label in labels_b + ] + contract_labels = [ + label for label in labels_a + if label in labels_b and label not in rhs_set + ] + a_free_labels = [ + label for label in labels_a + if label not in batch_labels and label not in contract_labels + ] + b_free_labels = [ + label for label in labels_b + if label not in batch_labels and label not in contract_labels + ] + + sizes: dict[str, int] = {} + for labels, tensor in ((labels_a, a), (labels_b, b)): + for axis, label in enumerate(labels): + size = int(tensor.shape[axis]) + if label in sizes and sizes[label] != size: + raise ValueError(f"Mismatched dimension for label {label!r}: " + f"{sizes[label]} vs {size}.") + sizes[label] = size + + a_order = batch_labels + a_free_labels + contract_labels + b_order = batch_labels + contract_labels + b_free_labels + a = _permute_to(a, labels_a, a_order) + b = _permute_to(b, labels_b, b_order) + + batch_shape = [sizes[label] for label in batch_labels] + a_shape = [sizes[label] for label in a_free_labels] + contract_shape = [sizes[label] for label in contract_labels] + b_shape = [sizes[label] for label in b_free_labels] + + batch_size = _prod(batch_shape) + a_size = _prod(a_shape) + contract_size = _prod(contract_shape) + b_size = _prod(b_shape) + + a_mat = _reshape(a, [batch_size, a_size, contract_size]) + b_mat = _reshape(b, [batch_size, contract_size, b_size]) + out = torch.bmm(a_mat, b_mat) + + current_labels = batch_labels + a_free_labels + b_free_labels + out_shape = batch_shape + a_shape + b_shape + out = out.reshape(tuple(out_shape) if out_shape else ()) + return _permute_to(out, current_labels, rhs_labels) + + +def _einsum_torch(eq: str, *operands: torch.Tensor) -> torch.Tensor: + """Torch einsum with a pairwise fallback for >52 opt_einsum labels.""" + if _einsum_label_count(eq) <= len(_ASCII_POOL): + return torch.einsum(remap_eq_to_ascii(eq), *operands) + return _einsum_pairwise_torch(eq, operands) + + +def _path_largest_intermediate(info: Any) -> float: + value = getattr(info, "largest_intermediate", None) + if value is None: + return float("inf") + try: + return float(value) + except (TypeError, ValueError, OverflowError): + return float("inf") + + +def _path_opt_cost(info: Any) -> float: + value = getattr(info, "opt_cost", None) + if value is None: + return float("inf") + try: + return float(value) + except (TypeError, ValueError, OverflowError): + return float("inf") + + +def _select_default_torch_path( + eq: str, + shapes: tuple[tuple[int, ...], ...], + *, + tn: TensorNetwork | None = None, + output_inds: tuple[str, ...] | None = None, +) -> tuple[Any, Any]: + """Choose a deterministic torch-friendly default contraction path.""" + key = (eq, shapes) + if key in _path_cache: + return _path_cache[key] + + candidates: list[tuple[str, Any, Any]] = [] + + optimizers: list[tuple[str, Any]] = [("greedy", "greedy"), ("auto", "auto")] + try: + import cotengra as ctg + except ImportError: + pass + else: + for attempt in range(3): + optimizers.append((f"cotengra-{attempt}", + ctg.HyperOptimizer(max_repeats=8, + parallel=False))) + + for tag, optimize in optimizers: + try: + if tn is not None: + if output_inds is None: + raise ValueError("output_inds is required with tn.") + info = tn.contraction_info(output_inds=output_inds, + optimize=optimize) + path = info.path + else: + path, info = oe.contract_path(eq, + *shapes, + shapes=True, + optimize=optimize) + except Exception as exc: + warnings.warn( + f"NMOptimizer default path candidate {tag!r} failed: " + f"{exc!r}", + RuntimeWarning, + stacklevel=3, + ) + continue + candidates.append((tag, path, info)) + + if not candidates: + raise RuntimeError("No NMOptimizer default contraction path " + "candidate succeeded.") + + _tag, selected_path, selected_info = min( + candidates, + key=lambda c: (_path_largest_intermediate(c[2]), _path_opt_cost(c[2])), + ) + selected = (selected_path, selected_info) + if len(_path_cache) >= _PATH_CACHE_MAXSIZE: + _path_cache.pop(next(iter(_path_cache))) + _path_cache[key] = selected + return selected + + class NMOptimizer(TensorNetworkDecoder): """Differentiable noise-model optimiser for the TN decoder. @@ -157,7 +415,7 @@ class NMOptimizer(TensorNetworkDecoder): The forward pass is materialised once at construction and reused across optimisation steps. Optionally call :meth:`optimize_path` (e.g. with ``cotengra.HyperOptimizer()``) to pin a better contraction - path; the JIT is rebuilt automatically. + path; the cached forward is rebuilt automatically. .. warning:: @@ -183,22 +441,30 @@ class NMOptimizer(TensorNetworkDecoder): check_inds, error_inds, logical_inds, logical_tags: Optional index and tag names; defaults track the parent decoder. dtype: Tensor data type (e.g. ``"float32"``). - device: ``"cuda"`` (default) or ``"cpu"``. + device: ``"cuda"`` (default) or ``"cpu"``. ``NMOptimizer`` + always uses torch-backed contractions on this device; it + does not dispatch through cuTensorNet. compile: If ``True``, wrap the forward in :func:`torch.compile`. - Most useful with ``execute="codegen"``. - execute: Forward backend. ``"codegen"`` (default) partial-evaluates - the path into a flat Python function; ``"unrolled"`` keeps an - interpretive einsum list; ``"opt_einsum"`` dispatches via - :func:`opt_einsum.contract_expression`. + execute: Forward backend. ``"opt_einsum"`` (default) dispatches + via :func:`opt_einsum.contract_expression`; ``"unrolled"`` + walks the pairwise path with torch tensor operations; + ``"codegen"`` partial-evaluates the path into a flat Python + function. All modes are torch/autograd paths and support + CPU and CUDA tensors. compile_mode: Forwarded to :func:`torch.compile`; ignored when ``compile=False``. dynamic_syndromes: If ``True`` (default), syndromes are runtime arguments to the compiled forward, so :meth:`update_dataset` - does not retrigger codegen / ``torch.compile`` (provided - shapes are unchanged). ``False`` bakes the syndromes into - the closure as constants -- fewer runtime einsums, but every - :meth:`update_dataset` call rebuilds the graph. Only affects - ``execute="codegen"``. + does not rebuild codegen when shapes are unchanged. + ``False`` bakes syndromes into the generated closure and + only affects ``execute="codegen"``. + precontract_noise: If ``True``, defer the per-error + noise-into-code contractions into differentiable torch ops + and contract the reduced tensor network. If ``"auto"`` + (default), use the full tensor network when the default path + is small enough and switch to the reduced topology for large + DEMs or when the full path's largest intermediate is too + large. Example (logit-space, no clamping needed):: @@ -227,54 +493,63 @@ def __init__( device: str = "cuda", *, compile: bool = False, - execute: Literal["codegen", "unrolled", "opt_einsum"] = "codegen", + execute: Literal["opt_einsum", "unrolled", "codegen"] = "opt_einsum", compile_mode: str | None = None, dynamic_syndromes: bool = True, + precontract_noise: bool | Literal["auto"] = "auto", ) -> None: - if execute not in ("unrolled", "opt_einsum", "codegen"): + if execute not in ("opt_einsum", "unrolled", "codegen"): raise ValueError(f"Invalid execute mode: {execute!r}") if dtype not in _SUPPORTED_DTYPES: raise ValueError(f"Invalid dtype {dtype!r}; expected one of " f"{list(_SUPPORTED_DTYPES)}.") + if precontract_noise not in (False, True, "auto"): + raise ValueError( + "precontract_noise must be one of False, True, or 'auto'; " + f"got {precontract_noise!r}.") # Sanitise once so the base TN tensors and ``self._noise_probs`` # see identical values (see :func:`_validate_and_clamp_priors`). noise_model = _validate_and_clamp_priors(noise_model, dtype) - super().__init__( - H, - logical_obs, - noise_model, - check_inds=check_inds, - error_inds=error_inds, - logical_inds=logical_inds, - logical_tags=logical_tags, - contract_noise_model=False, - dtype=dtype, - device=device, - ) - - # Force the torch backend so tensor data lives in the autograd - # graph (the base class would otherwise pick cutensornet/numpy - # on GPU). Contractions still go through codegen / cotengra - # below, not cuTensorNet. - if self.contractor_config.contractor_name == "cutensornet" \ - and self.contractor_config.backend != "torch": + requested_device = device + requested_cuda = "cuda" in requested_device + cuda_available = torch.cuda.is_available() + if requested_cuda and not cuda_available: warnings.warn( - "NMOptimizer requires the torch backend for autograd; " - f"switching contractor backend from " - f"{self.contractor_config.backend!r} to 'torch'. " - "Contractions are executed via codegen/opt_einsum, not " - "cuTensorNet.", - stacklevel=3, + "CUDA was requested for NMOptimizer, but torch CUDA is not " + "available; using CPU for differentiable tensor-network " + "contractions.", + RuntimeWarning, + stacklevel=2, ) - self._set_contractor( - "cutensornet", - self.contractor_config.device, - "torch", - dtype, + + with contextlib.redirect_stdout(io.StringIO()): + super().__init__( + H, + logical_obs, + noise_model, + check_inds=check_inds, + error_inds=error_inds, + logical_inds=logical_inds, + logical_tags=logical_tags, + contract_noise_model=False, + dtype=dtype, + # NMOptimizer is always torch/autograd-backed. Build the + # parent topology on CPU first so construction never routes + # through the base decoder's cuTensorNet default, then move + # tensors to the requested torch device below. + device="cpu", ) + target_device = (requested_device + if requested_cuda and cuda_available else "cpu") + if (self.contractor_config.contractor_name != "oe_torch_compiled" or + self.contractor_config.backend != "torch" or + self.contractor_config.device != target_device): + self._set_contractor("oe_torch_compiled", target_device, "torch", + dtype) + # Swap the base's placeholder single-syndrome TN for a batched one. self._syndrome_tags = [f"SYN_{i}" for i in range(len(self.check_inds))] self.syndrome_tn = tensor_network_from_syndrome_batch( @@ -313,6 +588,10 @@ def __init__( self._execute_mode = execute self._torch_compile_mode = compile_mode self._dynamic_syndromes = dynamic_syndromes + self._precontract_noise_auto = precontract_noise == "auto" + self._precontract_noise = precontract_noise is True + self._reduced_optimize: Any = None + self._reduced_info: Any = None self._compiled_predict: Any | None = None self._syndrome_tuple: tuple[torch.Tensor, ...] = () self._snapshot_arrays_and_eq() @@ -439,35 +718,62 @@ def _as_torch(x): ] self._syndrome_tuple = tuple(self._syndrome_arrays) # Used by :meth:`_update_data` to detect layout changes that - # invalidate the cached path / codegen / oe expr. + # invalidate the cached path / opt_einsum expression. self._syndrome_shapes: tuple[tuple[int, ...], ...] = tuple( tuple(s.shape) for s in self._syndrome_arrays) - if self._execute_mode == "opt_einsum": - shapes = tuple(t.shape for t in tensors) + shapes = tuple(t.shape for t in tensors) + if self._precontract_noise_auto: + optimize, info = _select_default_torch_path( + self._eq_batch, + shapes, + tn=self.full_tn, + output_inds=("batch_index", self.logical_obs_inds[0]), + ) + self._precontract_noise = ( + _path_largest_intermediate(info) + > _TORCH_CONTRACTION_MAX_INTERMEDIATE_ELEMENTS) + if not self._precontract_noise: + self.path_batch = optimize + + if self._precontract_noise: + self._oe_expr = None + self._build_reduced_tn_state() + elif self._execute_mode == "opt_einsum": + optimize = self.path_batch + if optimize in (None, "auto"): + optimize, _info = _select_default_torch_path( + self._eq_batch, + shapes, + tn=self.full_tn, + output_inds=("batch_index", self.logical_obs_inds[0]), + ) + self.path_batch = optimize self._oe_expr = oe.contract_expression( self._eq_batch, *shapes, - optimize=self.path_batch - if self.path_batch not in (None, "auto") else "auto", + optimize=optimize, ) self._path_steps = None else: self._oe_expr = None - # Flatten the path into ``[(eq, idxs, sorted_desc), ...]``; - # ``sorted_desc`` is precomputed for the unrolled-mode pop - # walk, and labels are remapped to ASCII because - # opt_einsum falls back to unicode past 52 indices and - # torch.einsum rejects those. - shapes = tuple(t.shape for t in tensors) + optimize = self.path_batch + if optimize in (None, "auto"): + optimize, _info = _select_default_torch_path( + self._eq_batch, + shapes, + tn=self.full_tn, + output_inds=("batch_index", self.logical_obs_inds[0]), + ) + self.path_batch = optimize _, info = oe.contract_path( self._eq_batch, *shapes, shapes=True, - optimize=self.path_batch - if self.path_batch not in (None, "auto") else "auto", + optimize=optimize, ) - self._path_steps = [(remap_eq_to_ascii(step[2]), tuple(step[0]), + self._path_steps = [(_maybe_remap_eq_to_ascii(step[2]), + tuple(step[0]), tuple(sorted(step[0], reverse=True))) for step in info.contraction_list] @@ -475,13 +781,16 @@ def _as_torch(x): self._compile_loss() def _compile_predict(self) -> None: - """Build ``self._predict_fn`` for the configured execute mode.""" - builders = { - "opt_einsum": self._build_predict_opt_einsum, - "unrolled": self._build_predict_unrolled, - "codegen": self._build_predict_codegen, - } - self._predict_fn = builders[self._execute_mode]() + """Build ``self._predict_fn``.""" + if self._precontract_noise: + self._predict_fn = self._build_predict_reduced() + else: + builders = { + "opt_einsum": self._build_predict_opt_einsum, + "unrolled": self._build_predict_unrolled, + "codegen": self._build_predict_codegen, + } + self._predict_fn = builders[self._execute_mode]() self._compiled_predict = self._maybe_torch_compile(self._predict_fn, kind="predict") @@ -507,7 +816,7 @@ def _predict(noise_probs: torch.Tensor, # Torch backend is auto-selected from the tensor type; # avoids the per-call ``backend=`` dispatch. out = oe_expr(*arrays) - return out / out.sum(dim=1, keepdim=True) + return _normalize_prediction(out) return _predict @@ -534,14 +843,14 @@ def _predict(noise_probs: torch.Tensor, picked = [ops[i] for i in idxs] for i in sorted_idxs: ops.pop(i) - ops.append(torch.einsum(eq_str, *picked)) + ops.append(_einsum_torch(eq_str, *picked)) out = ops[0] - return out / out.sum(dim=1, keepdim=True) + return _normalize_prediction(out) return _predict def _build_predict_codegen(self): - """Codegen predict: partial-eval'd flat Python with named locals.""" + """Codegen predict: partial-evaluated Python with named locals.""" static_arrays = self._static_arrays syndrome_positions = tuple(p for p, _t in self._syndrome_positions) noise_pos_ordered = self._noise_pos_ordered @@ -573,6 +882,329 @@ def _predict_static( return _predict_static + def _build_reduced_tn_state(self) -> None: + """Build the reduced TN topology and differentiable noise recipes. + + This mirrors the parent decoder's contracted-noise topology, but + keeps each per-error noise contraction as a torch operation so + gradients still flow to ``noise_probs``. + """ + from collections import defaultdict + + error_inds_set = set(self.error_inds) + + survivor_lookup: dict[tuple[tuple[str, ...], frozenset[str]], int] = {} + doomed_lookup: dict[tuple[tuple[str, ...], frozenset[str]], int] = {} + for opt_pos, tensor in enumerate(self._tensors_ref): + key = (tuple(tensor.inds), frozenset(tensor.tags)) + if any(ind in error_inds_set for ind in tensor.inds): + doomed_lookup[key] = opt_pos + else: + survivor_lookup[key] = opt_pos + + reduced_tn = self.full_tn.copy() + recipes: list[dict[str, Any]] = [] + merged_id_to_recipe_idx: dict[int, int] = {} + + for error_idx, error_ind in enumerate(self.error_inds): + doomed = [t for t in reduced_tn.tensors if error_ind in t.inds] + check_tensors = [t for t in doomed if "NOISE" not in t.tags] + check_opt_positions = [ + doomed_lookup[(tuple(t.inds), frozenset(t.tags))] + for t in check_tensors + ] + + ids_before = {id(t) for t in reduced_tn.tensors} + reduced_tn.contract_ind(error_ind) + new_tensors = [ + t for t in reduced_tn.tensors if id(t) not in ids_before + ] + assert len(new_tensors) == 1 + new_tensor = new_tensors[0] + merged_id_to_recipe_idx[id(new_tensor)] = error_idx + + quimb_out_inds = tuple(new_tensor.inds) + mapping = {error_ind: "e"} + next_code = ord("a") + for ind in quimb_out_inds: + while chr(next_code) == "e": + next_code += 1 + mapping[ind] = chr(next_code) + next_code += 1 + + noise_str = mapping[error_ind] + check_strs = [ + "".join(mapping[ind] for ind in t.inds) for t in check_tensors + ] + out_str = "".join(mapping[ind] for ind in quimb_out_inds) + ordered_check_opt_positions: list[int] = [None + ] * len( # type: ignore + check_tensors) + for tensor, opt_pos in zip(check_tensors, check_opt_positions): + non_error_ind = next( + ind for ind in tensor.inds if ind != error_ind) + ordered_check_opt_positions[quimb_out_inds.index( + non_error_ind)] = opt_pos + + recipes.append({ + "eq": ",".join([noise_str] + check_strs) + "->" + out_str, + "ordered_check_opt_positions": ordered_check_opt_positions, + "k": len(check_tensors), + }) + + reduced_eq = reduced_tn.get_equation( + output_inds=("batch_index", self.logical_obs_inds[0])) + reduced_shapes = tuple(t.shape for t in reduced_tn.tensors) + + reduced_static: dict[int, torch.Tensor] = {} + reduced_syndrome: list[tuple[int, int]] = [] + reduced_recipes: dict[int, int] = {} + syn_pos_to_idx = { + p: i for i, (p, _) in enumerate(self._syndrome_positions) + } + for pos, tensor in enumerate(reduced_tn.tensors): + if id(tensor) in merged_id_to_recipe_idx: + reduced_recipes[pos] = merged_id_to_recipe_idx[id(tensor)] + continue + + key = (tuple(tensor.inds), frozenset(tensor.tags)) + opt_pos = survivor_lookup[key] + if opt_pos in self._static_arrays: + reduced_static[pos] = self._static_arrays[opt_pos] + elif opt_pos in syn_pos_to_idx: + reduced_syndrome.append((pos, syn_pos_to_idx[opt_pos])) + else: + raise AssertionError( + f"Reduced tensor at position {pos} maps to full tensor " + f"position {opt_pos}, which is not static or syndrome.") + + user_optimize = self._reduced_optimize + if user_optimize is not None and ( + type(user_optimize).__module__.startswith("cuquantum") and + type(user_optimize).__name__ == "OptimizerOptions"): + raise ValueError( + "precontract_noise=True does not support cuTensorNet " + "OptimizerOptions; pass an opt_einsum/cotengra optimizer.") + + def _path_largest_intermediate(info: Any) -> float: + value = getattr(info, "largest_intermediate", None) + if value is None: + return float("inf") + try: + return float(value) + except (TypeError, ValueError, OverflowError): + return float("inf") + + def _path_opt_cost(info: Any) -> float: + value = getattr(info, "opt_cost", None) + if value is None: + return float("inf") + try: + return float(value) + except (TypeError, ValueError, OverflowError): + return float("inf") + + candidates: list[tuple[str, Any, Any]] = [] + + def _try_path(tag: str, optimize: Any) -> None: + try: + path, info = oe.contract_path(reduced_eq, + *reduced_shapes, + shapes=True, + optimize=optimize) + except Exception as exc: + warnings.warn( + f"Reduced TN path candidate {tag!r} failed: {exc!r}", + RuntimeWarning, + stacklevel=3, + ) + return + candidates.append((tag, path, info)) + + if user_optimize is None: + reduced_path, reduced_info = _select_default_torch_path( + reduced_eq, reduced_shapes) + selected_tag = "default" + else: + _try_path("user", user_optimize) + _try_path("auto", "auto") + _try_path("greedy", "greedy") + try: + import cotengra as ctg + except ImportError: + pass + else: + for attempt in range(3): + _try_path( + f"cotengra-{attempt}", + ctg.HyperOptimizer(max_repeats=8, parallel=False), + ) + + if not candidates: + raise RuntimeError("No reduced TN contraction path candidates " + "succeeded.") + + selected_tag, reduced_path, reduced_info = min( + candidates, + key=lambda c: + (_path_largest_intermediate(c[2]), _path_opt_cost(c[2])), + ) + if selected_tag != "user": + user_info = next( + (info for tag, _path, info in candidates if tag == "user"), + None) + if user_info is None: + message = ( + "Reduced TN path finder selected " + f"{selected_tag!r} because the user-supplied optimizer " + "did not produce a valid path.") + else: + message = ( + "Reduced TN path finder selected " + f"{selected_tag!r} instead of the user-supplied " + "optimizer because it had a smaller largest " + "intermediate " + f"({_path_largest_intermediate(reduced_info):.3e} vs " + f"{_path_largest_intermediate(user_info):.3e}).") + warnings.warn(message, UserWarning, stacklevel=3) + reduced_oe_expr = oe.contract_expression(reduced_eq, + *reduced_shapes, + optimize=reduced_path) + + recipe_to_reduced_pos = {ri: pos for pos, ri in reduced_recipes.items()} + groups_by_k: dict[int, list[int]] = defaultdict(list) + for recipe_idx, recipe in enumerate(recipes): + groups_by_k[recipe["k"]].append(recipe_idx) + + batched_groups: list[dict[str, Any]] = [] + device = self.torch_device + for k, error_indices in sorted(groups_by_k.items()): + out_letters: list[str] = [] + next_code = ord("a") + for _ in range(k): + while chr(next_code) in ("e", "n"): + next_code += 1 + out_letters.append(chr(next_code)) + next_code += 1 + + out_str = "".join(out_letters) + if k == 0: + eq = "ne->ne" + else: + check_strs = [f"n{letter}e" for letter in out_letters] + eq = "ne," + ",".join(check_strs) + "->n" + out_str + + stacked_checks = [] + for axis in range(k): + axis_arrays = [ + self._static_arrays[recipes[ri] + ["ordered_check_opt_positions"][axis]] + for ri in error_indices + ] + stacked_checks.append(torch.stack(axis_arrays, dim=0)) + + batched_groups.append({ + "k": + k, + "eq": + eq, + "error_indices_t": + torch.tensor(error_indices, dtype=torch.long, + device=device), + "stacked_checks": + stacked_checks, + "reduced_positions": [ + recipe_to_reduced_pos[ri] for ri in error_indices + ], + }) + + self._reduced_tn = reduced_tn + self._batched_einsum_groups = batched_groups + self._reduced_static_positions = reduced_static + self._reduced_syndrome_positions = reduced_syndrome + self._reduced_eq = reduced_eq + self._reduced_oe_expr = reduced_oe_expr + self._reduced_n_tensors = len(reduced_tn.tensors) + self._reduced_path_steps = [(_maybe_remap_eq_to_ascii(step[2]), + tuple(step[0]), + tuple(sorted(step[0], reverse=True))) + for step in reduced_info.contraction_list] + self._reduced_dynamic_positions = tuple( + pos for pos in range(len(reduced_tn.tensors)) + if pos not in reduced_static) + self._reduced_info = reduced_info + self._reduced_path_tag = selected_tag + self.path_batch = reduced_path + self.slicing_batch = tuple() + + def _build_predict_reduced(self): + """Predict using the reduced TN plus batched noise precontraction.""" + static_positions = self._reduced_static_positions + syndrome_positions = self._reduced_syndrome_positions + batched_groups = self._batched_einsum_groups + oe_expr = self._reduced_oe_expr + path_steps = self._reduced_path_steps + dynamic_positions = self._reduced_dynamic_positions + n = self._reduced_n_tensors + + if self._execute_mode == "codegen": + codegen_contract = self._build_codegen_contract( + n, + static_positions, + dynamic_positions, + path_steps, + ) + else: + codegen_contract = None + + def _materialize_arrays( + noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...], + ) -> list[torch.Tensor]: + noise_stacked = torch.stack((1.0 - noise_probs, noise_probs), + dim=-1) + arrays: list[torch.Tensor] = [None] * n # type: ignore + for pos, arr in static_positions.items(): + arrays[pos] = arr + for pos, syndrome_idx in syndrome_positions: + arrays[pos] = syndrome_tuple[syndrome_idx] + for group in batched_groups: + noise_batch = noise_stacked[group["error_indices_t"]] + if group["k"] == 0: + out_batch = noise_batch + else: + out_batch = _einsum_torch(group["eq"], noise_batch, + *group["stacked_checks"]) + for i, pos in enumerate(group["reduced_positions"]): + arrays[pos] = out_batch[i] + return arrays + + def _contract_unrolled(arrays: list[torch.Tensor]) -> torch.Tensor: + ops = list(arrays) + for eq_str, idxs, sorted_idxs in path_steps: + picked = [ops[i] for i in idxs] + for i in sorted_idxs: + ops.pop(i) + ops.append(_einsum_torch(eq_str, *picked)) + return ops[0] + + def _predict(noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...]) -> torch.Tensor: + arrays = _materialize_arrays(noise_probs, syndrome_tuple) + if self._execute_mode == "opt_einsum" and torch.is_grad_enabled( + ) and noise_probs.requires_grad: + out = _checkpoint(oe_expr, *arrays, use_reentrant=False) + elif self._execute_mode == "opt_einsum": + out = oe_expr(*arrays) + elif self._execute_mode == "unrolled": + out = _contract_unrolled(arrays) + else: + dyns = tuple(arrays[pos] for pos in dynamic_positions) + out = codegen_contract(dyns) + return _normalize_prediction(out) + + return _predict + def _maybe_torch_compile(self, fn, *, kind: str): """Wrap ``fn`` with :func:`torch.compile` if requested. @@ -599,7 +1231,7 @@ def _compile_loss(self) -> None: Two variants are produced: one accepting logits (sigmoid applied inside) and one accepting probabilities directly. """ - if self._execute_mode == "codegen": + if self._execute_mode == "codegen" and not self._precontract_noise: logits_fn, probs_fn = self._build_loss_codegen() else: logits_fn, probs_fn = self._build_loss_wrapped() @@ -668,7 +1300,8 @@ def _build_loss_wrapped(self): obs_f = self.obs_idx_false predict_fn = self._predict_fn - if self._dynamic_syndromes: + if (self._execute_mode != "codegen" or self._dynamic_syndromes or + self._precontract_noise): def _loss_from_probs(noise_probs, syndromes): p = predict_fn(noise_probs, syndromes) @@ -711,23 +1344,10 @@ def _torch_compile_kwargs(self) -> dict[str, Any]: def _codegen_partial_eval(n, static_arrays, syndrome_positions, noise_pos_ordered, path_steps, syndrome_tensors, dynamic_syndromes: bool): - """Partial-evaluate ``path_steps``; return the codegen building blocks. - - Steps whose inputs are all static are evaluated eagerly under - ``torch.no_grad`` and become closure constants; the remaining - steps become source lines. - - Returns ``(runtime_lines, closure_vars, used_static, final_state, - n_folded)``: emitted source lines, name -> tensor map for the - function namespace, the subset of names actually referenced, the - single surviving state slot ``(name, is_dynamic, value_or_None)``, - and the count of folded steps. - """ + """Partial-evaluate ``path_steps`` for the codegen builders.""" static_positions = sorted(static_arrays.keys()) noise_pos_set = set(noise_pos_ordered) syn_pos_set = set(syndrome_positions) - # O(1) reverse lookup tables (avoid repeated list.index() inside - # the per-step loop below — was O(N^2) for large path lengths). noise_pos_to_k = {pos: k for k, pos in enumerate(noise_pos_ordered)} syn_pos_to_sidx = { pos: sidx for sidx, pos in enumerate(syndrome_positions) @@ -736,7 +1356,6 @@ def _codegen_partial_eval(n, static_arrays, syndrome_positions, pos: sidx for sidx, pos in enumerate(static_positions) } - # state[pos] = (var_name, is_dynamic, concrete_value_or_None) state: list[tuple[str, bool, torch.Tensor | None]] = [] for pos in range(n): if pos in noise_pos_set: @@ -755,11 +1374,6 @@ def _codegen_partial_eval(n, static_arrays, syndrome_positions, closure_vars: dict[str, torch.Tensor] = {} runtime_lines: list[str] = [] - # Names that the emitted source actually references and that must - # be available in the closure namespace. We track this - # structurally as we go instead of re-parsing the source string, - # which is both faster and immune to lexical false positives / - # negatives (e.g. if an einsum equation contained an underscore). used_static: set[str] = set() n_folded = 0 @@ -774,23 +1388,20 @@ def _codegen_partial_eval(n, static_arrays, syndrome_positions, if not any_dynamic: arrs = [p[2] for p in picked] with torch.no_grad(): - result = torch.einsum(eq_str, *arrs).contiguous() + result = _einsum_torch(eq_str, *arrs).contiguous() static_name = f"_P{step_idx}" closure_vars[static_name] = result state.append((static_name, False, result)) n_folded += 1 else: arg_names = [p[0] for p in picked] - # Track which closure names this line will read. ``_n*`` - # names are header-built locals (from noise_probs) so - # they are *not* closure values; everything else is. for name in arg_names: if name.startswith(("_C", "_P")): used_static.add(name) elif name.startswith("_S") and not dynamic_syndromes: used_static.add(name) runtime_lines.append( - f" {out_name} = torch.einsum({eq_str!r}, " + f" {out_name} = _einsum_torch({eq_str!r}, " f"{', '.join(arg_names)})") state.append((out_name, True, None)) @@ -801,7 +1412,7 @@ def _codegen_partial_eval(n, static_arrays, syndrome_positions, if name.startswith("_C"): sidx = int(name[2:]) closure_vars[name] = static_arrays[static_positions[sidx]] - elif name.startswith("_S"): # static-syndromes mode only + elif name.startswith("_S"): sidx = int(name[2:]) closure_vars[name] = syndrome_tensors[sidx] @@ -810,21 +1421,13 @@ def _codegen_partial_eval(n, static_arrays, syndrome_positions, @staticmethod def _emit_noise_header(noise_pos_ordered, transform: str = "identity") -> list[str]: - """Emit source lines materialising ``_n0 .. _n{K-1}``. - - ``transform="identity"`` treats the input as probabilities; - ``"sigmoid"`` treats it as logits and applies ``torch.sigmoid`` - first. A single ``(K, 2)`` stack is built and then sliced, which - keeps the autograd graph compact. - """ + """Emit source lines materialising ``_n0 .. _n{K-1}``.""" lines: list[str] = [] if transform == "sigmoid": lines.append(" _p = torch.sigmoid(noise_probs)") else: lines.append(" _p = noise_probs") lines.append(" _q = 1.0 - _p") - # One stack of shape (K, 2) instead of K stacks of shape (2,). - # ``dim=1`` makes ``_NS[k]`` a contiguous view of length 2. lines.append(" _NS = torch.stack((_q, _p), dim=1)") for k in range(len(noise_pos_ordered)): lines.append(f" _n{k} = _NS[{k}]") @@ -833,8 +1436,7 @@ def _emit_noise_header(noise_pos_ordered, @staticmethod def _emit_syndrome_header(syndrome_positions, dynamic_syndromes: bool) -> list[str]: - """Emit source lines binding ``_S0 .. _S{S-1}`` to runtime - ``syndromes`` arguments; empty in static mode.""" + """Emit source lines binding runtime syndrome arguments.""" if not dynamic_syndromes: return [] return [ @@ -842,6 +1444,81 @@ def _emit_syndrome_header(syndrome_positions, for sidx in range(len(syndrome_positions)) ] + @classmethod + def _build_codegen_contract(cls, n, static_arrays, dynamic_positions, + path_steps): + """Generate ``_contract(dyns)`` for a static/dynamic operand list.""" + static_positions = sorted(static_arrays.keys()) + dynamic_positions = tuple(dynamic_positions) + dynamic_pos_to_idx = { + pos: idx for idx, pos in enumerate(dynamic_positions) + } + static_pos_to_idx = { + pos: idx for idx, pos in enumerate(static_positions) + } + + state: list[tuple[str, bool, torch.Tensor | None]] = [] + for pos in range(n): + if pos in dynamic_pos_to_idx: + state.append((f"_D{dynamic_pos_to_idx[pos]}", True, None)) + else: + sidx = static_pos_to_idx[pos] + state.append( + (f"_C{sidx}", False, static_arrays[static_positions[sidx]])) + + closure_vars: dict[str, torch.Tensor] = {} + runtime_lines: list[str] = [] + used_static: set[str] = set() + n_folded = 0 + + for step_idx, step in enumerate(path_steps): + eq_str = step[0] + idxs = step[1] + picked = [state[i] for i in idxs] + for i in sorted(idxs, reverse=True): + state.pop(i) + any_dynamic = any(p[1] for p in picked) + out_name = f"_r{step_idx}" + if not any_dynamic: + arrs = [p[2] for p in picked] + with torch.no_grad(): + result = _einsum_torch(eq_str, *arrs).contiguous() + static_name = f"_P{step_idx}" + closure_vars[static_name] = result + state.append((static_name, False, result)) + n_folded += 1 + else: + arg_names = [p[0] for p in picked] + for name in arg_names: + if name.startswith(("_C", "_P")): + used_static.add(name) + runtime_lines.append( + f" {out_name} = _einsum_torch({eq_str!r}, " + f"{', '.join(arg_names)})") + state.append((out_name, True, None)) + + assert len(state) == 1 + for name in used_static: + if name in closure_vars: + continue + if name.startswith("_C"): + sidx = int(name[2:]) + closure_vars[name] = static_arrays[static_positions[sidx]] + + body = ["def _contract(dyns):"] + body.extend(f" _D{idx} = dyns[{idx}]" + for idx in range(len(dynamic_positions))) + body.extend(runtime_lines) + final_name, is_final_dyn, final_value = state[0] + if is_final_dyn: + body.append(f" return {final_name}") + else: + closure_vars["_FINAL"] = final_value + body.append(" return _FINAL") + + return cls._compile_codegen_source(body, closure_vars, n_folded, + len(runtime_lines), "contract") + @classmethod def _build_codegen_predict(cls, n, @@ -851,13 +1528,7 @@ def _build_codegen_predict(cls, path_steps, syndrome_tensors, dynamic_syndromes: bool = True): - """Generate ``_predict(noise_probs[, syndromes]) -> (shots, 2)``. - - With ``dynamic_syndromes=False`` syndromes are folded into the - closure, which maximises partial evaluation but forces a rebuild - on every :meth:`update_dataset` call. With ``True`` (default) - syndromes stay runtime arguments and dataset swaps are free. - """ + """Generate ``_predict(noise_probs[, syndromes]) -> (shots, 2)``.""" runtime_lines, closure_vars, _used, final_state, n_folded = ( cls._codegen_partial_eval( n, @@ -878,11 +1549,8 @@ def _build_codegen_predict(cls, body.append("def _predict(noise_probs):") if fully_static: - # Degenerate case: contraction didn't depend on noise (or any - # other dynamic input). Pre-normalise the constant and - # return it unchanged on every call. with torch.no_grad(): - normed = final_value / final_value.sum(dim=1, keepdim=True) + normed = _normalize_prediction(final_value) closure_vars["_FINAL"] = normed body.append(" return _FINAL") runtime_lines = [] @@ -894,7 +1562,7 @@ def _build_codegen_predict(cls, dynamic_syndromes)) body.extend(runtime_lines) body.append(f" _out = {final_name}") - body.append(" return _out / _out.sum(dim=1, keepdim=True)") + body.append(" return _normalize_prediction(_out)") return cls._compile_codegen_source(body, closure_vars, n_folded, len(runtime_lines), "predict") @@ -911,17 +1579,7 @@ def _build_codegen_loss(cls, obs_idx_false: torch.Tensor, dynamic_syndromes: bool = True, from_logits: bool = True): - """Generate a fused ``(input, syndromes) -> scalar`` loss callable. - - Pipes the contraction output straight into the cross-entropy - reduction so the whole pipeline (optional sigmoid, contraction, - normalisation, cross-entropy) is a single autograd graph. - - Args: - from_logits: If ``True`` (default), apply ``torch.sigmoid`` to - the input; if ``False``, the input must already be in - ``[0, 1]``. - """ + """Generate a fused ``(input, syndromes) -> scalar`` loss callable.""" runtime_lines, closure_vars, _used, final_state, n_folded = ( cls._codegen_partial_eval( n, @@ -945,15 +1603,8 @@ def _build_codegen_loss(cls, body.append("def _loss(noise_probs):") if fully_static: - # The contraction is a constant; the loss it produces is - # also a constant. We still need the gradient wrt - # noise_probs to be zero (a real-valued zero tensor with a - # graph edge), so emit a no-op multiplication by - # noise_probs.sum() * 0 to keep autograd happy. with torch.no_grad(): - normed = final_value / final_value.sum(dim=1, keepdim=True) - # Compute the loss eagerly; we can't fold it because - # autograd needs a path back to noise_probs. + normed = _normalize_prediction(final_value) ce = ( -torch.log(_clamp_log_input(normed[obs_idx_true, 1])).sum() - @@ -968,24 +1619,11 @@ def _build_codegen_loss(cls, cls._emit_syndrome_header(syndrome_positions, dynamic_syndromes)) body.extend(runtime_lines) - # Fused cross-entropy that skips the explicit - # ``_preds = _out / _out.sum(dim=1, keepdim=True)`` step. - # OBS_T and OBS_F partition the batch (every row is in exactly - # one), so: - # -log(p_T[:,1]).sum() - log(p_F[:,0]).sum() - # = log(Z).sum() - log(_out[OBS_T,1]).sum() - # - log(_out[OBS_F,0]).sum() - # where Z = _out[:,0] + _out[:,1]. Saves one (shots, 2) - # division + materialisation and the corresponding backward - # nodes -- ~2-5% per-step on CPU at d=3/r=3. body.append(f" _out = {final_name}") - body.append(" _z0 = _out[:, 0]") - body.append(" _z1 = _out[:, 1]") - body.append(" _eps = torch.finfo(_z0.dtype).tiny") + body.append(" _p = _normalize_prediction(_out)") body.append( - " return (torch.log((_z0 + _z1).clamp_min(_eps)).sum() " - "- torch.log(_z1[_OBS_T].clamp_min(_eps)).sum() " - "- torch.log(_z0[_OBS_F].clamp_min(_eps)).sum())") + " return (-torch.log(_clamp_log_input(_p[_OBS_T, 1])).sum() " + "- torch.log(_clamp_log_input(_p[_OBS_F, 0])).sum())") return cls._compile_codegen_source(body, closure_vars, n_folded, len(runtime_lines), "loss") @@ -996,9 +1634,19 @@ def _compile_codegen_source(body: list[str], n_folded: int, n_runtime: int, kind: str): """Compile the assembled function source and return the callable.""" source = "\n".join(body) - ns: dict[str, Any] = {"torch": torch} + ns: dict[str, Any] = { + "torch": torch, + "_einsum_torch": _einsum_torch, + "_clamp_log_input": _clamp_log_input, + "_finite_nonnegative": _finite_nonnegative, + "_normalize_prediction": _normalize_prediction, + } ns.update(closure_vars) - fn_name = "_loss" if kind == "loss" else "_predict" + fn_name = { + "contract": "_contract", + "loss": "_loss", + "predict": "_predict", + }[kind] exec(compile(source, f"", "exec"), ns) fn = ns[fn_name] fn._n_folded = n_folded # type: ignore[attr-defined] @@ -1023,12 +1671,12 @@ def cross_entropy_loss(self) -> torch.Tensor: def current_syndrome_args(self) -> tuple[torch.Tensor, ...]: """Return the syndrome argument expected by :meth:`loss_fn`. - Returns ``()`` when syndromes are baked into the closure - (``execute="codegen"`` and ``dynamic_syndromes=False``), else - the current live tuple. Re-fetch each step so an intervening - :meth:`update_dataset` is reflected. + Returns ``()`` when syndromes are baked into a static codegen + closure, else the current live tuple. Re-fetch each step so an + intervening :meth:`update_dataset` is reflected. """ - if self._execute_mode == "codegen" and not self._dynamic_syndromes: + if (self._execute_mode == "codegen" and not self._dynamic_syndromes and + not self._precontract_noise): return () return self._syndrome_tuple @@ -1097,12 +1745,10 @@ def _update_data(self, new_shapes.append(tuple(arr.shape)) new_shapes_tuple = tuple(new_shapes) - # Shape change: cached path / codegen / oe expression / compile - # guards are all stale. Drop the path and rebuild from scratch. - # Shapes unchanged: dynamic codegen and the unrolled / - # opt_einsum paths read syndromes per call — refreshing the - # cached tuple is enough. Static codegen baked the old tensors - # into the closure and still needs a full rebuild. + # Shape change: cached path / opt_einsum expression / compile + # guards are stale. Drop the path and rebuild from scratch. + # Shapes unchanged: the forward reads syndromes per call, so + # refreshing the cached tuple is enough. shape_changed = new_shapes_tuple != self._syndrome_shapes if shape_changed: self.path_batch = None @@ -1114,7 +1760,8 @@ def _update_data(self, return self._syndrome_tuple = tuple(self._syndrome_arrays) - if self._execute_mode == "codegen" and not self._dynamic_syndromes: + if (self._execute_mode == "codegen" and not self._dynamic_syndromes and + not self._precontract_noise): try: self._snapshot_arrays_and_eq() finally: @@ -1136,7 +1783,7 @@ def update_dataset(self, new_observable_flips: Shape ``(shots,)``. enforce_shape: Assert that per-tensor shapes match. A changing batch size triggers a full rebuild of the - cached contraction path and codegen. + cached contraction path and forward. """ syndrome_arrays = prepare_syndrome_data_batch(new_syndrome_data) torch_dtype = getattr(torch, self._dtype) @@ -1149,11 +1796,11 @@ def update_dataset(self, self._update_data(syndrome_arrays, new_observable_flips, enforce_shape) def optimize_path(self, optimize: Any = None, batch_size: int = -1) -> Any: - """Cache a contraction path via quimb and rebuild the JIT. + """Cache a contraction path via quimb and rebuild the forward. Always routes through :meth:`TensorNetwork.contraction_info` so - the resulting path is compatible with :mod:`opt_einsum` and - manual unrolling -- unlike :meth:`TensorNetworkDecoder.optimize_path`, + the resulting path is compatible with :mod:`opt_einsum`, unlike + :meth:`TensorNetworkDecoder.optimize_path`, which defaults to a cuTensorNet-only path. ``batch_size`` is part of the parent ``TensorNetworkDecoder`` @@ -1163,9 +1810,28 @@ def optimize_path(self, optimize: Any = None, batch_size: int = -1) -> Any: ignored. Kept for Liskov substitution with the parent. """ del batch_size + if self._precontract_noise: + self._reduced_optimize = optimize + self._snapshot_arrays_and_eq() + return self._reduced_info + + if optimize is None or optimize == "auto": + shapes = tuple(t.shape for t in self.full_tn.tensors) + path, info = _select_default_torch_path( + self.full_tn.get_equation( + output_inds=("batch_index", self.logical_obs_inds[0])), + shapes, + tn=self.full_tn, + output_inds=("batch_index", self.logical_obs_inds[0]), + ) + self.path_batch = path + self.slicing_batch = tuple() + self._snapshot_arrays_and_eq() + return info + info = self.full_tn.contraction_info( output_inds=("batch_index", self.logical_obs_inds[0]), - optimize=optimize if optimize is not None else "auto", + optimize=optimize, ) self.path_batch = info.path self.slicing_batch = tuple() @@ -1173,33 +1839,112 @@ def optimize_path(self, optimize: Any = None, batch_size: int = -1) -> Any: return info -def make_compiled_step(optimizer: NMOptimizer, logits: torch.Tensor, - torch_optimizer: torch.optim.Optimizer): - """Build a no-arg callable that runs one Adam step and returns the loss. +def make_compiled_step(optimizer: NMOptimizer, + logits: torch.Tensor, + torch_optimizer: torch.optim.Optimizer, + *, + max_backtracks: int = 0, + backtrack_factor: float = 0.5, + loss_tolerance: float = 0.0): + """Build a no-arg logit-space optimizer step. - The step zeros grads, calls the optimizer's compiled - ``loss_fn(from_logits=True)`` (sigmoid + contraction + cross-entropy - fused), backwards, and steps ``torch_optimizer``. Use this when - training in logit space. + By default this is a plain optimizer step. With ``max_backtracks > 0``, + rejected steps are retried from the same state with reduced learning rates. Args: - optimizer: The :class:`NMOptimizer` providing the fused - inner loss; pass ``compile=True`` at the - :class:`NMOptimizer` constructor for the - ``torch.compile``-d variant. + optimizer: The :class:`NMOptimizer` providing the loss. logits: Trainable 1-D tensor of length ``len(optimizer.error_inds)`` with ``requires_grad=True``. torch_optimizer: A ``torch.optim`` instance owning ``logits``. + max_backtracks: Number of reduced-LR retries after the initial + optimizer step. ``0`` preserves ordinary optimizer behavior. + backtrack_factor: Multiplicative LR reduction used for each retry. + loss_tolerance: Absolute tolerated post-step loss increase. """ + if max_backtracks < 0: + raise ValueError("max_backtracks must be non-negative.") + if max_backtracks > 0 and not 0.0 < backtrack_factor < 1.0: + raise ValueError("backtrack_factor must be in (0, 1).") - # Re-fetch per call so a rebuild from update_dataset or the - # observable_flips setter is picked up; capturing would go stale. - def _step(): - torch_optimizer.zero_grad(set_to_none=True) - loss = optimizer.loss_fn(from_logits=True)( + def _loss(): + return optimizer.loss_fn(from_logits=True)( logits, optimizer.current_syndrome_args()) + + def _is_finite_tensor(value: torch.Tensor) -> bool: + return bool(torch.isfinite(value).all().detach().cpu()) + + def _logits_are_finite() -> bool: + return _is_finite_tensor(logits) + + def _grads_are_finite() -> bool: + if logits.grad is None: + return True + return _is_finite_tensor(logits.grad) + + # Resolve the loss each call so dataset updates are reflected. + def _plain_step(): + torch_optimizer.zero_grad(set_to_none=True) + loss = _loss() + if not _is_finite_tensor(loss): + return loss loss.backward() + if not _grads_are_finite(): + raise RuntimeError("Non-finite NMOptimizer logit gradients.") torch_optimizer.step() + if not _logits_are_finite(): + raise RuntimeError("Non-finite NMOptimizer logits after step.") return loss - return _step + if max_backtracks == 0: + return _plain_step + + def _set_group_lrs(lrs: list[float]) -> None: + for group, lr in zip(torch_optimizer.param_groups, lrs): + group["lr"] = lr + + def _restore_state(saved_logits: torch.Tensor, + saved_state: dict[str, Any]) -> None: + with torch.no_grad(): + logits.copy_(saved_logits) + torch_optimizer.load_state_dict(copy.deepcopy(saved_state)) + + def _guarded_step(): + base_lrs = [ + float(group["lr"]) for group in torch_optimizer.param_groups + ] + saved_logits = logits.detach().clone() + saved_state = copy.deepcopy(torch_optimizer.state_dict()) + best_lrs = base_lrs + current_loss: torch.Tensor | None = None + + for attempt in range(max_backtracks + 1): + trial_lrs = [lr * (backtrack_factor**attempt) for lr in base_lrs] + best_lrs = trial_lrs + _restore_state(saved_logits, saved_state) + _set_group_lrs(trial_lrs) + + torch_optimizer.zero_grad(set_to_none=True) + loss = _loss() + if current_loss is None: + current_loss = loss.detach() + if not bool(torch.isfinite(loss).detach().cpu()): + return loss + + loss.backward() + if not _grads_are_finite(): + continue + torch_optimizer.step() + if not _logits_are_finite(): + continue + + with torch.no_grad(): + next_loss = _loss().detach() + if (bool(torch.isfinite(next_loss).detach().cpu()) and bool( + (next_loss <= current_loss + loss_tolerance).cpu())): + return current_loss + + _restore_state(saved_logits, saved_state) + _set_group_lrs(best_lrs) + return current_loss if current_loss is not None else _loss() + + return _guarded_step diff --git a/libs/qec/python/tests/test_nm_optimizer.py b/libs/qec/python/tests/test_nm_optimizer.py index 97dceaba..1c2856f6 100644 --- a/libs/qec/python/tests/test_nm_optimizer.py +++ b/libs/qec/python/tests/test_nm_optimizer.py @@ -22,8 +22,11 @@ "torch", reason="torch not installed; skipping TN noise-learning tests") if sys.version_info >= (3, 11): + import cudaq_qec.plugins.decoders.tensor_network_utils.nm_optimizer as nmopt from cudaq_qec.plugins.decoders.tensor_network_utils.nm_optimizer import ( NMOptimizer, + _einsum_torch, + _normalize_prediction, make_compiled_step, remap_eq_to_ascii, ) @@ -44,7 +47,7 @@ def _device_params(): return out -_EXECUTE_MODES = ("codegen", "unrolled", "opt_einsum") +_EXECUTE_MODES = ("opt_einsum", "unrolled", "codegen") # -- fixtures / helpers ------------------------------------------------------- @@ -101,8 +104,8 @@ def _make_opt(H, logical, priors, syn, flips, **kwargs): def _naive_cross_entropy(opt: "NMOptimizer") -> torch.Tensor: """Reference cross-entropy: predict, then ``-log p`` per shot. - Mirrors the pre-fusion implementation; used to verify the codegen - loss in :func:`test_fused_loss_matches_naive`. + Used to verify :meth:`NMOptimizer.cross_entropy_loss` across + execution modes. """ preds = opt.decoder_prediction() obs_t = opt.obs_idx_true @@ -115,7 +118,7 @@ def _naive_cross_entropy(opt: "NMOptimizer") -> torch.Tensor: @pytest.mark.parametrize("device", _device_params()) -def test_construction_basic(device): +def test_construction_basic(device, capsys): H, logical, priors = _simple_repetition_code() syn, flips = _sample_synthetic_dataset(H, logical, @@ -123,6 +126,12 @@ def test_construction_basic(device): num_shots=8, rng=np.random.default_rng(0)) opt = _make_opt(H, logical, priors, syn, flips, device=device) + captured = capsys.readouterr() + assert "cutensornet" not in captured.out + assert "CUDA is not available" not in captured.out + assert opt.contractor_config.contractor_name == "oe_torch_compiled" + assert opt.contractor_config.backend == "torch" + assert opt.contractor_config.device == device assert opt._batch_size == 8 assert opt._noise_probs.requires_grad assert len(opt.noise_params) == 1 @@ -151,6 +160,24 @@ def test_invalid_execute_mode_rejected(device): execute="bogus") +@pytest.mark.parametrize("device", _device_params()) +def test_invalid_precontract_noise_rejected(device): + H, logical, priors = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=4, + rng=np.random.default_rng(24)) + with pytest.raises(ValueError, match="precontract_noise must be"): + _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + precontract_noise="bad") + + @pytest.mark.parametrize("device", _device_params()) def test_invalid_dtype_rejected(device): """Unsupported dtypes must be rejected at the constructor boundary, @@ -197,8 +224,7 @@ def test_gradient_flows(device, execute): Uses :func:`_nondegenerate_code` plus mismatched init priors so the loss surface has a non-trivial gradient. Parametrized over every - ``execute`` mode so unrolled and opt_einsum paths can't silently - regress on the autograd graph. + execution mode so benchmark paths cannot silently regress. """ rng = np.random.default_rng(3) H, logical = _nondegenerate_code() @@ -230,12 +256,7 @@ def test_gradient_flows(device, execute): @pytest.mark.parametrize("device", _device_params()) @pytest.mark.parametrize("execute", _EXECUTE_MODES) def test_fused_loss_matches_naive(device, execute): - """``cross_entropy_loss`` == predict + manual CE in every execute mode. - - Codegen mode fuses the CE reduction into the contraction graph; - unrolled/opt_einsum wrap ``predict_fn``. All three must agree with - the naive reference up to fp64 round-off. - """ + """``cross_entropy_loss`` == predict + manual CE in every execute mode.""" rng = np.random.default_rng(11) H, logical = _nondegenerate_code() init_priors = [0.2, 0.3, 0.4] @@ -260,7 +281,7 @@ def test_fused_loss_matches_naive(device, execute): @pytest.mark.parametrize("device", _device_params()) def test_fused_loss_matches_naive_static_codegen(device): - """Static codegen (``dynamic_syndromes=False``) numerical correctness.""" + """Static codegen numerical correctness.""" rng = np.random.default_rng(13) H, logical = _nondegenerate_code() init_priors = [0.2, 0.3, 0.4] @@ -298,9 +319,6 @@ def test_fused_loss_matches_naive_static_codegen(device): def test_loss_fn_from_logits_and_probs(device, execute): """``loss_fn(from_logits=True)`` matches ``loss_fn(from_logits=False) o sigmoid``, and both agree with ``cross_entropy_loss`` on the optimiser's own probs. - - Parametrized over execute modes so the logit-vs-probs equivalence is - validated on every supported backend. """ rng = np.random.default_rng(12) H, logical = _nondegenerate_code() @@ -329,6 +347,73 @@ def test_loss_fn_from_logits_and_probs(device, execute): assert torch.allclose(v_probs, v_self, atol=1e-8, rtol=1e-8) +@pytest.mark.parametrize("device", _device_params()) +@pytest.mark.parametrize("execute", _EXECUTE_MODES) +def test_precontract_noise_matches_full(device, execute): + rng = np.random.default_rng(25) + H, logical = _nondegenerate_code() + priors = [0.2, 0.3, 0.4] + syn, flips = _sample_synthetic_dataset(H, + logical, [0.1, 0.15, 0.25], + num_shots=24, + rng=rng) + full = _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + dtype="float64") + reduced = _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + dtype="float64", + execute=execute, + precontract_noise=True) + with torch.no_grad(): + p_full = full.decoder_prediction() + p_reduced = reduced.decoder_prediction() + loss_full = full.cross_entropy_loss() + loss_reduced = reduced.cross_entropy_loss() + assert torch.allclose(p_full, p_reduced, atol=1e-7, rtol=1e-7) + assert torch.allclose(loss_full, loss_reduced, atol=1e-7, rtol=1e-7) + + reduced.noise_params[0].grad = None + loss = reduced.cross_entropy_loss() + loss.backward() + grad = reduced.noise_params[0].grad + assert grad is not None + assert torch.isfinite(grad).all() + assert torch.any(grad != 0.0) + + +@pytest.mark.parametrize("device", _device_params()) +def test_precontract_noise_auto_uses_reduced_for_large_intermediate( + device, monkeypatch): + monkeypatch.setattr(nmopt, "_TORCH_CONTRACTION_MAX_INTERMEDIATE_ELEMENTS", + -1) + H, logical, priors = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=4, + rng=np.random.default_rng(26)) + opt = _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + precontract_noise="auto") + assert opt._precontract_noise is True + with torch.no_grad(): + pred = opt.decoder_prediction() + assert pred.shape == (4, 2) + + # -- numerical guards -------------------------------------------------------- @@ -457,8 +542,8 @@ def test_non_1d_noise_model_rejected(): @pytest.mark.parametrize("device", _device_params()) -def test_current_syndrome_args_dynamic_returns_live_tuple(device): - """Dynamic mode: returns the live syndrome tuple.""" +def test_current_syndrome_args_returns_live_tuple(device): + """Returns the live syndrome tuple used by ``loss_fn``.""" rng = np.random.default_rng(101) H, logical, priors = _simple_repetition_code() syn, flips = _sample_synthetic_dataset(H, @@ -466,14 +551,7 @@ def test_current_syndrome_args_dynamic_returns_live_tuple(device): priors, num_shots=12, rng=rng) - opt = _make_opt(H, - logical, - priors, - syn, - flips, - device=device, - execute="codegen", - dynamic_syndromes=True) + opt = _make_opt(H, logical, priors, syn, flips, device=device) args = opt.current_syndrome_args() assert args is opt._syndrome_tuple assert len(args) > 0 @@ -482,8 +560,8 @@ def test_current_syndrome_args_dynamic_returns_live_tuple(device): @pytest.mark.parametrize("device", _device_params()) -def test_current_syndrome_args_static_returns_empty(device): - """Static codegen mode: returns ``()`` (syndromes are closure-baked).""" +def test_current_syndrome_args_static_codegen_returns_empty(device): + """Static codegen returns ``()`` because syndromes are closure-baked.""" rng = np.random.default_rng(102) H, logical, priors = _simple_repetition_code() syn, flips = _sample_synthetic_dataset(H, @@ -508,8 +586,8 @@ def test_current_syndrome_args_static_returns_empty(device): @pytest.mark.parametrize("device", _device_params()) -def test_update_dataset_dynamic_keeps_predict_fn(device): - """Dynamic mode: predict function identity unchanged across swaps.""" +def test_update_dataset_same_shape_keeps_predict_fn(device): + """Same-shape swaps refresh data without rebuilding the predict function.""" rng = np.random.default_rng(5) H, logical, priors = _simple_repetition_code() syn1, flips1 = _sample_synthetic_dataset(H, @@ -517,13 +595,7 @@ def test_update_dataset_dynamic_keeps_predict_fn(device): priors, num_shots=10, rng=rng) - opt = _make_opt(H, - logical, - priors, - syn1, - flips1, - device=device, - dynamic_syndromes=True) + opt = _make_opt(H, logical, priors, syn1, flips1, device=device) fn_before = opt._predict_fn syn2, flips2 = _sample_synthetic_dataset(H, logical, @@ -535,8 +607,8 @@ def test_update_dataset_dynamic_keeps_predict_fn(device): @pytest.mark.parametrize("device", _device_params()) -def test_update_dataset_static_rebuilds_predict_fn(device): - """Static mode: predict function is re-codegened on swap.""" +def test_update_dataset_static_codegen_rebuilds_predict_fn(device): + """Static codegen rebuilds when same-shape syndrome values change.""" rng = np.random.default_rng(6) H, logical, priors = _simple_repetition_code() syn1, flips1 = _sample_synthetic_dataset(H, @@ -550,6 +622,7 @@ def test_update_dataset_static_rebuilds_predict_fn(device): syn1, flips1, device=device, + execute="codegen", dynamic_syndromes=False) fn_before = opt._predict_fn syn2, flips2 = _sample_synthetic_dataset(H, @@ -582,6 +655,7 @@ def test_update_dataset_shape_change_rebuilds_and_decodes( flips1, device=device, dtype="float64", + execute="codegen", dynamic_syndromes=dynamic_syndromes) syn2, flips2 = _sample_synthetic_dataset(H, logical, @@ -602,6 +676,7 @@ def test_update_dataset_shape_change_rebuilds_and_decodes( flips2, device=device, dtype="float64", + execute="codegen", dynamic_syndromes=dynamic_syndromes) with torch.no_grad(): ref_loss = ref.cross_entropy_loss() @@ -704,13 +779,12 @@ def test_optimize_path_with_cotengra(device): def test_remap_eq_to_ascii_simple(): eq = "ab,bc->ac" out = remap_eq_to_ascii(eq) - # ASCII input is returned unchanged via the ``isascii()`` fast path. assert out == "ab,bc->ac" def test_remap_eq_to_ascii_unicode_labels(): """Synthetic equation with non-ASCII labels is remapped to a-zA-Z.""" - eq = "\u0391\u0392,\u0392\u0393->\u0391\u0393" # greek letters + eq = "\u0391\u0392,\u0392\u0393->\u0391\u0393" out = remap_eq_to_ascii(eq) assert "\u0391" not in out and "\u0392" not in out and "\u0393" not in out assert "->" in out @@ -720,13 +794,50 @@ def test_remap_eq_to_ascii_unicode_labels(): def test_remap_eq_to_ascii_too_many_labels(): - """Equations with > 52 distinct labels raise.""" - chars = [chr(0x4E00 + i) for i in range(53)] # 53 distinct CJK chars + """Equations with more than 52 distinct labels raise.""" + chars = [chr(0x4E00 + i) for i in range(53)] eq = "".join(chars) + "->" + chars[0] with pytest.raises(ValueError, match="more than 52"): remap_eq_to_ascii(eq) +def test_einsum_torch_handles_more_than_52_pairwise_labels(): + """High-rank pairwise steps bypass torch.einsum's label limit.""" + chars = [chr(0x4E00 + i) for i in range(54)] + lhs_a = "".join(chars[:53]) + lhs_b = chars[52] + chars[53] + rhs = "".join(chars[:52]) + chars[53] + eq = f"{lhs_a},{lhs_b}->{rhs}" + a = torch.ones((1,) * 53, dtype=torch.float64) + b = torch.full((1, 1), 2.0, dtype=torch.float64) + out = _einsum_torch(eq, a, b) + assert out.shape == (1,) * 53 + assert torch.allclose(out, torch.full_like(out, 2.0)) + + +def test_normalize_prediction_handles_bad_weights(): + """Raw weights with roundoff pathologies still produce finite probs.""" + weights = torch.tensor( + [ + [float("nan"), 1.0], + [-1.0, -2.0], + [float("inf"), 1.0], + [float("nan"), float("nan")], + [float("inf"), float("inf")], + [0.25, 0.75], + ], + dtype=torch.float32, + ) + probs = _normalize_prediction(weights) + assert torch.isfinite(probs).all() + assert (probs > 0).all() + assert (probs <= 1).all() + assert torch.allclose(probs[1], torch.tensor([0.5, 0.5])) + assert torch.allclose(probs[3], torch.tensor([0.5, 0.5])) + assert torch.allclose(probs[4], torch.tensor([0.5, 0.5])) + assert torch.allclose(probs[5], torch.tensor([0.25, 0.75])) + + # -- logical_error_rate ------------------------------------------------------ @@ -876,8 +987,7 @@ def test_recovers_true_priors_within_tol(device): syn, flips, device=device, - dtype="float64", - execute="codegen") + dtype="float64") init_p = opt.noise_params[0].detach() logits = torch.logit(init_p).clone().requires_grad_(True) torch_opt = torch.optim.Adam([logits], lr=0.05) @@ -890,5 +1000,49 @@ def test_recovers_true_priors_within_tol(device): atol=0.02) +class _QuadraticOptimizer: + + def loss_fn(self, from_logits=True): + return lambda logits, syndrome_args: (logits**2).sum() + + def current_syndrome_args(self): + return () + + +class _NanStepOptimizer(torch.optim.Optimizer): + + def __init__(self, params): + super().__init__(params, defaults={}) + + def step(self, closure=None): + for group in self.param_groups: + for param in group["params"]: + with torch.no_grad(): + param.fill_(float("nan")) + + +def test_make_compiled_step_backtracks_large_loss_increase(): + logits = torch.tensor([1.0], requires_grad=True) + torch_opt = torch.optim.SGD([logits], lr=10.0) + step_fn = make_compiled_step(_QuadraticOptimizer(), + logits, + torch_opt, + max_backtracks=5) + before = float((logits**2).sum().detach()) + step_loss = step_fn() + after = float((logits**2).sum().detach()) + assert torch.isfinite(step_loss) + assert after < before + assert torch_opt.param_groups[0]["lr"] < 10.0 + + +def test_make_compiled_step_rejects_nonfinite_logits_after_step(): + logits = torch.tensor([1.0], requires_grad=True) + step_fn = make_compiled_step(_QuadraticOptimizer(), logits, + _NanStepOptimizer([logits])) + with pytest.raises(RuntimeError, match="Non-finite.*logits"): + step_fn() + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/libs/qec/python/tests/test_tensor_network_decoder.py b/libs/qec/python/tests/test_tensor_network_decoder.py index 4fae06c4..d3d9dd04 100644 --- a/libs/qec/python/tests/test_tensor_network_decoder.py +++ b/libs/qec/python/tests/test_tensor_network_decoder.py @@ -23,7 +23,8 @@ prepare_syndrome_data_batch, tensor_network_from_syndrome_batch, tensor_network_from_logical_observable) from cudaq_qec.plugins.decoders.tensor_network_utils.contractors import ( - optimize_path, cutn_contractor, ContractorConfig, contractor) + optimize_path, cutn_contractor, ContractorConfig, contractor, + oe_torch_contractor, oe_torch_compiled_contractor) from cudaq_qec.plugins.decoders.tensor_network_utils.noise_models import factorized_noise_model, error_pairs_noise_model pytestmark = pytest.mark.skipif(sys.version_info < (3, 11), @@ -528,40 +529,28 @@ def test_error_pairs_noise_model_default_tags(): assert "NOISE" in t.tags -def test_valid_numpy_cpu(): - cfg = ContractorConfig("numpy", "numpy", "cpu") - assert cfg.contractor_name == "numpy" - assert cfg.backend == "numpy" - assert cfg.device == "cpu" +@pytest.mark.parametrize( + "contractor_name, backend, device, expected_contractor", + [ + ("numpy", "numpy", "cpu", contractor), + ("torch", "torch", "cpu", contractor), + ("torch", "torch", "cuda:0", contractor), + ("oe_torch", "torch", "cpu", oe_torch_contractor), + ("oe_torch", "torch", "cuda:0", oe_torch_contractor), + ("oe_torch_compiled", "torch", "cpu", oe_torch_compiled_contractor), + ("oe_torch_compiled", "torch", "cuda:0", oe_torch_compiled_contractor), + ("cutensornet", "numpy", "cuda", cutn_contractor), + ("cutensornet", "torch", "cuda", cutn_contractor), + ], +) +def test_valid_contractor_config(contractor_name, backend, device, + expected_contractor): + cfg = ContractorConfig(contractor_name, backend, device) + assert cfg.contractor_name == contractor_name + assert cfg.backend == backend + assert cfg.device == device assert cfg.device_id == 0 - assert cfg.contractor is contractor - - -def test_valid_torch_cpu(): - cfg = ContractorConfig("torch", "torch", "cpu") - assert cfg.contractor_name == "torch" - assert cfg.backend == "torch" - assert cfg.device == "cpu" - assert cfg.device_id == 0 - assert cfg.contractor is contractor - - -def test_valid_cutensornet_numpy_cuda(): - cfg = ContractorConfig("cutensornet", "numpy", "cuda") - assert cfg.contractor_name == "cutensornet" - assert cfg.backend == "numpy" - assert cfg.device == "cuda" - assert cfg.device_id == 0 - assert cfg.contractor is cutn_contractor - - -def test_valid_cutensornet_torch_cuda(): - cfg = ContractorConfig("cutensornet", "torch", "cuda") - assert cfg.contractor_name == "cutensornet" - assert cfg.backend == "torch" - assert cfg.device == "cuda" - assert cfg.device_id == 0 - assert cfg.contractor is cutn_contractor + assert cfg.contractor is expected_contractor def test_cuda_device_id_parsing():