From 80de5f4ed93d0977f703e3c2b0c33bb772c4a755 Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Tue, 26 May 2026 12:31:09 -0400 Subject: [PATCH 1/2] add cutn support Signed-off-by: vedika-saravanan --- .../tensor_network_utils/contractors.py | 104 ++++- .../tensor_network_utils/nm_optimizer.py | 369 +++++++++++++++++- libs/qec/python/tests/test_nm_optimizer.py | 148 +++++++ 3 files changed, 590 insertions(+), 31 deletions(-) diff --git a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/contractors.py b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/contractors.py index 5ded4fb3..50fe5485 100644 --- a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/contractors.py +++ b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/contractors.py @@ -7,11 +7,14 @@ # ============================================================================ # from __future__ import annotations +from collections import OrderedDict from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, ClassVar +from typing import Any, ClassVar, Optional, Union +import numpy.typing as npt import opt_einsum as oe +import torch from quimb.tensor import TensorNetwork @@ -33,6 +36,67 @@ def contractor(subscripts: str, return oe.contract(subscripts, *tensors, optimize=optimize) +def oe_torch_contractor(subscripts: str, + tensors: list[torch.Tensor], + optimize: str = "auto", + **_: Any) -> Any: + """Perform einsum contraction using opt_einsum with the torch backend. + + Combines opt_einsum's contraction-path optimisation with torch's + execution engine, giving autograd support and GPU acceleration in a + single call. Execution device follows the input tensors. + + Args: + subscripts: The einsum subscripts. + tensors: list of torch tensors to contract. + optimize: Optimization strategy passed to ``opt_einsum.contract``. + Defaults to ``"auto"``. + + Returns: + The contracted tensor. + """ + return oe.contract(subscripts, *tensors, optimize=optimize, backend="torch") + + +# TODO: move to per-decoder instance; module-global cache means unrelated decoders share and evict each other. +_OE_EXPR_CACHE_MAXSIZE = 32 +_oe_expr_cache: OrderedDict[tuple, Any] = OrderedDict() + + +def oe_torch_compiled_contractor(subscripts: str, + tensors: list[torch.Tensor], + optimize: str = "auto", + **_: Any) -> Any: + """Perform einsum contraction using a cached ``opt_einsum.contract_expression`` + with the torch backend. + + On the first call for a given ``(subscripts, shapes, optimize)`` + combination, builds and caches a :class:`opt_einsum.ContractExpression`. + Subsequent calls with the same key skip path search entirely and only + execute the pairwise tensor contractions via torch. + + Args: + subscripts: The einsum subscripts. + tensors: list of torch tensors to contract. + optimize: Optimization strategy passed to + ``opt_einsum.contract_expression``. Defaults to ``"auto"``. + + Returns: + The contracted tensor. + """ + shapes = tuple(t.shape for t in tensors) + key = (subscripts, shapes, str(optimize)) + if key in _oe_expr_cache: + _oe_expr_cache.move_to_end(key) + else: + if len(_oe_expr_cache) >= _OE_EXPR_CACHE_MAXSIZE: + _oe_expr_cache.popitem(last=False) + _oe_expr_cache[key] = oe.contract_expression(subscripts, + *shapes, + optimize=optimize) + return _oe_expr_cache[key](*tensors, backend="torch") + + def cutn_contractor(subscripts: str, tensors: list[Any], optimize: Any | None = None, @@ -61,31 +125,43 @@ def cutn_contractor(subscripts: str, ) -def optimize_path(optimize: Any, output_inds: tuple[str, ...], - tn: TensorNetwork) -> tuple[Any, Any]: +def optimize_path(optimize: Any, + output_inds: tuple[str, ...], + tn: TensorNetwork, + network_options: Any = None) -> tuple[Any, Any]: """ Optimize the contraction path for a tensor network. Args: - optimize (Any): The optimization options to use. - If None or cuquantum.tensornet.OptimizerOptions, we use cuquantum.tensornet. - Else, Quimb interface at - https://quimb.readthedocs.io/en/latest/autoapi/quimb/tensor/tensor_core/index.html#quimb.tensor.tensor_core.TensorNetwork.contraction_info - output_inds (tuple[str, ...]): Output indices for the contraction. - tn (TensorNetwork): The tensor network. + optimize: The optimization options to use. + If ``None`` or a ``cuquantum.tensornet.OptimizerOptions`` + instance, dispatches to ``cuquantum.tensornet.contract_path``. + Otherwise routes through Quimb's + :meth:`TensorNetwork.contraction_info` (which accepts + opt_einsum string presets, :class:`PathOptimizer` + instances, :class:`cotengra.HyperOptimizer`, etc.). + output_inds: Output indices for the contraction. + tn: The tensor network. + network_options: Optional cuTensorNet ``NetworkOptions`` (or + equivalent dict). Forwarded as ``options=`` to + ``cuquantum.tensornet.contract_path``. Ignored for + non-cuTensorNet optimizers. Returns: - tuple[Any, Any]: The contraction path and optimizer info. + A ``(path, info)`` tuple. """ use_cutn = optimize is None or ( type(optimize).__module__.startswith("cuquantum") and type(optimize).__name__ == "OptimizerOptions") if use_cutn: from cuquantum import tensornet as cutn + kwargs: dict[str, Any] = {"optimize": optimize} + if network_options is not None: + kwargs["options"] = network_options path, info = cutn.contract_path( tn.get_equation(output_inds=output_inds), *tn.arrays, - optimize=optimize, + **kwargs, ) return path, info @@ -109,6 +185,10 @@ class ContractorConfig: _allowed_configs: ClassVar[tuple[tuple[str, str, str], ...]] = ( ("numpy", "numpy", "cpu"), ("torch", "torch", "cpu"), + ("oe_torch", "torch", "cpu"), + ("oe_torch", "torch", "cuda"), + ("oe_torch_compiled", "torch", "cpu"), + ("oe_torch_compiled", "torch", "cuda"), ("cutensornet", "numpy", "cuda"), ("cutensornet", "torch", "cuda"), ) @@ -116,6 +196,8 @@ class ContractorConfig: _contractors: ClassVar[dict[str, Callable]] = { "numpy": contractor, "torch": contractor, + "oe_torch": oe_torch_contractor, + "oe_torch_compiled": oe_torch_compiled_contractor, "cutensornet": cutn_contractor, } diff --git a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py index ae229545..6f18b016 100644 --- a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py +++ b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py @@ -27,6 +27,7 @@ from quimb.tensor import TensorNetwork from ..tensor_network_decoder import TensorNetworkDecoder +from .contractors import optimize_path as _optimize_path_dispatch from .tensor_network_factory import ( tensor_network_from_syndrome_batch, prepare_syndrome_data_batch, @@ -224,12 +225,17 @@ def __init__( execute: Literal["codegen", "unrolled", "opt_einsum"] = "codegen", compile_mode: str | None = None, dynamic_syndromes: bool = True, + precontract_noise: bool = False, ) -> None: if execute not in ("unrolled", "opt_einsum", "codegen"): raise ValueError(f"Invalid execute mode: {execute!r}") if dtype not in _SUPPORTED_DTYPES: raise ValueError(f"Invalid dtype {dtype!r}; expected one of " f"{list(_SUPPORTED_DTYPES)}.") + if precontract_noise and execute != "opt_einsum": + raise ValueError( + "precontract_noise=True requires execute='opt_einsum'; " + f"got {execute!r}") # Sanitise once so the base TN tensors and ``self._noise_probs`` # see identical values (see :func:`_validate_and_clamp_priors`). @@ -307,8 +313,10 @@ def __init__( self._execute_mode = execute self._torch_compile_mode = compile_mode self._dynamic_syndromes = dynamic_syndromes + self._precontract_noise = precontract_noise self._compiled_predict: Any | None = None self._syndrome_tuple: tuple[torch.Tensor, ...] = () + self.batch_slices: int = 1 self._snapshot_arrays_and_eq() self._suspend_loss_rebuild = False @@ -445,6 +453,8 @@ def _as_torch(x): if self.path_batch not in (None, "auto") else "auto", ) self._path_steps = None + if self._precontract_noise: + self._build_reduced_tn_state() else: self._oe_expr = None # Flatten the path into ``[(eq, idxs, sorted_desc), ...]``; @@ -469,15 +479,47 @@ def _as_torch(x): def _compile_predict(self) -> None: """Build ``self._predict_fn`` for the configured execute mode.""" - builders = { - "opt_einsum": self._build_predict_opt_einsum, - "unrolled": self._build_predict_unrolled, - "codegen": self._build_predict_codegen, - } - self._predict_fn = builders[self._execute_mode]() + if self._precontract_noise: + base_predict = self._build_predict_reduced() + else: + builders = { + "opt_einsum": self._build_predict_opt_einsum, + "unrolled": self._build_predict_unrolled, + "codegen": self._build_predict_codegen, + } + base_predict = builders[self._execute_mode]() + if self.batch_slices > 1: + base_predict = self._wrap_predict_batch_sliced(base_predict) + self._predict_fn = base_predict self._compiled_predict = self._maybe_torch_compile(self._predict_fn, kind="predict") + def _wrap_predict_batch_sliced(self, base_predict_fn): + """Wrap a predict function so it iterates over batch-axis chunks. + + Syndrome tensors are shape ``(2, batch)``; we split along axis 1 + into ``self.batch_slices`` chunks, call ``base_predict_fn`` per + chunk, then concatenate the per-chunk ``(chunk_batch, 2)`` + outputs along dim 0. ``torch.cat`` is differentiable so the + autograd graph through ``noise_probs`` is preserved. + """ + n_slices = self.batch_slices + + def _sliced_predict( + noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...], + ) -> torch.Tensor: + split = [ + torch.tensor_split(t, n_slices, dim=1) for t in syndrome_tuple + ] + outs = [] + for s in range(n_slices): + chunk = tuple(split[k][s] for k in range(len(split))) + outs.append(base_predict_fn(noise_probs, chunk)) + return torch.cat(outs, dim=0) + + return _sliced_predict + def _build_predict_opt_einsum(self): """opt_einsum-backed predict: reuse the cached contract expression.""" static_arrays = self._static_arrays @@ -504,6 +546,210 @@ def _predict(noise_probs: torch.Tensor, return _predict + def _build_reduced_tn_state(self) -> None: + """Build the reduced TN topology + per-error einsum specs. + + Equivalent of :meth:`TensorNetworkDecoder.init_noise_model` + ``contract=True``, but deferred: the per-error noise-into-checks + contractions become differentiable :func:`torch.einsum` calls + invoked per step, so the noise priors stay leaves of the + autograd graph while the main contraction runs on the + ``contract_noise_model=True`` topology. + """ + import cotengra as ctg + + error_inds_set = set(self.error_inds) + + survivor_lookup: dict = {} + doomed_lookup: dict = {} + for opt_pos, t in enumerate(self._tensors_ref): + key = (tuple(t.inds), frozenset(t.tags)) + if any(ind in error_inds_set for ind in t.inds): + doomed_lookup[key] = opt_pos + else: + survivor_lookup[key] = opt_pos + + reduced_tn = self.full_tn.copy() + recipes: list[dict[str, Any]] = [] + merged_id_to_recipe_idx: dict[int, int] = {} + + for error_idx, error_ind in enumerate(self.error_inds): + doomed = [t for t in reduced_tn.tensors if error_ind in t.inds] + check_ts = [t for t in doomed if 'NOISE' not in t.tags] + check_opt_positions = [ + doomed_lookup[(tuple(ct.inds), frozenset(ct.tags))] + for ct in check_ts + ] + ids_before = {id(t) for t in reduced_tn.tensors} + reduced_tn.contract_ind(error_ind) + new_ts = [t for t in reduced_tn.tensors if id(t) not in ids_before] + assert len(new_ts) == 1 + new_t = new_ts[0] + merged_id_to_recipe_idx[id(new_t)] = error_idx + + # Quimb's index order on the merged tensor -- the einsum + # output must match it so axes align in reduced_tn. + quimb_out_inds = tuple(new_t.inds) + mapping = {error_ind: 'e'} + next_code = ord('a') + for ind in quimb_out_inds: + while chr(next_code) == 'e': + next_code += 1 + mapping[ind] = chr(next_code) + next_code += 1 + noise_str = mapping[error_ind] + check_strs = [ + "".join(mapping[i] for i in ct.inds) for ct in check_ts + ] + out_str = "".join(mapping[i] for i in quimb_out_inds) + # Canonical order: ordered_check_opt_positions[axis] is the + # opt-array position of the check tensor whose non-error + # index is quimb_out_inds[axis]. This lets us batch every + # error in a signature class through a single torch.einsum. + ordered_check_opt_positions: list[int] = [None] * len( + check_ts) # type: ignore + for ct, ct_pos in zip(check_ts, check_opt_positions): + non_e = next(i for i in ct.inds if i != error_ind) + ordered_check_opt_positions[quimb_out_inds.index( + non_e)] = ct_pos + recipes.append({ + 'eq': ",".join([noise_str] + check_strs) + "->" + out_str, + 'check_opt_positions': check_opt_positions, + 'ordered_check_opt_positions': ordered_check_opt_positions, + 'k': len(check_ts), + }) + + reduced_eq = reduced_tn.get_equation( + output_inds=("batch_index", self.logical_obs_inds[0])) + reduced_shapes = tuple(t.shape for t in reduced_tn.tensors) + + reduced_static: dict[int, torch.Tensor] = {} + reduced_syndrome: list[tuple[int, int]] = [] + reduced_recipes: dict[int, int] = {} + syn_pos_to_idx = { + p: i for i, (p, _) in enumerate(self._syndrome_positions) + } + for pos, t in enumerate(reduced_tn.tensors): + if id(t) in merged_id_to_recipe_idx: + reduced_recipes[pos] = merged_id_to_recipe_idx[id(t)] + continue + key = (tuple(t.inds), frozenset(t.tags)) + opt_pos = survivor_lookup[key] + if opt_pos in self._static_arrays: + reduced_static[pos] = self._static_arrays[opt_pos] + elif opt_pos in syn_pos_to_idx: + reduced_syndrome.append((pos, syn_pos_to_idx[opt_pos])) + else: + raise AssertionError( + f"reduced_tn tensor at pos {pos} maps to opt_pos {opt_pos} " + "which isn't classified as static or syndrome") + + # Cotengra rather than opt_einsum's ``auto``: ``auto`` falls + # back to ``greedy`` on 200+-tensor networks, which isn't + # memory-aware and picks paths torch.einsum can't execute at + # large batch. + hyper = ctg.HyperOptimizer(max_repeats=8, parallel=False) + reduced_path, _info = oe.contract_path(reduced_eq, + *reduced_shapes, + shapes=True, + optimize=hyper) + reduced_oe_expr = oe.contract_expression(reduced_eq, + *reduced_shapes, + optimize=reduced_path) + + # Group errors by signature (number of checks) so each class + # runs through one batched torch.einsum instead of one per error. + from collections import defaultdict + recipe_to_reduced_pos = {ri: cp for cp, ri in reduced_recipes.items()} + groups_by_k: dict[int, list[int]] = defaultdict(list) + for ri, r in enumerate(recipes): + groups_by_k[r['k']].append(ri) + + batched_groups: list[dict[str, Any]] = [] + device = self.torch_device + for k, error_indices in sorted(groups_by_k.items()): + # 'n' = batched-error dim, 'e' = contracted error index, + # 'a'..'z' (skipping 'e' and 'n') = output axes. + out_letters: list[str] = [] + next_code = ord('a') + for _ in range(k): + while chr(next_code) in ('e', 'n'): + next_code += 1 + out_letters.append(chr(next_code)) + next_code += 1 + out_str = "".join(out_letters) + check_strs = [f"n{c}e" for c in out_letters] + eq = "ne," + ",".join(check_strs) + "->n" + out_str if k > 0 \ + else "ne->ne" + + stacked_checks = [] + for axis in range(k): + axis_arrays = [ + self._static_arrays[recipes[ri] + ['ordered_check_opt_positions'][axis]] + for ri in error_indices + ] + stacked_checks.append(torch.stack(axis_arrays, dim=0)) + + reduced_positions = [ + recipe_to_reduced_pos[ri] for ri in error_indices + ] + error_indices_t = torch.tensor(error_indices, + dtype=torch.long, + device=device) + + batched_groups.append({ + 'k': k, + 'eq': eq, + 'error_indices_t': error_indices_t, + 'stacked_checks': stacked_checks, + 'reduced_positions': reduced_positions, + }) + + self._reduced_tn = reduced_tn + self._per_error_einsums = recipes + self._batched_einsum_groups = batched_groups + self._reduced_static_positions = reduced_static + self._reduced_syndrome_positions = reduced_syndrome + self._reduced_recipe_positions = reduced_recipes + self._reduced_eq = reduced_eq + self._reduced_oe_expr = reduced_oe_expr + self._reduced_n_tensors = len(reduced_tn.tensors) + + def _build_predict_reduced(self): + """Predict using the reduced TN + per-step batched noise precontraction. + + See :meth:`_build_reduced_tn_state`. + """ + static_positions = self._reduced_static_positions + syndrome_positions = self._reduced_syndrome_positions + batched_groups = self._batched_einsum_groups + oe_expr = self._reduced_oe_expr + n = self._reduced_n_tensors + + def _predict(noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...]) -> torch.Tensor: + noise_stacked = torch.stack((1.0 - noise_probs, noise_probs), + dim=-1) + arrays: list[torch.Tensor] = [None] * n # type: ignore + for pos, arr in static_positions.items(): + arrays[pos] = arr + for pos, syn_idx in syndrome_positions: + arrays[pos] = syndrome_tuple[syn_idx] + for group in batched_groups: + noise_batch = noise_stacked[group['error_indices_t']] + if group['k'] == 0: + out_batch = noise_batch + else: + out_batch = torch.einsum(group['eq'], noise_batch, + *group['stacked_checks']) + for i, pos in enumerate(group['reduced_positions']): + arrays[pos] = out_batch[i] + out = oe_expr(*arrays) + return out / out.sum(dim=1, keepdim=True) + + return _predict + def _build_predict_unrolled(self): """Unrolled predict: walk the cached pairwise contraction path.""" static_arrays = self._static_arrays @@ -592,7 +838,13 @@ def _compile_loss(self) -> None: Two variants are produced: one accepting logits (sigmoid applied inside) and one accepting probabilities directly. """ - if self._execute_mode == "codegen": + # The fused codegen loss bakes obs_idx_true/false as closure + # constants over the full batch; under batch-slicing predict is + # already chunked-and-concatenated, so wrapped CE composes + # correctly while fused CE would not. + use_codegen_loss = (self._execute_mode == "codegen" and + self.batch_slices == 1) + if use_codegen_loss: logits_fn, probs_fn = self._build_loss_codegen() else: logits_fn, probs_fn = self._build_loss_wrapped() @@ -1096,6 +1348,7 @@ def _update_data(self, if shape_changed: self.path_batch = None self.slicing_batch = tuple() + self.batch_slices = 1 try: self._snapshot_arrays_and_eq() finally: @@ -1137,27 +1390,103 @@ def update_dataset(self, self._batch_size = int(new_syndrome_data.shape[0]) self._update_data(syndrome_arrays, new_observable_flips, enforce_shape) - def optimize_path(self, optimize: Any = None, batch_size: int = -1) -> Any: - """Cache a contraction path via quimb and rebuild the JIT. - - Always routes through :meth:`TensorNetwork.contraction_info` so - the resulting path is compatible with :mod:`opt_einsum` and - manual unrolling -- unlike :meth:`TensorNetworkDecoder.optimize_path`, - which defaults to a cuTensorNet-only path. + def optimize_path(self, + optimize: Any = None, + batch_size: int = -1, + network_options: Any = None) -> Any: + """Cache a contraction path and rebuild the JIT. + + Dispatches on the type of ``optimize``: + + * ``None`` (default) or any string / :mod:`opt_einsum` + :class:`PathOptimizer` / :class:`cotengra.HyperOptimizer` -- + route through quimb's :meth:`TensorNetwork.contraction_info`. + ``None`` is treated as ``"auto"``. This is the CPU-safe + default and does not require :mod:`cuquantum`. + * :class:`cuquantum.tensornet.OptimizerOptions` -- route + through :func:`cuquantum.tensornet.contract_path` to use + cuTensorNet's hyper-optimiser. Useful on large networks + where opt_einsum's heuristics underperform. + + The path -- whether from cuTensorNet or quimb -- is a list of + ``(int, int)`` pairs and is consumed directly by + :mod:`opt_einsum` in the executor rebuild, so all three + execute modes (``opt_einsum`` / ``unrolled`` / ``codegen``) + work unchanged. + + .. note:: + + cuTensorNet may return a *sliced* path on memory-pressured + networks. The torch-backed executors used by + :class:`NMOptimizer` cannot honour slice descriptors, so a + sliced result raises :class:`NotImplementedError`. Pass + ``OptimizerOptions(slicing=SlicerOptions(disable_slicing=True))`` + to force an unsliced path, or fall back to ``optimize="auto"``. ``batch_size`` is part of the parent ``TensorNetworkDecoder`` signature (which rebuilds its TN around a fake batch); on the optimiser the syndrome TN is already batched at construction and resized in :meth:`update_dataset`, so this argument is ignored. Kept for Liskov substitution with the parent. + + Example (cuTensorNet path finder):: + + from cuquantum.tensornet.configuration import ( + OptimizerOptions, SlicerOptions, NetworkOptions) + opt.optimize_path( + optimize=OptimizerOptions( + slicing=SlicerOptions(disable_slicing=True)), + network_options=NetworkOptions(memory_limit='8GiB')) + + ``network_options`` is forwarded to + :func:`cuquantum.tensornet.contract_path` as ``options=``. """ del batch_size - info = self.full_tn.contraction_info( - output_inds=("batch_index", self.logical_obs_inds[0]), - optimize=optimize if optimize is not None else "auto", - ) - self.path_batch = info.path - self.slicing_batch = tuple() + + use_cutn = (optimize is not None and + type(optimize).__module__.startswith("cuquantum") and + type(optimize).__name__ == "OptimizerOptions") + + output_inds = ("batch_index", self.logical_obs_inds[0]) + batch_slices = 1 + if use_cutn: + path, info = _optimize_path_dispatch( + optimize, + output_inds, + self.full_tn, + network_options=network_options) + num_slices = getattr(info, "num_slices", 1) + if num_slices > 1: + sliced_modes = getattr(info, "sliced_modes", ()) + non_batch = [ + m for m in sliced_modes + if (m[0] if isinstance(m, tuple) else m) != "batch_index" + ] + if non_batch: + raise NotImplementedError( + "NMOptimizer's batch-dim slicing executor only " + "supports slicing the 'batch_index' mode; " + f"cuTensorNet sliced additional modes: " + f"{non_batch}. Pass OptimizerOptions(slicing=" + "SlicerOptions(disable_slicing=True)) to " + "suppress slicing.") + if not self._dynamic_syndromes: + raise NotImplementedError( + "Sliced contraction paths require " + "dynamic_syndromes=True; rebuild NMOptimizer " + "with dynamic_syndromes=True.") + batch_slices = num_slices + else: + info = self.full_tn.contraction_info( + output_inds=output_inds, + optimize=optimize if optimize is not None else "auto", + ) + path = info.path + + self.path_batch = path + self.slicing_batch = getattr(info, "sliced_modes", + tuple()) if use_cutn else tuple() + self.batch_slices = batch_slices self._snapshot_arrays_and_eq() return info diff --git a/libs/qec/python/tests/test_nm_optimizer.py b/libs/qec/python/tests/test_nm_optimizer.py index 54af6aec..3f4482b7 100644 --- a/libs/qec/python/tests/test_nm_optimizer.py +++ b/libs/qec/python/tests/test_nm_optimizer.py @@ -669,6 +669,154 @@ def test_optimize_path_with_cotengra(device): np.testing.assert_allclose(before, after, atol=1e-10, rtol=1e-10) +def test_optimize_path_with_cutn_optimizer_options(): + """``OptimizerOptions`` routes through cuTensorNet's path finder and + the resulting path is consumed correctly by the opt_einsum-backed + executor.""" + if not _gpu_available(): + pytest.skip("No GPU available; cuTensorNet path finder requires CUDA.") + cutn_cfg = pytest.importorskip("cuquantum.tensornet.configuration") + OptimizerOptions = cutn_cfg.OptimizerOptions + SlicerOptions = cutn_cfg.SlicerOptions + from cuquantum.tensornet.configuration import OptimizerInfo + + rng = np.random.default_rng(90) + H, logical = _nondegenerate_code() + priors = [0.1, 0.15, 0.25] + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=24, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn, + flips, + device="cuda", + dtype="float64") + with torch.no_grad(): + before = opt.decoder_prediction().detach().cpu().numpy() + info = opt.optimize_path(optimize=OptimizerOptions(slicing=SlicerOptions( + disable_slicing=True))) + assert isinstance(info, OptimizerInfo) + assert info.num_slices == 1 + # cuTensorNet returns the path as a list of ``(int, int)`` pairs. + assert isinstance(opt.path_batch, (list, tuple)) and len(opt.path_batch) > 0 + assert all( + isinstance(step, tuple) and len(step) == 2 for step in opt.path_batch) + with torch.no_grad(): + after = opt.decoder_prediction().detach().cpu().numpy() + np.testing.assert_allclose(before, after, atol=1e-10, rtol=1e-10) + + +def test_optimize_path_with_cutn_sliced_forward_matches_unsliced(): + """A cuTensorNet-sliced path (forced via ``min_slices``) executes + through NMOptimizer's batch-slicing wrapper and produces the same + forward output as the unsliced path.""" + if not _gpu_available(): + pytest.skip("No GPU available; cuTensorNet path finder requires CUDA.") + cutn_cfg = pytest.importorskip("cuquantum.tensornet.configuration") + OptimizerOptions = cutn_cfg.OptimizerOptions + SlicerOptions = cutn_cfg.SlicerOptions + + rng = np.random.default_rng(91) + H, logical = _nondegenerate_code() + priors = [0.1, 0.15, 0.25] + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=24, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn, + flips, + device="cuda", + dtype="float64") + with torch.no_grad(): + ref = opt.decoder_prediction().detach().cpu().numpy() + + # ``min_slices=4`` forces at least four slices on the batch_index + # mode; the wrapper splits, contracts, and concatenates the result. + forced = OptimizerOptions(slicing=SlicerOptions(min_slices=4)) + info = opt.optimize_path(optimize=forced) + assert info.num_slices >= 4 + assert opt.batch_slices == info.num_slices + + with torch.no_grad(): + sliced = opt.decoder_prediction().detach().cpu().numpy() + np.testing.assert_allclose(sliced, ref, atol=1e-10, rtol=1e-10) + + +def test_optimize_path_with_cutn_sliced_gradients_flow(): + """Gradients on the noise priors are non-zero after a sliced + contraction; ``torch.cat`` preserves the autograd graph.""" + if not _gpu_available(): + pytest.skip("No GPU available; cuTensorNet path finder requires CUDA.") + cutn_cfg = pytest.importorskip("cuquantum.tensornet.configuration") + OptimizerOptions = cutn_cfg.OptimizerOptions + SlicerOptions = cutn_cfg.SlicerOptions + + rng = np.random.default_rng(92) + H, logical = _nondegenerate_code() + priors = [0.1, 0.15, 0.25] + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=24, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn, + flips, + device="cuda", + dtype="float64") + opt.optimize_path(optimize=OptimizerOptions(slicing=SlicerOptions( + min_slices=4))) + assert opt.batch_slices >= 4 + + loss = opt.loss_fn(from_logits=False)(opt._noise_probs, opt._syndrome_tuple) + loss.backward() + grad = opt._noise_probs.grad + assert grad is not None + assert torch.isfinite(grad).all() + assert grad.abs().sum().item() > 0.0 + + +def test_optimize_path_with_cutn_sliced_rejects_static_codegen(): + """Sliced paths require ``dynamic_syndromes=True`` because static + codegen bakes per-batch syndromes into the closure.""" + if not _gpu_available(): + pytest.skip("No GPU available; cuTensorNet path finder requires CUDA.") + cutn_cfg = pytest.importorskip("cuquantum.tensornet.configuration") + OptimizerOptions = cutn_cfg.OptimizerOptions + SlicerOptions = cutn_cfg.SlicerOptions + + rng = np.random.default_rng(93) + H, logical = _nondegenerate_code() + priors = [0.1, 0.15, 0.25] + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=24, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn, + flips, + device="cuda", + dtype="float64", + execute="codegen", + dynamic_syndromes=False) + forced = OptimizerOptions(slicing=SlicerOptions(min_slices=2)) + with pytest.raises(NotImplementedError, match="dynamic_syndromes"): + opt.optimize_path(optimize=forced) + + # -- remap_eq_to_ascii ------------------------------------------------------- From adfbaeb8e1f46e375b5d7eca1a042d53ac426725 Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Mon, 1 Jun 2026 12:46:37 -0400 Subject: [PATCH 2/2] enhance nm optimizer for larger distance Signed-off-by: vedika-saravanan --- .../tensor_network_utils/nm_optimizer.py | 258 +++++++++++------- 1 file changed, 155 insertions(+), 103 deletions(-) diff --git a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py index 6f18b016..b17438e3 100644 --- a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py +++ b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py @@ -24,6 +24,7 @@ import numpy.typing as npt import opt_einsum as oe import torch +from torch.utils.checkpoint import checkpoint as _checkpoint from quimb.tensor import TensorNetwork from ..tensor_network_decoder import TensorNetworkDecoder @@ -237,8 +238,6 @@ def __init__( "precontract_noise=True requires execute='opt_einsum'; " f"got {execute!r}") - # Sanitise once so the base TN tensors and ``self._noise_probs`` - # see identical values (see :func:`_validate_and_clamp_priors`). noise_model = _validate_and_clamp_priors(noise_model, dtype) super().__init__( @@ -254,10 +253,7 @@ def __init__( device=device, ) - # Force the torch backend so tensor data lives in the autograd - # graph (the base class would otherwise pick cutensornet/numpy - # on GPU). Contractions still go through codegen / cotengra - # below, not cuTensorNet. + # Force the torch backend so tensor data lives in the autograd graph. if self.contractor_config.contractor_name == "cutensornet" \ and self.contractor_config.backend != "torch": warnings.warn( @@ -275,7 +271,6 @@ def __init__( dtype, ) - # Swap the base's placeholder single-syndrome TN for a batched one. self._syndrome_tags = [f"SYN_{i}" for i in range(len(self.check_inds))] self.syndrome_tn = tensor_network_from_syndrome_batch( syndrome_data, @@ -285,7 +280,6 @@ def __init__( ) self._batch_size = syndrome_data.shape[0] - # Re-stitch ``full_tn`` around the batched syndrome TN. self.full_tn = TensorNetwork() self.full_tn = self.full_tn.combine(self.code_tn, virtual=True) self.full_tn = self.full_tn.combine(self.logical_tn, virtual=True) @@ -301,10 +295,9 @@ def __init__( device=self.torch_device, requires_grad=True, ) - # The base's noise tensors stay in ``full_tn`` as numpy - # placeholders: ``_snapshot_arrays_and_eq`` uses ``id()`` to - # locate their positions, then ``self._noise_probs`` (autograd - # live) is written into those slots. Do not strip them. + # Noise placeholders stay in ``full_tn``; ``_snapshot_arrays_and_eq`` + # locates them by ``id()`` and writes ``self._noise_probs`` into + # those slots. Do not strip them. self._suspend_loss_rebuild = True self.observable_flips = observable_flips @@ -314,6 +307,8 @@ def __init__( self._torch_compile_mode = compile_mode self._dynamic_syndromes = dynamic_syndromes self._precontract_noise = precontract_noise + self._reduced_optimize: Any = None + self._reduced_network_options: Any = None self._compiled_predict: Any | None = None self._syndrome_tuple: tuple[torch.Tensor, ...] = () self.batch_slices: int = 1 @@ -405,10 +400,9 @@ def _snapshot_arrays_and_eq(self) -> None: else: self._static_positions.append(i) - # Guard against a future quimb that copies tensors on virtual - # combine: every tensor in ``full_tn`` must classify into - # exactly one bucket, else the predict path rebuilds the - # operand list with a None slot or a misplaced placeholder. + # Every tensor in ``full_tn`` must classify into exactly one bucket; + # a future quimb that copies tensors on virtual combine would break + # this and put a None or misplaced placeholder in the operand list. n_classified = (len(self._noise_pos_for_error) + len(syndrome_positions_list) + len(self._static_positions)) @@ -439,29 +433,29 @@ def _as_torch(x): for i in syndrome_positions_list ] self._syndrome_tuple = tuple(self._syndrome_arrays) - # Used by :meth:`_update_data` to detect layout changes that - # invalidate the cached path / codegen / oe expr. + # Used by :meth:`_update_data` to detect layout changes. self._syndrome_shapes: tuple[tuple[int, ...], ...] = tuple( tuple(s.shape) for s in self._syndrome_arrays) if self._execute_mode == "opt_einsum": - shapes = tuple(t.shape for t in tensors) - self._oe_expr = oe.contract_expression( - self._eq_batch, - *shapes, - optimize=self.path_batch - if self.path_batch not in (None, "auto") else "auto", - ) - self._path_steps = None if self._precontract_noise: + # ``_predict`` uses ``_reduced_oe_expr``; skip full_tn build. + self._oe_expr = None self._build_reduced_tn_state() + else: + shapes = tuple(t.shape for t in tensors) + self._oe_expr = oe.contract_expression( + self._eq_batch, + *shapes, + optimize=self.path_batch + if self.path_batch not in (None, "auto") else "auto", + ) + self._path_steps = None else: self._oe_expr = None - # Flatten the path into ``[(eq, idxs, sorted_desc), ...]``; - # ``sorted_desc`` is precomputed for the unrolled-mode pop - # walk, and labels are remapped to ASCII because - # opt_einsum falls back to unicode past 52 indices and - # torch.einsum rejects those. + # Flatten to ``[(eq, idxs, sorted_desc), ...]``; sorted_desc is + # the unrolled-mode pop walk, ASCII remap dodges torch.einsum's + # rejection of opt_einsum's >52-index unicode fallback. shapes = tuple(t.shape for t in tensors) _, info = oe.contract_path( self._eq_batch, @@ -539,8 +533,6 @@ def _predict(noise_probs: torch.Tensor, arrays[pos] = arr for k, pos in enumerate(noise_pos_ordered): arrays[pos] = noise_stacked[k] - # Torch backend is auto-selected from the tensor type; - # avoids the per-call ``backend=`` dispatch. out = oe_expr(*arrays) return out / out.sum(dim=1, keepdim=True) @@ -587,8 +579,8 @@ def _build_reduced_tn_state(self) -> None: new_t = new_ts[0] merged_id_to_recipe_idx[id(new_t)] = error_idx - # Quimb's index order on the merged tensor -- the einsum - # output must match it so axes align in reduced_tn. + # Einsum output must match quimb's index order on the merged + # tensor so axes align in reduced_tn. quimb_out_inds = tuple(new_t.inds) mapping = {error_ind: 'e'} next_code = ord('a') @@ -602,10 +594,10 @@ def _build_reduced_tn_state(self) -> None: "".join(mapping[i] for i in ct.inds) for ct in check_ts ] out_str = "".join(mapping[i] for i in quimb_out_inds) - # Canonical order: ordered_check_opt_positions[axis] is the - # opt-array position of the check tensor whose non-error - # index is quimb_out_inds[axis]. This lets us batch every - # error in a signature class through a single torch.einsum. + # ordered_check_opt_positions[axis] holds the opt-array position + # of the check tensor whose non-error index is + # quimb_out_inds[axis] — needed for batching every error in a + # signature class through one torch.einsum. ordered_check_opt_positions: list[int] = [None] * len( check_ts) # type: ignore for ct, ct_pos in zip(check_ts, check_opt_positions): @@ -644,21 +636,90 @@ def _build_reduced_tn_state(self) -> None: f"reduced_tn tensor at pos {pos} maps to opt_pos {opt_pos} " "which isn't classified as static or syndrome") - # Cotengra rather than opt_einsum's ``auto``: ``auto`` falls - # back to ``greedy`` on 200+-tensor networks, which isn't - # memory-aware and picks paths torch.einsum can't execute at - # large batch. - hyper = ctg.HyperOptimizer(max_repeats=8, parallel=False) - reduced_path, _info = oe.contract_path(reduced_eq, - *reduced_shapes, - shapes=True, - optimize=hyper) + # Score a path for torch.einsum usability: max tensor rank across + # contraction steps. torch.einsum hard-caps at 25 dims per tensor. + def _max_step_rank(path: Any) -> int: + _, info = oe.contract_path(reduced_eq, + *reduced_shapes, + shapes=True, + optimize=path) + max_rank = 0 + for step in info.contraction_list: + eq = step[2] + if "->" in eq: + lhs, rhs = eq.split("->") + else: + lhs, rhs = eq, "" + for part in lhs.split(","): + max_rank = max(max_rank, len(set(part))) + max_rank = max(max_rank, len(set(rhs))) + return max_rank + + # Cotengra is stochastic; retry until we land an executable path + # (max_step_rank <= 25) or exhaust attempts. + cotengra_retries = 8 + ctg_path = ctg_info = None + for attempt in range(cotengra_retries): + hyper = ctg.HyperOptimizer(max_repeats=8, parallel=False) + p, info = oe.contract_path(reduced_eq, + *reduced_shapes, + shapes=True, + optimize=hyper) + rank = _max_step_rank(p) + if (ctg_info is None or + (_max_step_rank(ctg_path) > 25 and rank <= 25) or + (rank <= 25 and float(info.largest_intermediate) < float( + ctg_info.largest_intermediate))): + ctg_path, ctg_info = p, info + if rank <= 25: + break + candidates = [("cotengra", ctg_path, ctg_info)] + + if self._reduced_optimize is not None: + usr_path, usr_info = _optimize_path_dispatch( + self._reduced_optimize, + ("batch_index", self.logical_obs_inds[0]), + reduced_tn, + network_options=self._reduced_network_options, + ) + candidates.append(("user", usr_path, usr_info)) + + # Score by (unexecutable, largest_intermediate, rank). Unexecutable + # paths (rank > 25) always lose to executable ones. + def _score(c: tuple) -> tuple: + _tag, path, info = c + li = getattr(info, "largest_intermediate", None) + li = float("inf") if li is None else float(li) + rank = _max_step_rank(path) + return (rank > 25, li, rank) + + which, reduced_path, reduced_info = min(candidates, key=_score) + for c in candidates: + tag, _p, info = c + executable, li, rank = _score(c) + oc = getattr(info, "opt_cost", float("nan")) + warnings.warn( + f"reduced TN candidate ({tag}{'*' if tag == which else ''}): " + f"opt_cost={oc:.3e} largest_intermediate={li:.3e} " + f"max_step_rank={rank}" + + (" [unexecutable]" if executable else ""), + UserWarning, + stacklevel=2, + ) reduced_oe_expr = oe.contract_expression(reduced_eq, *reduced_shapes, optimize=reduced_path) - # Group errors by signature (number of checks) so each class - # runs through one batched torch.einsum instead of one per error. + _, step_info = oe.contract_path(reduced_eq, + *reduced_shapes, + shapes=True, + optimize=reduced_path) + reduced_path_steps = [(remap_eq_to_ascii(step[2]), tuple(step[0]), + tuple(sorted(step[0], reverse=True))) + for step in step_info.contraction_list] + + # Group errors by check count so each class runs through one + # batched torch.einsum instead of one per error. from collections import defaultdict recipe_to_reduced_pos = {ri: cp for cp, ri in reduced_recipes.items()} groups_by_k: dict[int, list[int]] = defaultdict(list) @@ -668,8 +729,7 @@ def _build_reduced_tn_state(self) -> None: batched_groups: list[dict[str, Any]] = [] device = self.torch_device for k, error_indices in sorted(groups_by_k.items()): - # 'n' = batched-error dim, 'e' = contracted error index, - # 'a'..'z' (skipping 'e' and 'n') = output axes. + # 'n' = batched-error dim, 'e' = contracted error index. out_letters: list[str] = [] next_code = ord('a') for _ in range(k): @@ -714,6 +774,7 @@ def _build_reduced_tn_state(self) -> None: self._reduced_recipe_positions = reduced_recipes self._reduced_eq = reduced_eq self._reduced_oe_expr = reduced_oe_expr + self._reduced_path_steps = reduced_path_steps self._reduced_n_tensors = len(reduced_tn.tensors) def _build_predict_reduced(self): @@ -745,7 +806,12 @@ def _predict(noise_probs: torch.Tensor, *group['stacked_checks']) for i, pos in enumerate(group['reduced_positions']): arrays[pos] = out_batch[i] - out = oe_expr(*arrays) + + # Gradient-checkpoint the main reduced-TN contraction. + if torch.is_grad_enabled() and noise_probs.requires_grad: + out = _checkpoint(oe_expr, *arrays, use_reentrant=False) + else: + out = oe_expr(*arrays) return out / out.sum(dim=1, keepdim=True) return _predict @@ -838,10 +904,9 @@ def _compile_loss(self) -> None: Two variants are produced: one accepting logits (sigmoid applied inside) and one accepting probabilities directly. """ - # The fused codegen loss bakes obs_idx_true/false as closure - # constants over the full batch; under batch-slicing predict is - # already chunked-and-concatenated, so wrapped CE composes - # correctly while fused CE would not. + # Codegen loss bakes obs_idx_true/false as closure constants; + # under batch-slicing predict is already chunked-and-concatenated, + # so the wrapped CE composes correctly while fused CE would not. use_codegen_loss = (self._execute_mode == "codegen" and self.batch_slices == 1) if use_codegen_loss: @@ -971,8 +1036,7 @@ def _codegen_partial_eval(n, static_arrays, syndrome_positions, static_positions = sorted(static_arrays.keys()) noise_pos_set = set(noise_pos_ordered) syn_pos_set = set(syndrome_positions) - # O(1) reverse lookup tables (avoid repeated list.index() inside - # the per-step loop below — was O(N^2) for large path lengths). + # O(1) reverse lookups; the per-step list.index() was O(N^2). noise_pos_to_k = {pos: k for k, pos in enumerate(noise_pos_ordered)} syn_pos_to_sidx = { pos: sidx for sidx, pos in enumerate(syndrome_positions) @@ -1000,11 +1064,8 @@ def _codegen_partial_eval(n, static_arrays, syndrome_positions, closure_vars: dict[str, torch.Tensor] = {} runtime_lines: list[str] = [] - # Names that the emitted source actually references and that must - # be available in the closure namespace. We track this - # structurally as we go instead of re-parsing the source string, - # which is both faster and immune to lexical false positives / - # negatives (e.g. if an einsum equation contained an underscore). + # Track referenced closure names structurally rather than parsing + # the emitted source — faster and immune to lexical false matches. used_static: set[str] = set() n_folded = 0 @@ -1026,9 +1087,8 @@ def _codegen_partial_eval(n, static_arrays, syndrome_positions, n_folded += 1 else: arg_names = [p[0] for p in picked] - # Track which closure names this line will read. ``_n*`` - # names are header-built locals (from noise_probs) so - # they are *not* closure values; everything else is. + # ``_n*`` are header-built locals; everything else is a + # closure value that must be wired into used_static. for name in arg_names: if name.startswith(("_C", "_P")): used_static.add(name) @@ -1068,8 +1128,7 @@ def _emit_noise_header(noise_pos_ordered, else: lines.append(" _p = noise_probs") lines.append(" _q = 1.0 - _p") - # One stack of shape (K, 2) instead of K stacks of shape (2,). - # ``dim=1`` makes ``_NS[k]`` a contiguous view of length 2. + # One (K, 2) stack; ``dim=1`` makes ``_NS[k]`` a contiguous slice. lines.append(" _NS = torch.stack((_q, _p), dim=1)") for k in range(len(noise_pos_ordered)): lines.append(f" _n{k} = _NS[{k}]") @@ -1123,9 +1182,7 @@ def _build_codegen_predict(cls, body.append("def _predict(noise_probs):") if fully_static: - # Degenerate case: contraction didn't depend on noise (or any - # other dynamic input). Pre-normalise the constant and - # return it unchanged on every call. + # Contraction didn't depend on noise; return the constant. with torch.no_grad(): normed = final_value / final_value.sum(dim=1, keepdim=True) closure_vars["_FINAL"] = normed @@ -1190,15 +1247,10 @@ def _build_codegen_loss(cls, body.append("def _loss(noise_probs):") if fully_static: - # The contraction is a constant; the loss it produces is - # also a constant. We still need the gradient wrt - # noise_probs to be zero (a real-valued zero tensor with a - # graph edge), so emit a no-op multiplication by - # noise_probs.sum() * 0 to keep autograd happy. + # Loss is constant; emit a 0 * noise_probs.sum() term so + # autograd still produces a zero gradient with a graph edge. with torch.no_grad(): normed = final_value / final_value.sum(dim=1, keepdim=True) - # Compute the loss eagerly; we can't fold it because - # autograd needs a path back to noise_probs. ce = (-torch.log(normed[obs_idx_true, 1]).sum() - torch.log(normed[obs_idx_false, 0]).sum()) closure_vars["_LOSS"] = ce @@ -1211,16 +1263,12 @@ def _build_codegen_loss(cls, cls._emit_syndrome_header(syndrome_positions, dynamic_syndromes)) body.extend(runtime_lines) - # Fused cross-entropy that skips the explicit - # ``_preds = _out / _out.sum(dim=1, keepdim=True)`` step. - # OBS_T and OBS_F partition the batch (every row is in exactly - # one), so: - # -log(p_T[:,1]).sum() - log(p_F[:,0]).sum() - # = log(Z).sum() - log(_out[OBS_T,1]).sum() - # - log(_out[OBS_F,0]).sum() - # where Z = _out[:,0] + _out[:,1]. Saves one (shots, 2) - # division + materialisation and the corresponding backward - # nodes -- ~2-5% per-step on CPU at d=3/r=3. + # Fused CE: with Z = _out[:,0] + _out[:,1] and OBS_T/OBS_F + # partitioning the batch, + # -log(p_T[:,1]).sum() - log(p_F[:,0]).sum() + # = log(Z).sum() - log(_out[OBS_T,1]).sum() + # - log(_out[OBS_F,0]).sum() + # — skips the explicit (shots, 2) normalisation step. body.append(f" _out = {final_name}") body.append(" _z0 = _out[:, 0]") body.append(" _z1 = _out[:, 1]") @@ -1307,8 +1355,8 @@ def _update_data(self, device, shape ``(syndrome_length, shots, 2)``). Public callers should use :meth:`update_dataset` instead. """ - # Patch syndrome tensor data in the quimb TN in place; the - # cotengra path is invalidated below if any shape changed. + # Patch syndrome data in the quimb TN in place; the cached path + # is invalidated below if any shape changed. for i, tag in enumerate(self._syndrome_tags): t = self.syndrome_tn.tensors[next( iter(self.syndrome_tn.tag_map[tag]))] @@ -1318,8 +1366,8 @@ def _update_data(self, f"{t.data.shape} vs {new_syndrome_arrays[i].shape}") t.modify(data=new_syndrome_arrays[i]) - # Suppress the loss rebuild the ``observable_flips`` setter - # would otherwise trigger; one of the branches below issues it. + # Suppress the rebuild the observable_flips setter would trigger; + # a branch below issues it. self._suspend_loss_rebuild = True self.observable_flips = new_observable_flips @@ -1338,12 +1386,9 @@ def _update_data(self, new_shapes.append(tuple(arr.shape)) new_shapes_tuple = tuple(new_shapes) - # Shape change: cached path / codegen / oe expression / compile - # guards are all stale. Drop the path and rebuild from scratch. - # Shapes unchanged: dynamic codegen and the unrolled / - # opt_einsum paths read syndromes per call — refreshing the - # cached tuple is enough. Static codegen baked the old tensors - # into the closure and still needs a full rebuild. + # Shape change ⇒ everything cached is stale; full rebuild. + # Same shapes ⇒ dynamic modes only need the tuple refreshed; + # static codegen baked the old tensors and still rebuilds. shape_changed = new_shapes_tuple != self._syndrome_shapes if shape_changed: self.path_batch = None @@ -1362,8 +1407,7 @@ def _update_data(self, finally: self._suspend_loss_rebuild = False else: - # The observable indices may have changed; the loss bakes - # them in, so it still needs a rebuild. + # Observable indices may have changed; loss bakes them in. self._suspend_loss_rebuild = False self._compile_loss() @@ -1443,6 +1487,14 @@ def optimize_path(self, """ del batch_size + if self._precontract_noise: + # Apply the user's path-finder to the reduced TN (not full_tn). + # Cached on the instance so update_dataset rebuilds reuse it. + self._reduced_optimize = optimize + self._reduced_network_options = network_options + self._snapshot_arrays_and_eq() + return None + use_cutn = (optimize is not None and type(optimize).__module__.startswith("cuquantum") and type(optimize).__name__ == "OptimizerOptions") @@ -1510,8 +1562,8 @@ def make_compiled_step(optimizer: NMOptimizer, logits: torch.Tensor, torch_optimizer: A ``torch.optim`` instance owning ``logits``. """ - # Re-fetch per call so a rebuild from update_dataset or the - # observable_flips setter is picked up; capturing would go stale. + # Re-fetch per call so update_dataset / observable_flips rebuilds + # are picked up; capturing would go stale. def _step(): torch_optimizer.zero_grad(set_to_none=True) loss = optimizer.loss_fn(from_logits=True)(