From fb395191a781f22a7ba338c2a1719b59348f3968 Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Thu, 30 Apr 2026 13:05:55 -0400 Subject: [PATCH 1/5] noide model integration Signed-off-by: vedika-saravanan --- .github/workflows/all_libs.yaml | 2 +- .github/workflows/all_libs_release.yaml | 2 +- .github/workflows/lib_qec.yaml | 4 +- .../api/qec/tensor_network_decoder_api.rst | 183 ++- docs/sphinx/components/qec/introduction.rst | 21 + .../examples/qec/python/noise_learning.py | 161 +++ docs/sphinx/examples_rst/qec/decoders.rst | 25 + libs/qec/pyproject.toml.cu12 | 2 + libs/qec/pyproject.toml.cu13 | 2 + .../decoders/tensor_network_decoder.py | 14 + .../tensor_network_utils/noise_models.py | 1205 ++++++++++++++++- libs/qec/python/tests/test_tn_noise_models.py | 826 +++++++++++ 12 files changed, 2438 insertions(+), 9 deletions(-) create mode 100644 docs/sphinx/examples/qec/python/noise_learning.py create mode 100644 libs/qec/python/tests/test_tn_noise_models.py diff --git a/.github/workflows/all_libs.yaml b/.github/workflows/all_libs.yaml index 78a2c312..7cd082b1 100644 --- a/.github/workflows/all_libs.yaml +++ b/.github/workflows/all_libs.yaml @@ -113,7 +113,7 @@ jobs: # Install the correct torch first. cuda_no_dot=$(echo ${{ matrix.cuda_version }} | sed 's/\.//') pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu${cuda_no_dot} - pip install numpy pytest onnxscript cuquantum-cu${{ steps.config.outputs.cuda_major }} lightning ml_collections mpi4py transformers quimb opt_einsum nvidia-cublas "cuquantum-python-cu${{ steps.config.outputs.cuda_major }}>=26.3.0" + pip install numpy pytest onnxscript cuquantum-cu${{ steps.config.outputs.cuda_major }} lightning ml_collections mpi4py transformers quimb opt_einsum cotengra nvidia-cublas "cuquantum-python-cu${{ steps.config.outputs.cuda_major }}>=26.3.0" # The following tests are needed for docs/sphinx/examples/qec/python/tensor_network_decoder.py. if [ "$(uname -m)" == "x86_64" ]; then # Stim is not currently available on manylinux ARM wheels, so only diff --git a/.github/workflows/all_libs_release.yaml b/.github/workflows/all_libs_release.yaml index a5e4554b..4dd37c6c 100644 --- a/.github/workflows/all_libs_release.yaml +++ b/.github/workflows/all_libs_release.yaml @@ -162,7 +162,7 @@ jobs: # Install the correct torch first. cuda_no_dot=$(echo ${{ matrix.cuda_version }} | sed 's/\.//') pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu${cuda_no_dot} - pip install numpy pytest onnxscript cuquantum-cu${{ steps.config.outputs.cuda_major }} lightning ml_collections mpi4py transformers quimb opt_einsum nvidia-cublas "cuquantum-python-cu${{ steps.config.outputs.cuda_major }}>=26.3.0" + pip install numpy pytest onnxscript cuquantum-cu${{ steps.config.outputs.cuda_major }} lightning ml_collections mpi4py transformers quimb opt_einsum cotengra nvidia-cublas "cuquantum-python-cu${{ steps.config.outputs.cuda_major }}>=26.3.0" # The following tests are needed for docs/sphinx/examples/qec/python/tensor_network_decoder.py. if [ "$(uname -m)" == "x86_64" ]; then # Stim is not currently available on manylinux ARM wheels, so only diff --git a/.github/workflows/lib_qec.yaml b/.github/workflows/lib_qec.yaml index 87103b82..fa93a0d9 100644 --- a/.github/workflows/lib_qec.yaml +++ b/.github/workflows/lib_qec.yaml @@ -104,7 +104,7 @@ jobs: # Install the correct torch first. cuda_no_dot=$(echo ${{ matrix.cuda_version }} | sed 's/\.//') pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu${cuda_no_dot} - pip install numpy pytest onnxscript cuquantum-cu${{ steps.config.outputs.cuda_major }} quimb opt_einsum nvidia-cublas "cuquantum-python-cu${{ steps.config.outputs.cuda_major }}>=26.3.0" + pip install numpy pytest onnxscript cuquantum-cu${{ steps.config.outputs.cuda_major }} quimb opt_einsum cotengra nvidia-cublas "cuquantum-python-cu${{ steps.config.outputs.cuda_major }}>=26.3.0" # The following tests are needed for docs/sphinx/examples/qec/python/tensor_network_decoder.py. if [ "$(uname -m)" == "x86_64" ]; then # Stim is not currently available on manylinux ARM wheels, so only @@ -216,7 +216,7 @@ jobs: run: | cuda_no_dot=$(echo ${{ matrix.cuda_version }} | sed 's/\.//') pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu${cuda_no_dot} - pip install numpy pytest cuquantum-cu${{ steps.config.outputs.cuda_major }} quimb opt_einsum nvidia-cublas "cuquantum-python-cu${{ steps.config.outputs.cuda_major }}>=26.3.0" "custabilizer-cu${{ steps.config.outputs.cuda_major }}>=0.3.0" + pip install numpy pytest cuquantum-cu${{ steps.config.outputs.cuda_major }} quimb opt_einsum cotengra nvidia-cublas "cuquantum-python-cu${{ steps.config.outputs.cuda_major }}>=26.3.0" "custabilizer-cu${{ steps.config.outputs.cuda_major }}>=0.3.0" - name: Run GPU C++ tests run: | diff --git a/docs/sphinx/api/qec/tensor_network_decoder_api.rst b/docs/sphinx/api/qec/tensor_network_decoder_api.rst index b5bbc600..16279de2 100644 --- a/docs/sphinx/api/qec/tensor_network_decoder_api.rst +++ b/docs/sphinx/api/qec/tensor_network_decoder_api.rst @@ -93,4 +93,185 @@ :param optimize: Optimization options or None :param batch_size: (int, optional) Batch size for optimization (default: -1, no batching) - :returns: Optimizer info object \ No newline at end of file + :returns: Optimizer info object + +.. class:: cudaq_qec.plugins.decoders.tensor_network_decoder.NMOptimizer + + Differentiable noise-model optimizer built on top of :class:`TensorNetworkDecoder`. + + Fits a factorised per-error noise model to a syndrome dataset by + backpropagating through a torch-backed tensor-network contraction. + The noise probabilities are maintained as ``torch`` tensors with + ``requires_grad=True`` so they can be updated with any ``torch.optim`` + optimizer. + + Requires Python 3.11 or higher and the same optional dependencies as + :class:`TensorNetworkDecoder` (``pip install cudaq-qec[tensor-network-decoder]``). + PyTorch must also be installed. + + .. note:: + Quick-start example (logit-space training; the loss has no ``log`` + guard, so direct probability training requires per-step clamping + into ``[eps, 1 - eps]``):: + + import numpy as np + import torch + from cudaq_qec.plugins.decoders.tensor_network_decoder import ( + NMOptimizer, make_compiled_step, + ) + + H = np.array([[1, 1, 0], [0, 1, 1]], dtype=np.float64) + logical = np.array([[1, 0, 1]], dtype=np.float64) + priors = [0.1, 0.2, 0.3] + + opt = NMOptimizer(H, logical, priors, syndrome_data, obs_flips, + dtype="float64") + logits = torch.logit(opt.noise_params[0].detach()).requires_grad_() + adam = torch.optim.Adam([logits], lr=0.01) + step = make_compiled_step(opt, logits, adam) + for _ in range(100): + step() + + :param H: Parity check matrix (numpy.ndarray), shape (num_checks, num_errors) + :param logical_obs: Logical observable matrix (numpy.ndarray), shape (1, num_errors) + :param noise_model: Initial per-error probabilities, list of floats in (0, 1). + Values outside ``[eps, 1 - eps]`` are clamped at + construction with a ``UserWarning``; non-finite values + raise ``ValueError``. ``eps`` is ``1e-12`` for + ``"float64"`` and ``1e-6`` for ``"float32"``. + :param syndrome_data: Observed syndromes, numpy.ndarray of shape (num_shots, num_checks) + :param observable_flips: Observed logical flips, bool array of length num_shots + :param check_inds: (optional) List of check index names; defaults track the parent decoder. + :param error_inds: (optional) List of error index names; defaults track the parent decoder. + :param logical_inds: (optional) List of logical index names; defaults track the parent decoder. + :param logical_tags: (optional) List of logical tags; defaults track the parent decoder. + :param dtype: (str, optional) ``"float32"`` (default) or ``"float64"``; + other values raise ``ValueError``. + :param device: (str, optional) Torch device, e.g. ``"cpu"`` or ``"cuda"`` (default: ``"cuda"``) + :param compile: (bool, optional, keyword-only) If ``True``, wrap the forward + and loss in :func:`torch.compile`. Most useful with + ``execute="codegen"``. Defaults to ``False``. + :param execute: (str, optional, keyword-only) Forward backend. ``"codegen"`` + (default) partial-evaluates the contraction path into a flat + Python function with named locals; ``"unrolled"`` keeps an + interpretive einsum list; ``"opt_einsum"`` dispatches via + :func:`opt_einsum.contract_expression`. + :param compile_mode: (str, optional, keyword-only) Forwarded to + :func:`torch.compile` (e.g. ``"reduce-overhead"``, + ``"default"``); ignored when ``compile=False``. + :param dynamic_syndromes: (bool, optional, keyword-only) If ``True`` + (default), syndromes are runtime arguments to the + compiled forward, so :meth:`update_dataset` reuses + the codegen/``torch.compile`` artifact when shapes + are unchanged. ``False`` bakes syndromes into the + closure -- faster per call but every + :meth:`update_dataset` rebuilds the graph. Only + affects ``execute="codegen"``. + + **Attributes** + + .. attribute:: noise_params + + ``list[torch.Tensor]`` — the learnable noise-probability tensors; pass + directly to a ``torch.optim`` optimizer. + + .. attribute:: torch_device + + ``torch.device`` derived from the ``device`` constructor argument. + Read-only. + + .. attribute:: observable_flips + + Bool ``torch.Tensor`` of logical flip outcomes for the current + syndrome batch. Assigning a new value also rebuilds the fused + loss closure (the observable indices are baked into the codegen); + prefer :meth:`update_dataset` when swapping syndromes and flips + together. + + **Methods** + + .. method:: current_syndrome_args() + + Return the syndrome argument expected by the callable from + :meth:`loss_fn`: the live tuple when ``dynamic_syndromes=True``, + or ``()`` for static codegen (syndromes are closure-baked). + Re-fetch each step so an intervening :meth:`update_dataset` is + reflected. + + :returns: ``tuple[torch.Tensor, ...]`` + + .. method:: cross_entropy_loss() + + Compute the cross-entropy loss between the predicted logical-flip + probabilities and the observed ``observable_flips``. + + :returns: Scalar ``torch.Tensor`` (differentiable). + + .. method:: decoder_prediction() + + Run the forward pass and return per-shot probabilities. + + :returns: ``torch.Tensor`` of shape ``(num_shots, 2)`` where column 1 + is ``P(logical flip | syndrome)``. + + .. method:: logical_error_rate() + + Fraction of shots where ``argmax`` of :meth:`decoder_prediction` + disagrees with :attr:`observable_flips`. Not differentiable + (runs under :func:`torch.no_grad`). + + :returns: ``float`` in ``[0, 1]``. + + .. method:: loss_fn(from_logits=True) + + Return a compiled callable ``fn(params, syndrome_tuple) -> loss`` + suitable for use with external optimizers or ``torch.compile``. + + :param from_logits: If ``True`` (default), ``params`` are interpreted + as logits and passed through ``sigmoid`` before + contraction. If ``False``, ``params`` are + interpreted as probabilities already in ``[0, 1]``. + :returns: Compiled loss function. + + .. method:: optimize_path(optimize=None, batch_size=-1) + + Cache a contraction path via quimb / opt_einsum and rebuild the + compiled forward. Pass e.g. ``cotengra.HyperOptimizer()`` to run a + more expensive path search; ``None`` falls back to ``"auto"``. + + :param optimize: Optimization options (e.g. a ``cotengra.HyperOptimizer``) + or ``None``. + :param batch_size: Accepted for signature compatibility; ignored. + :returns: Contraction info object. + + .. method:: update_dataset(syndrome_data, observable_flips, enforce_shape=True) + + Swap in a new syndrome batch without rebuilding the tensor network. + If ``dynamic_syndromes=True`` and the batch size is unchanged, the + compiled contraction path is reused; a shape change triggers a full + rebuild. + + :param syndrome_data: numpy.ndarray of shape (num_shots, num_checks) + :param observable_flips: bool array of length num_shots + :param enforce_shape: (bool, optional, default ``True``) Assert + per-tensor shapes match the existing layout + before patching in place. A batch-size change + triggers a full rebuild regardless. + +.. function:: cudaq_qec.plugins.decoders.tensor_network_decoder.make_compiled_step(optimizer, logits, torch_optimizer) + + Build a no-arg callable that runs one Adam step and returns the loss. + + The returned ``step()`` callable zeros gradients, evaluates the + optimizer's fused ``loss_fn(from_logits=True)`` (sigmoid + contraction + + cross-entropy), backpropagates, and steps ``torch_optimizer``. Intended + for training in logit space; pair with :class:`NMOptimizer` constructed + with ``compile=True`` for a ``torch.compile``-d variant. + + :param optimizer: An :class:`NMOptimizer` instance providing the fused + inner loss. + :param logits: Trainable 1-D ``torch.Tensor`` of length + ``len(optimizer.error_inds)`` with ``requires_grad=True``. + :param torch_optimizer: A ``torch.optim`` instance owning ``logits``. + :returns: A no-arg callable that performs one optimization step and + returns the scalar loss as a ``torch.Tensor``. \ No newline at end of file diff --git a/docs/sphinx/components/qec/introduction.rst b/docs/sphinx/components/qec/introduction.rst index 99756c73..4088795a 100644 --- a/docs/sphinx/components/qec/introduction.rst +++ b/docs/sphinx/components/qec/introduction.rst @@ -899,6 +899,27 @@ The decoder returns the probability that the logical observable has flipped for that this GPU will not be supported by the Tensor Network Decoder when CUDA-Q 0.5.0 is released. +Learning the Noise Model from Data +"""""""""""""""""""""""""""""""""" + +When the true per-error noise rates are unknown (typical of real hardware), +the Tensor Network Decoder ships with ``NMOptimizer``, a differentiable +extension that **fits the noise model directly from observed syndromes and +logical-flip outcomes**. Noise probabilities are held as PyTorch tensors +with ``requires_grad=True``; backpropagating through the tensor-network +contraction yields gradients that any ``torch.optim`` optimizer (Adam, SGD, +etc.) can update. Starting from a uniform initial prior and a few hundred +Adam steps is usually enough to recover the per-error rates and beat a +static-uniform baseline on a held-out batch. + +This is offline -- training happens once on a representative syndrome +dataset, and the learned probabilities can then be used as a standard +static noise model for batch decoding. See +:ref:`tensor_network_decoder_api_python` for the ``NMOptimizer`` API and +the *Learning Noise Models with NMOptimizer* example in +:doc:`../../examples_rst/qec/decoders` for a runnable end-to-end demo on a +Stim repetition-code circuit. + Sliding Window Decoder ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/sphinx/examples/qec/python/noise_learning.py b/docs/sphinx/examples/qec/python/noise_learning.py new file mode 100644 index 00000000..6ab2784f --- /dev/null +++ b/docs/sphinx/examples/qec/python/noise_learning.py @@ -0,0 +1,161 @@ +# ============================================================================ # +# Copyright (c) 2026 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # +import sys +import platform +if platform.machine().lower() in ("arm64", "aarch64"): + print( + "Warning: stim is not supported on manylinux ARM64/aarch64. Skipping this example..." + ) + sys.exit(0) + +if sys.version_info < (3, 11): + print( + "Warning: The tensor network noise learner requires Python 3.11 or higher. Exiting..." + ) + sys.exit(0) + +# [Begin Documentation] +""" +Noise learning with NMOptimizer on a Stim repetition-code circuit. + +This script demonstrates how to use NMOptimizer to fit per-error noise +probabilities to syndrome data sampled from a Stim repetition-code memory +experiment. Starting from uniform initial priors, Adam optimization on +logits drives the cross-entropy loss down toward the true DEM error rates, +and a held-out evaluation compares the learned model's logical error rate +against the static uniform-prior baseline. + +Requirements: + pip install cudaq-qec[tensor-network-decoder] stim beliefmatching +""" + +import numpy as np +import torch +import stim +from beliefmatching.belief_matching import detector_error_model_to_check_matrices + +import cudaq_qec as qec +from cudaq_qec.plugins.decoders.tensor_network_decoder import ( + NMOptimizer, + make_compiled_step, +) + + +def parse_detector_error_model(dem): + matrices = detector_error_model_to_check_matrices(dem) + H = np.zeros(matrices.check_matrix.shape) + matrices.check_matrix.astype(np.float64).toarray(out=H) + L = np.zeros(matrices.observables_matrix.shape) + matrices.observables_matrix.astype(np.float64).toarray(out=L) + priors = [float(p) for p in matrices.priors] + return H, L, priors + + +def main(): + # Asymmetric noise (data 10x measurement) so the uniform initial + # prior is meaningfully wrong and the optimizer has signal to + # learn; with symmetric noise, uniform is already near-optimal. + circuit = stim.Circuit.generated( + "repetition_code:memory", + rounds=5, + distance=3, + before_round_data_depolarization=0.05, + before_measure_flip_probability=0.005, + ) + dem = circuit.detector_error_model(decompose_errors=True) + H, L, true_priors = parse_detector_error_model(dem) + true_probs = np.array(true_priors) + n_checks, n_errors = H.shape + + print(f"DEM: {n_checks} checks, {n_errors} errors") + print(f"True priors: mean={true_probs.mean():.4e} " + f"min={true_probs.min():.4e} max={true_probs.max():.4e} " + f"(spread {true_probs.max() / true_probs.min():.1f}x)") + + num_shots = 1000 + sampler = circuit.compile_detector_sampler() + det_events, obs_flips = sampler.sample(num_shots, separate_observables=True) + det_events = det_events.astype(float) + obs_flips = obs_flips.ravel().astype(bool) + + uniform = float(true_probs.mean()) + opt = NMOptimizer(H, + L, [uniform] * n_errors, + det_events, + obs_flips, + dtype="float64") + + # Optimize in logit space — numerically stabler than raw probs. + def _to_logits(p): + p = np.clip(p, 1e-7, 1 - 1e-7) + return -np.log(1.0 / p - 1.0) + + logits = torch.tensor( + _to_logits(np.full(n_errors, uniform)), + dtype=torch.float64, + device=opt.torch_device, + requires_grad=True, + ) + adam = torch.optim.Adam([logits], lr=1e-2) + step_fn = make_compiled_step(opt, logits, adam) + + iters = 300 + losses = [float(step_fn().detach().cpu()) for _ in range(iters)] + learned = torch.sigmoid(logits).detach().cpu().numpy() + + print(f"Loss: {losses[0]:.2f} -> {losses[-1]:.2f} " + f"({iters} Adam steps)") + print(f"True priors: mean={true_probs.mean():.4e} " + f"min={true_probs.min():.4e} max={true_probs.max():.4e}") + print(f"Learned priors: mean={learned.mean():.4e} " + f"min={learned.min():.4e} max={learned.max():.4e}") + + if losses[-1] >= losses[0]: + raise RuntimeError(f"Training did not reduce loss at all: " + f"{losses[0]:.2f} -> {losses[-1]:.2f}") + + # Held-out LER comparison is the real gate: a noise model is only + # useful if it decodes better than uniform priors. 20k shots keeps + # the per-run std of the (static - learned) difference around 0.001, + # so the +0.002 gate sits many sigmas below the expected gain even + # without a fixed RNG seed. + num_test = 20000 + test_events, test_flips = sampler.sample(num_test, + separate_observables=True) + test_events = test_events.astype(float) + test_flips_bool = test_flips.ravel().astype(bool) + + def _ler(noise: list[float]) -> float: + decoder = qec.get_decoder( + "tensor_network_decoder", + H, + logical_obs=L, + noise_model=noise, + contract_noise_model=True, + ) + res = decoder.decode_batch(test_events) + pred = np.array([r.result[0] > 0.5 for r in res], dtype=bool) + return float(np.mean(pred != test_flips_bool)) + + ler_static = _ler([uniform] * n_errors) + ler_learned = _ler(learned.tolist()) + + print(f"LER (static uniform priors): {ler_static:.4f} ({num_test} shots)") + print( + f"LER (learned priors): {ler_learned:.4f} ({num_test} shots)") + print(f"Absolute improvement: {ler_static - ler_learned:+.4f}") + + min_improvement = 0.002 + if ler_static - ler_learned < min_improvement: + raise RuntimeError( + f"Learned LER ({ler_learned:.4f}) did not beat the static " + f"baseline ({ler_static:.4f}) by at least {min_improvement:.4f}.") + + +if __name__ == "__main__": + main() diff --git a/docs/sphinx/examples_rst/qec/decoders.rst b/docs/sphinx/examples_rst/qec/decoders.rst index fe240ed1..c86f0a51 100644 --- a/docs/sphinx/examples_rst/qec/decoders.rst +++ b/docs/sphinx/examples_rst/qec/decoders.rst @@ -136,6 +136,31 @@ See Also: - ``cudaq_qec.plugins.decoders.tensor_network_decoder`` +Learning Noise Models with NMOptimizer ++++++++++++++++++++++++++++++++++++++++ + +:class:`~cudaq_qec.plugins.decoders.tensor_network_decoder.NMOptimizer` extends +the Tensor Network Decoder with differentiable noise probabilities. Given a +batch of observed syndromes and logical-flip outcomes, it fits per-error noise +rates by backpropagating through the tensor-network contraction using PyTorch. + +The following example builds a distance-3 repetition-code circuit with +**asymmetric** noise (data-qubit depolarization is 10x measurement-flip +probability), samples syndromes from Stim, and trains +:class:`NMOptimizer` from a uniform initial prior with 300 Adam steps in +logit space. It then compares the **logical error rate (LER)** of the +learned noise model against a static uniform-prior baseline on a 20k-shot +held-out batch — demonstrating that fitting per-error rates from data +decodes meaningfully better than assuming uniform noise: + +.. literalinclude:: ../../examples/qec/python/noise_learning.py + :language: python + :start-after: [Begin Documentation] + +See Also: + +- :ref:`tensor_network_decoder_api_python` + .. _deploying-ai-decoders: Deploying AI Decoders with TensorRT diff --git a/libs/qec/pyproject.toml.cu12 b/libs/qec/pyproject.toml.cu12 index 961bf938..ec0bed7d 100644 --- a/libs/qec/pyproject.toml.cu12 +++ b/libs/qec/pyproject.toml.cu12 @@ -58,6 +58,7 @@ write_to = "_version.py" tensor_network_decoder = [ "quimb", "opt_einsum", + "cotengra", "torch", "cuquantum-python-cu12>=26.3.0" ] @@ -67,6 +68,7 @@ trt_decoder = [ all = [ "quimb", "opt_einsum", + "cotengra", "torch", "cuquantum-python-cu12>=26.3.0", "tensorrt-cu12; platform_machine == 'x86_64'" diff --git a/libs/qec/pyproject.toml.cu13 b/libs/qec/pyproject.toml.cu13 index 2c23e661..5bd39d6a 100644 --- a/libs/qec/pyproject.toml.cu13 +++ b/libs/qec/pyproject.toml.cu13 @@ -58,6 +58,7 @@ write_to = "_version.py" tensor_network_decoder = [ "quimb", "opt_einsum", + "cotengra", "torch>=2.9.0", "cuquantum-python-cu13>=26.3.0" ] @@ -67,6 +68,7 @@ trt_decoder = [ all = [ "quimb", "opt_einsum", + "cotengra", "torch>=2.9.0", "cuquantum-python-cu13>=26.3.0", "tensorrt-cu13" diff --git a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_decoder.py b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_decoder.py index 546c7b20..fdd71ab6 100644 --- a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_decoder.py +++ b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_decoder.py @@ -536,3 +536,17 @@ def _set_tensor_type(self, tn: TensorNetwork) -> None: dtype=to_backend_dtype(self._dtype, like=self.contractor_config.backend), )) + + +# Re-export the noise learner so callers can use the same module path +# as :class:`TensorNetworkDecoder`. Imported at the bottom because +# :mod:`noise_models` subclasses :class:`TensorNetworkDecoder`. +from .tensor_network_utils.noise_models import ( # noqa: E402 + NMOptimizer, make_compiled_step, +) + +__all__ = [ + "TensorNetworkDecoder", + "NMOptimizer", + "make_compiled_step", +] diff --git a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/noise_models.py b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/noise_models.py index fc6bb900..1fb4346e 100644 --- a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/noise_models.py +++ b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/noise_models.py @@ -5,13 +5,44 @@ # This source code and the accompanying materials are made available under # # the terms of the Apache License 2.0 which accompanies this distribution. # # ============================================================================ # +"""Noise-model tensor-network builders and differentiable noise learning. + +Static noise models +------------------- +:func:`factorized_noise_model` and :func:`error_pairs_noise_model` return +:class:`quimb.tensor.TensorNetwork` objects whose open indices match the +error indices of the parent decoder. The networks are combined with the +code / logical / syndrome tensor networks inside +:class:`TensorNetworkDecoder`. + +Differentiable noise learning +------------------------------ +:class:`NMOptimizer` fits a factorised per-error noise model to a +syndrome dataset by backpropagating through a torch-backed tensor-network +contraction. :func:`make_compiled_step` is a convenience factory that +builds a no-arg callable for one Adam step in logit space. +""" from __future__ import annotations -from typing import Any +import warnings +from typing import Any, Literal import numpy as np +import numpy.typing as npt +import opt_einsum as oe +import torch from quimb import oset -from quimb.tensor import TensorNetwork, Tensor +from quimb.tensor import Tensor, TensorNetwork + +from ..tensor_network_decoder import TensorNetworkDecoder +from .tensor_network_factory import ( + tensor_network_from_syndrome_batch, + prepare_syndrome_data_batch, +) + +# --------------------------------------------------------------------------- +# Static noise-model builders +# --------------------------------------------------------------------------- def factorized_noise_model( @@ -24,7 +55,7 @@ def factorized_noise_model( Args: error_indices (list[str]): list of error index names. error_probabilities (Union[list[float], np.ndarray]): list or array of error probabilities for each error index. - tensors_tags (Optional[list[str]], optional): list of tags for each tensor. If None, default tags are used. + tensors_tags (list[str] | None, optional): list of tags for each tensor. If None, default tags are used. Returns: TensorNetwork: The tensor network representing the factorized noise model. @@ -70,7 +101,7 @@ def error_pairs_noise_model( Args: error_index_pairs (list[tuple[str, str]]): list of pairs of error index names. error_probabilities (list[np.ndarray]): list of 2x2 probability matrices for each error pair. - tensors_tags (Optional[list[str]], optional): list of tags for each tensor. If None, default tags are used. + tensors_tags (list[str] | None, optional): list of tags for each tensor. If None, default tags are used. Returns: TensorNetwork: The tensor network representing the error pairs noise model. @@ -103,3 +134,1169 @@ def error_pairs_noise_model( tags=oset([etag]), )) return TensorNetwork(tensors) + + +# --------------------------------------------------------------------------- +# Differentiable noise learning +# --------------------------------------------------------------------------- + +_ASCII_POOL = ("abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ") + +# Coarse for fp32 because ``1.0 - 1e-12`` rounds back to ``1.0``. +_PRIOR_EPS_BY_DTYPE: dict[str, float] = { + "float64": 1e-12, + "float32": 1e-6, +} +_SUPPORTED_DTYPES: tuple[str, ...] = ("float32", "float64") + + +def _validate_and_clamp_priors(noise_model: Any, dtype: str) -> list[float]: + """Validate noise priors and clamp them into ``[eps, 1 - eps]``. + + The fused cross-entropy reduction in + :meth:`NMOptimizer.cross_entropy_loss` has no ``log`` guard, so a + prior of exactly ``0.0`` or ``1.0`` makes the contraction emit a + zero whose log is ``-inf`` and whose gradient is ``NaN``; training + silently diverges. Stim DEMs occasionally emit ``p=1.0`` + (deterministic detectors) or ``p<1e-15`` (underflow), so we + intercept here rather than force every caller to clamp. + + Behaviour mirrors :class:`torch.nn.BCELoss`-style stable wrappers: + + * Non-finite priors (``NaN`` / ``+/-inf``) raise ``ValueError`` - + these indicate caller bugs, not numerical fragility, and + silently coercing them would hide the real problem. + * Out-of-range priors (``p <= eps`` or ``p >= 1 - eps``) are + clamped into ``[eps, 1 - eps]`` and a single ``UserWarning`` + summarises the number of values changed. + * In-range priors pass through unchanged with no warning. + + Args: + noise_model: array-like of priors, length ``num_errors``. + dtype: contraction dtype string (``"float32"`` / ``"float64"``). + + Returns: + A plain ``list[float]`` so the base + :class:`TensorNetworkDecoder` keeps using its existing + list-based factorised noise model unchanged. + """ + arr = np.asarray(noise_model, dtype=np.float64) + if arr.ndim != 1: + raise ValueError(f"noise_model must be 1-D; got shape {arr.shape}") + if not np.all(np.isfinite(arr)): + bad = np.where(~np.isfinite(arr))[0] + raise ValueError( + f"All priors must be finite; got non-finite values at error " + f"indices {bad.tolist()}: {arr[bad].tolist()}") + + dtype_str = str(dtype) + if dtype_str not in _PRIOR_EPS_BY_DTYPE: + raise ValueError(f"Unsupported dtype {dtype_str!r}; " + f"expected one of {sorted(_PRIOR_EPS_BY_DTYPE)}.") + eps = _PRIOR_EPS_BY_DTYPE[dtype_str] + out_of_range = (arr < eps) | (arr > 1.0 - eps) + if np.any(out_of_range): + warnings.warn( + f"Clamped {int(out_of_range.sum())}/{len(arr)} NMOptimizer " + f"priors into [{eps}, {1.0 - eps}] for numerical stability; " + f"values at or outside the (0, 1) boundary produce -inf " + f"cross-entropy loss and NaN gradients in the fused codegen.", + UserWarning, + stacklevel=3, + ) + arr = np.clip(arr, eps, 1.0 - eps) + return arr.tolist() + + +def _remap_eq_to_ascii(eq: str) -> str: + """Rewrite an einsum equation so every label is in ``[a-zA-Z]``. + + Needed because :mod:`opt_einsum` falls back to non-ASCII unicode + labels once the total index count exceeds 52, but + :func:`torch.einsum` rejects them. Raises if a single step has more + than 52 distinct labels. + """ + if eq.isascii(): + return eq + if "->" in eq: + lhs, rhs = eq.split("->") + else: + lhs, rhs = eq, None + + mapping: dict[str, str] = {} + out_lhs_chars: list[str] = [] + for c in lhs: + if c == ",": + out_lhs_chars.append(c) + continue + if c not in mapping: + if len(mapping) >= len(_ASCII_POOL): + raise ValueError( + f"Einsum step '{eq}' has more than {len(_ASCII_POOL)} " + "distinct labels; cannot remap to ASCII.") + mapping[c] = _ASCII_POOL[len(mapping)] + out_lhs_chars.append(mapping[c]) + out_lhs = "".join(out_lhs_chars) + if rhs is None: + return out_lhs + out_rhs_chars: list[str] = [] + for c in rhs: + if c not in mapping: + raise ValueError( + f"Einsum step '{eq}' has output label {c!r} not present " + "on the LHS; cannot remap.") + out_rhs_chars.append(mapping[c]) + return f"{out_lhs}->{''.join(out_rhs_chars)}" + + +class NMOptimizer(TensorNetworkDecoder): + """Differentiable noise-model optimiser for the TN decoder. + + The factorised noise probabilities live in the torch autograd graph + and are fit to a fixed syndrome batch by minimising the cross-entropy + of the decoder's logical prediction against the observed flips. + + The forward pass is materialised once at construction and reused + across optimisation steps. Optionally call :meth:`optimize_path` + (e.g. with ``cotengra.HyperOptimizer()``) to pin a better contraction + path; the JIT is rebuilt automatically. + + .. warning:: + + Priors are clamped into ``[eps, 1 - eps]`` only at construction; + an unconstrained optimiser step on :attr:`noise_params` can push + them past the boundary, after which :meth:`cross_entropy_loss` + returns ``NaN`` gradients. Prefer logit-space training via + :func:`make_compiled_step` (shown below), or clamp the tensor + under :func:`torch.no_grad` after each step. + + Args: + H: Parity check matrix, shape ``(num_checks, num_errors)``. + logical_obs: Logical observable matrix, shape ``(1, num_errors)``. + noise_model: Initial per-error probabilities, length ``num_errors``. + Each value must be strictly in ``(0, 1)``; values at or + outside the boundary (``p <= eps`` or ``p >= 1 - eps``, + with ``eps`` dtype-dependent) are auto-clamped at + construction with a :class:`UserWarning`. Non-finite + priors raise :class:`ValueError`. + syndrome_data: Syndrome batch, shape ``(shots, num_checks)``. + observable_flips: Observable flip outcomes, shape ``(shots,)``. + check_inds, error_inds, logical_inds, logical_tags: Optional index + and tag names; defaults track the parent decoder. + dtype: Tensor data type (e.g. ``"float32"``). + device: ``"cuda"`` (default) or ``"cpu"``. + compile: If ``True``, wrap the forward in :func:`torch.compile`. + Most useful with ``execute="codegen"``. + execute: Forward backend. ``"codegen"`` (default) partial-evaluates + the path into a flat Python function; ``"unrolled"`` keeps an + interpretive einsum list; ``"opt_einsum"`` dispatches via + :func:`opt_einsum.contract_expression`. + compile_mode: Forwarded to :func:`torch.compile`; ignored when + ``compile=False``. + dynamic_syndromes: If ``True`` (default), syndromes are runtime + arguments to the compiled forward, so :meth:`update_dataset` + does not retrigger codegen / ``torch.compile`` (provided + shapes are unchanged). ``False`` bakes the syndromes into + the closure as constants -- fewer runtime einsums, but every + :meth:`update_dataset` call rebuilds the graph. Only affects + ``execute="codegen"``. + + Example (logit-space, no clamping needed):: + + opt = NMOptimizer(H, logical_obs, priors, + syndrome_data, obs_flips) + opt.optimize_path(optimize=ctg.HyperOptimizer()) + logits = torch.logit(opt.noise_params[0].detach()).requires_grad_() + torch_opt = torch.optim.Adam([logits], lr=0.01) + step = make_compiled_step(opt, logits, torch_opt) + for _ in range(100): + loss = step() + """ + + def __init__( + self, + H: npt.NDArray[Any], + logical_obs: npt.NDArray[Any], + noise_model: list[float], + syndrome_data: npt.NDArray[Any], + observable_flips: npt.NDArray[Any], + check_inds: list[str] | None = None, + error_inds: list[str] | None = None, + logical_inds: list[str] | None = None, + logical_tags: list[str] | None = None, + dtype: str = "float32", + device: str = "cuda", + *, + compile: bool = False, + execute: Literal["codegen", "unrolled", "opt_einsum"] = "codegen", + compile_mode: str | None = None, + dynamic_syndromes: bool = True, + ) -> None: + if execute not in ("unrolled", "opt_einsum", "codegen"): + raise ValueError(f"Invalid execute mode: {execute!r}") + if dtype not in _SUPPORTED_DTYPES: + raise ValueError(f"Invalid dtype {dtype!r}; expected one of " + f"{list(_SUPPORTED_DTYPES)}.") + + # Sanitise once so the base TN tensors and ``self._noise_probs`` + # see identical values (see :func:`_validate_and_clamp_priors`). + noise_model = _validate_and_clamp_priors(noise_model, dtype) + + super().__init__( + H, + logical_obs, + noise_model, + check_inds=check_inds, + error_inds=error_inds, + logical_inds=logical_inds, + logical_tags=logical_tags, + contract_noise_model=False, + dtype=dtype, + device=device, + ) + + # Force the torch backend so tensor data lives in the autograd + # graph (the base class would otherwise pick cutensornet/numpy + # on GPU). Contractions still go through codegen / cotengra + # below, not cuTensorNet. + if self.contractor_config.contractor_name == "cutensornet" \ + and self.contractor_config.backend != "torch": + warnings.warn( + "NMOptimizer requires the torch backend for autograd; " + f"switching contractor backend from " + f"{self.contractor_config.backend!r} to 'torch'. " + "Contractions are executed via codegen/opt_einsum, not " + "cuTensorNet.", + stacklevel=3, + ) + self._set_contractor( + "cutensornet", + self.contractor_config.device, + "torch", + dtype, + ) + + # Swap the base's placeholder single-syndrome TN for a batched one. + self._syndrome_tags = [f"SYN_{i}" for i in range(len(self.check_inds))] + self.syndrome_tn = tensor_network_from_syndrome_batch( + syndrome_data, + self.check_inds, + batch_index="batch_index", + tags=self._syndrome_tags, + ) + self._batch_size = syndrome_data.shape[0] + + # Re-stitch ``full_tn`` around the batched syndrome TN. + self.full_tn = TensorNetwork() + self.full_tn = self.full_tn.combine(self.code_tn, virtual=True) + self.full_tn = self.full_tn.combine(self.logical_tn, virtual=True) + self.full_tn = self.full_tn.combine(self.syndrome_tn, virtual=True) + self.full_tn = self.full_tn.combine(self.noise_model, virtual=True) + + self._set_tensor_type(self.syndrome_tn) + + torch_dtype = getattr(torch, self._dtype) + self._noise_probs = torch.tensor( + noise_model, + dtype=torch_dtype, + device=self.torch_device, + requires_grad=True, + ) + # The base's noise tensors stay in ``full_tn`` as numpy + # placeholders: ``_snapshot_arrays_and_eq`` uses ``id()`` to + # locate their positions, then ``self._noise_probs`` (autograd + # live) is written into those slots. Do not strip them. + + self._suspend_loss_rebuild = True + self.observable_flips = observable_flips + + self._use_torch_compile = compile + self._execute_mode = execute + self._torch_compile_mode = compile_mode + self._dynamic_syndromes = dynamic_syndromes + self._compiled_predict: Any | None = None + self._syndrome_tuple: tuple[torch.Tensor, ...] = () + self._snapshot_arrays_and_eq() + self._suspend_loss_rebuild = False + + @property + def torch_device(self) -> torch.device: + """The ``torch.device`` matching the contractor config.""" + if "cuda" in self.contractor_config.device: + return torch.device(f"cuda:{self.contractor_config.device_id}",) + return torch.device("cpu") + + def _set_tensor_type(self, tn: TensorNetwork) -> None: + """Move all tensor data in *tn* to torch on the configured device. + + Overrides the base ``autoray``-routed implementation so gradients + flow through the noise-model tensors. + """ + torch_dtype = getattr(torch, self._dtype) + dev = self.torch_device + + def _to_torch(x): + if isinstance(x, torch.Tensor): + return x.to(device=dev, dtype=torch_dtype) + return torch.tensor( + np.asarray(x), + dtype=torch_dtype, + device=dev, + ) + + tn.apply_to_arrays(_to_torch) + + @property + def observable_flips(self) -> torch.Tensor: + """Boolean tensor of observable flip outcomes.""" + return self._observable_flips + + @observable_flips.setter + def observable_flips(self, value: Any) -> None: + dev = self.torch_device + if not isinstance(value, torch.Tensor): + self._observable_flips = torch.tensor( + value, + dtype=torch.bool, + device=dev, + ) + else: + self._observable_flips = value.bool().to(dev) + self.obs_idx_true = torch.where(self._observable_flips)[0] + self.obs_idx_false = torch.where(~self._observable_flips)[0] + # The fused loss bakes ``obs_idx_true/false`` into its closure + # and must be rebuilt when they change. Skip when a full + # snapshot rebuild is already pending (gated by + # ``_suspend_loss_rebuild``) or before first ``__init__``. + if (getattr(self, "_compiled_predict", None) is not None and + not getattr(self, "_suspend_loss_rebuild", False)): + self._compile_loss() + + @property + def noise_params(self) -> list[torch.Tensor]: + """Trainable noise probabilities, ready for ``torch.optim``. + + Clamped to ``[eps, 1 - eps]`` only at construction; an + unconstrained step can push past the boundary and produce + ``NaN`` gradients on the next :meth:`cross_entropy_loss`. + See the class warning for safe training patterns. + """ + return [self._noise_probs] + + def _snapshot_arrays_and_eq(self) -> None: + self._eq_batch = self.full_tn.get_equation( + output_inds=("batch_index", self.logical_obs_inds[0])) + tensors = list(self.full_tn.tensors) + self._tensors_ref = tensors + + noise_ids = {id(t) for t in self.noise_model.tensors} + syndrome_ids = {id(t) for t in self.syndrome_tn.tensors} + + self._noise_pos_for_error: dict[str, int] = {} + syndrome_positions_list: list[int] = [] + self._static_positions: list[int] = [] + + for i, t in enumerate(tensors): + if id(t) in noise_ids: + self._noise_pos_for_error[t.inds[0]] = i + elif id(t) in syndrome_ids: + syndrome_positions_list.append(i) + else: + self._static_positions.append(i) + + # Guard against a future quimb that copies tensors on virtual + # combine: every tensor in ``full_tn`` must classify into + # exactly one bucket, else the predict path rebuilds the + # operand list with a None slot or a misplaced placeholder. + n_classified = (len(self._noise_pos_for_error) + + len(syndrome_positions_list) + + len(self._static_positions)) + assert n_classified == len(tensors) + assert len(self._noise_pos_for_error) == len(self.error_inds) + + self._syndrome_positions: list[tuple[int, None]] = [ + (i, None) for i in syndrome_positions_list + ] + + self._noise_pos_ordered = tuple( + self._noise_pos_for_error[ei] for ei in self.error_inds) + + torch_dtype = getattr(torch, self._dtype) + dev = self.torch_device + + def _as_torch(x): + if isinstance(x, torch.Tensor): + return x.detach().to(device=dev, dtype=torch_dtype) + return torch.as_tensor(np.asarray(x), dtype=torch_dtype, device=dev) + + self._static_arrays: dict[int, torch.Tensor] = { + i: _as_torch(self._tensors_ref[i].data) + for i in self._static_positions + } + self._syndrome_arrays: list[torch.Tensor] = [ + _as_torch(self._tensors_ref[i].data) + for i in syndrome_positions_list + ] + self._syndrome_tuple = tuple(self._syndrome_arrays) + # Used by :meth:`_update_data` to detect layout changes that + # invalidate the cached path / codegen / oe expr. + self._syndrome_shapes: tuple[tuple[int, ...], ...] = tuple( + tuple(s.shape) for s in self._syndrome_arrays) + + if self._execute_mode == "opt_einsum": + shapes = tuple(t.shape for t in tensors) + self._oe_expr = oe.contract_expression( + self._eq_batch, + *shapes, + optimize=self.path_batch + if self.path_batch not in (None, "auto") else "auto", + ) + self._path_steps = None + else: + self._oe_expr = None + # Flatten the path into ``[(eq, idxs, sorted_desc), ...]``; + # ``sorted_desc`` is precomputed for the unrolled-mode pop + # walk, and labels are remapped to ASCII because + # opt_einsum falls back to unicode past 52 indices and + # torch.einsum rejects those. + shapes = tuple(t.shape for t in tensors) + _, info = oe.contract_path( + self._eq_batch, + *shapes, + shapes=True, + optimize=self.path_batch + if self.path_batch not in (None, "auto") else "auto", + ) + self._path_steps = [(_remap_eq_to_ascii(step[2]), tuple(step[0]), + tuple(sorted(step[0], reverse=True))) + for step in info.contraction_list] + + self._compile_predict() + self._compile_loss() + + def _compile_predict(self) -> None: + """Build ``self._predict_fn`` for the configured execute mode.""" + builders = { + "opt_einsum": self._build_predict_opt_einsum, + "unrolled": self._build_predict_unrolled, + "codegen": self._build_predict_codegen, + } + self._predict_fn = builders[self._execute_mode]() + self._compiled_predict = self._maybe_torch_compile(self._predict_fn, + kind="predict") + + def _build_predict_opt_einsum(self): + """opt_einsum-backed predict: reuse the cached contract expression.""" + static_arrays = self._static_arrays + syndrome_positions = tuple(p for p, _t in self._syndrome_positions) + noise_pos_ordered = self._noise_pos_ordered + n = len(self._tensors_ref) + oe_expr = self._oe_expr + + def _predict(noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...]) -> torch.Tensor: + noise_stacked = torch.stack((1.0 - noise_probs, noise_probs), + dim=-1) + arrays: list[torch.Tensor] = [None] * n # type: ignore + for pos, arr in static_arrays.items(): + arrays[pos] = arr + for pos, arr in zip(syndrome_positions, syndrome_tuple): + arrays[pos] = arr + for k, pos in enumerate(noise_pos_ordered): + arrays[pos] = noise_stacked[k] + # Torch backend is auto-selected from the tensor type; + # avoids the per-call ``backend=`` dispatch. + out = oe_expr(*arrays) + return out / out.sum(dim=1, keepdim=True) + + return _predict + + def _build_predict_unrolled(self): + """Unrolled predict: walk the cached pairwise contraction path.""" + static_arrays = self._static_arrays + syndrome_positions = tuple(p for p, _t in self._syndrome_positions) + noise_pos_ordered = self._noise_pos_ordered + n = len(self._tensors_ref) + path_steps = self._path_steps + + def _predict(noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...]) -> torch.Tensor: + noise_stacked = torch.stack((1.0 - noise_probs, noise_probs), + dim=-1) + ops: list[torch.Tensor] = [None] * n # type: ignore + for pos, arr in static_arrays.items(): + ops[pos] = arr + for pos, arr in zip(syndrome_positions, syndrome_tuple): + ops[pos] = arr + for k, pos in enumerate(noise_pos_ordered): + ops[pos] = noise_stacked[k] + for eq_str, idxs, sorted_idxs in path_steps: + picked = [ops[i] for i in idxs] + for i in sorted_idxs: + ops.pop(i) + ops.append(torch.einsum(eq_str, *picked)) + out = ops[0] + return out / out.sum(dim=1, keepdim=True) + + return _predict + + def _build_predict_codegen(self): + """Codegen predict: partial-eval'd flat Python with named locals.""" + static_arrays = self._static_arrays + syndrome_positions = tuple(p for p, _t in self._syndrome_positions) + noise_pos_ordered = self._noise_pos_ordered + n = len(self._tensors_ref) + syndrome_tensors = list(self._syndrome_arrays) + codegen_fn = self._build_codegen_predict( + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + self._path_steps, + syndrome_tensors, + dynamic_syndromes=self._dynamic_syndromes, + ) + self._codegen_fn = codegen_fn + self._codegen_n_folded = getattr(codegen_fn, "_n_folded", 0) + self._codegen_n_runtime = getattr(codegen_fn, "_n_runtime", 0) + + if self._dynamic_syndromes: + return codegen_fn + + # Static mode bakes syndromes into the closure and returns a + # 1-arg callable; wrap to match the public 2-arg signature. + def _predict_static( + noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...] = () + ) -> torch.Tensor: + return codegen_fn(noise_probs) + + return _predict_static + + def _maybe_torch_compile(self, fn, *, kind: str): + """Wrap ``fn`` with :func:`torch.compile` if requested. + + On any compile failure, warn and fall back to eager. ``kind`` + is included in the warning to disambiguate predict vs loss. + """ + if not self._use_torch_compile: + return fn + try: + kwargs = self._torch_compile_kwargs() + return torch.compile(fn, **kwargs) + except Exception as exc: # pragma: no cover + warnings.warn( + f"torch.compile {kind} failed ({exc!r}); " + "falling back to eager.", + RuntimeWarning, + stacklevel=2, + ) + return fn + + def _compile_loss(self) -> None: + """Build the ``(input, syndromes) -> scalar_loss`` callables. + + Two variants are produced: one accepting logits (sigmoid applied + inside) and one accepting probabilities directly. + """ + if self._execute_mode == "codegen": + logits_fn, probs_fn = self._build_loss_codegen() + else: + logits_fn, probs_fn = self._build_loss_wrapped() + + self._loss_from_logits_fn = logits_fn + self._loss_from_probs_fn = probs_fn + self._compiled_loss_from_logits = self._maybe_torch_compile(logits_fn, + kind="loss") + self._compiled_loss_from_probs = self._maybe_torch_compile(probs_fn, + kind="loss") + + def _build_loss_codegen(self): + """Codegen loss: fuse the CE reduction into the contraction graph.""" + static_arrays = self._static_arrays + syndrome_positions = tuple(p for p, _t in self._syndrome_positions) + noise_pos_ordered = self._noise_pos_ordered + n = len(self._tensors_ref) + syndrome_tensors = list(self._syndrome_arrays) + + codegen_logits = self._build_codegen_loss( + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + self._path_steps, + syndrome_tensors, + obs_idx_true=self.obs_idx_true, + obs_idx_false=self.obs_idx_false, + dynamic_syndromes=self._dynamic_syndromes, + from_logits=True, + ) + codegen_probs = self._build_codegen_loss( + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + self._path_steps, + syndrome_tensors, + obs_idx_true=self.obs_idx_true, + obs_idx_false=self.obs_idx_false, + dynamic_syndromes=self._dynamic_syndromes, + from_logits=False, + ) + + if self._dynamic_syndromes: + return codegen_logits, codegen_probs + + # Static codegen bakes syndromes into the closure and returns a + # 1-arg callable; wrap to match the public 2-arg signature. + def _loss_from_logits_static( + logits: torch.Tensor, syndrome_tuple: tuple[torch.Tensor, ...] = () + ) -> torch.Tensor: + return codegen_logits(logits) + + def _loss_from_probs_static( + noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...] = () + ) -> torch.Tensor: + return codegen_probs(noise_probs) + + return _loss_from_logits_static, _loss_from_probs_static + + def _build_loss_wrapped(self): + """opt_einsum / unrolled loss: wrap CE around ``self._predict_fn``.""" + obs_t = self.obs_idx_true + obs_f = self.obs_idx_false + predict_fn = self._predict_fn + + if self._dynamic_syndromes: + + def _loss_from_probs(noise_probs, syndromes): + p = predict_fn(noise_probs, syndromes) + return (-torch.log(p[obs_t, 1]).sum() - + torch.log(p[obs_f, 0]).sum()) + + def _loss_from_logits(logits, syndromes): + p = predict_fn(torch.sigmoid(logits), syndromes) + return (-torch.log(p[obs_t, 1]).sum() - + torch.log(p[obs_f, 0]).sum()) + else: + + def _loss_from_probs(noise_probs, syndromes=()): + p = predict_fn(noise_probs, ()) + return (-torch.log(p[obs_t, 1]).sum() - + torch.log(p[obs_f, 0]).sum()) + + def _loss_from_logits(logits, syndromes=()): + p = predict_fn(torch.sigmoid(logits), ()) + return (-torch.log(p[obs_t, 1]).sum() - + torch.log(p[obs_f, 0]).sum()) + + return _loss_from_logits, _loss_from_probs + + def _torch_compile_kwargs(self) -> dict[str, Any]: + """Build kwargs for :func:`torch.compile`. + + Defaults to ``mode="reduce-overhead"`` on CUDA so kernel-launch + overhead is amortised via CUDA Graphs; a ``compile_mode=...`` + passed to the constructor overrides this. + """ + kwargs: dict[str, Any] = {"dynamic": False} + if self._torch_compile_mode is not None: + kwargs["mode"] = self._torch_compile_mode + elif self.torch_device.type == "cuda": + kwargs["mode"] = "reduce-overhead" + return kwargs + + @staticmethod + def _codegen_partial_eval(n, static_arrays, syndrome_positions, + noise_pos_ordered, path_steps, syndrome_tensors, + dynamic_syndromes: bool): + """Partial-evaluate ``path_steps``; return the codegen building blocks. + + Steps whose inputs are all static are evaluated eagerly under + ``torch.no_grad`` and become closure constants; the remaining + steps become source lines. + + Returns ``(runtime_lines, closure_vars, used_static, final_state, + n_folded)``: emitted source lines, name -> tensor map for the + function namespace, the subset of names actually referenced, the + single surviving state slot ``(name, is_dynamic, value_or_None)``, + and the count of folded steps. + """ + static_positions = sorted(static_arrays.keys()) + noise_pos_set = set(noise_pos_ordered) + syn_pos_set = set(syndrome_positions) + # O(1) reverse lookup tables (avoid repeated list.index() inside + # the per-step loop below — was O(N^2) for large path lengths). + noise_pos_to_k = {pos: k for k, pos in enumerate(noise_pos_ordered)} + syn_pos_to_sidx = { + pos: sidx for sidx, pos in enumerate(syndrome_positions) + } + static_pos_to_sidx = { + pos: sidx for sidx, pos in enumerate(static_positions) + } + + # state[pos] = (var_name, is_dynamic, concrete_value_or_None) + state: list[tuple[str, bool, torch.Tensor | None]] = [] + for pos in range(n): + if pos in noise_pos_set: + k = noise_pos_to_k[pos] + state.append((f"_n{k}", True, None)) + elif pos in syn_pos_set: + sidx = syn_pos_to_sidx[pos] + if dynamic_syndromes: + state.append((f"_S{sidx}", True, None)) + else: + state.append((f"_S{sidx}", False, syndrome_tensors[sidx])) + else: + sidx = static_pos_to_sidx[pos] + state.append( + (f"_C{sidx}", False, static_arrays[static_positions[sidx]])) + + closure_vars: dict[str, torch.Tensor] = {} + runtime_lines: list[str] = [] + # Names that the emitted source actually references and that must + # be available in the closure namespace. We track this + # structurally as we go instead of re-parsing the source string, + # which is both faster and immune to lexical false positives / + # negatives (e.g. if an einsum equation contained an underscore). + used_static: set[str] = set() + n_folded = 0 + + for step_idx, step in enumerate(path_steps): + eq_str = step[0] + idxs = step[1] + picked = [state[i] for i in idxs] + for i in sorted(idxs, reverse=True): + state.pop(i) + any_dynamic = any(p[1] for p in picked) + out_name = f"_r{step_idx}" + if not any_dynamic: + arrs = [p[2] for p in picked] + with torch.no_grad(): + result = torch.einsum(eq_str, *arrs).contiguous() + static_name = f"_P{step_idx}" + closure_vars[static_name] = result + state.append((static_name, False, result)) + n_folded += 1 + else: + arg_names = [p[0] for p in picked] + # Track which closure names this line will read. ``_n*`` + # names are header-built locals (from noise_probs) so + # they are *not* closure values; everything else is. + for name in arg_names: + if name.startswith(("_C", "_P")): + used_static.add(name) + elif name.startswith("_S") and not dynamic_syndromes: + used_static.add(name) + runtime_lines.append( + f" {out_name} = torch.einsum({eq_str!r}, " + f"{', '.join(arg_names)})") + state.append((out_name, True, None)) + + assert len(state) == 1 + for name in used_static: + if name in closure_vars: + continue + if name.startswith("_C"): + sidx = int(name[2:]) + closure_vars[name] = static_arrays[static_positions[sidx]] + elif name.startswith("_S"): # static-syndromes mode only + sidx = int(name[2:]) + closure_vars[name] = syndrome_tensors[sidx] + + return runtime_lines, closure_vars, used_static, state[0], n_folded + + @staticmethod + def _emit_noise_header(noise_pos_ordered, + transform: str = "identity") -> list[str]: + """Emit source lines materialising ``_n0 .. _n{K-1}``. + + ``transform="identity"`` treats the input as probabilities; + ``"sigmoid"`` treats it as logits and applies ``torch.sigmoid`` + first. A single ``(K, 2)`` stack is built and then sliced, which + keeps the autograd graph compact. + """ + lines: list[str] = [] + if transform == "sigmoid": + lines.append(" _p = torch.sigmoid(noise_probs)") + else: + lines.append(" _p = noise_probs") + lines.append(" _q = 1.0 - _p") + # One stack of shape (K, 2) instead of K stacks of shape (2,). + # ``dim=1`` makes ``_NS[k]`` a contiguous view of length 2. + lines.append(" _NS = torch.stack((_q, _p), dim=1)") + for k in range(len(noise_pos_ordered)): + lines.append(f" _n{k} = _NS[{k}]") + return lines + + @staticmethod + def _emit_syndrome_header(syndrome_positions, + dynamic_syndromes: bool) -> list[str]: + """Emit source lines binding ``_S0 .. _S{S-1}`` to runtime + ``syndromes`` arguments; empty in static mode.""" + if not dynamic_syndromes: + return [] + return [ + f" _S{sidx} = syndromes[{sidx}]" + for sidx in range(len(syndrome_positions)) + ] + + @classmethod + def _build_codegen_predict(cls, + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + path_steps, + syndrome_tensors, + dynamic_syndromes: bool = True): + """Generate ``_predict(noise_probs[, syndromes]) -> (shots, 2)``. + + With ``dynamic_syndromes=False`` syndromes are folded into the + closure, which maximises partial evaluation but forces a rebuild + on every :meth:`update_dataset` call. With ``True`` (default) + syndromes stay runtime arguments and dataset swaps are free. + """ + runtime_lines, closure_vars, _used, final_state, n_folded = ( + cls._codegen_partial_eval( + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + path_steps, + syndrome_tensors, + dynamic_syndromes, + )) + final_name, is_final_dyn, final_value = final_state + fully_static = not is_final_dyn + + body: list[str] = [] + if dynamic_syndromes: + body.append("def _predict(noise_probs, syndromes):") + else: + body.append("def _predict(noise_probs):") + + if fully_static: + # Degenerate case: contraction didn't depend on noise (or any + # other dynamic input). Pre-normalise the constant and + # return it unchanged on every call. + with torch.no_grad(): + normed = final_value / final_value.sum(dim=1, keepdim=True) + closure_vars["_FINAL"] = normed + body.append(" return _FINAL") + runtime_lines = [] + else: + body.extend( + cls._emit_noise_header(noise_pos_ordered, transform="identity")) + body.extend( + cls._emit_syndrome_header(syndrome_positions, + dynamic_syndromes)) + body.extend(runtime_lines) + body.append(f" _out = {final_name}") + body.append(" return _out / _out.sum(dim=1, keepdim=True)") + + return cls._compile_codegen_source(body, closure_vars, n_folded, + len(runtime_lines), "predict") + + @classmethod + def _build_codegen_loss(cls, + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + path_steps, + syndrome_tensors, + obs_idx_true: torch.Tensor, + obs_idx_false: torch.Tensor, + dynamic_syndromes: bool = True, + from_logits: bool = True): + """Generate a fused ``(input, syndromes) -> scalar`` loss callable. + + Pipes the contraction output straight into the cross-entropy + reduction so the whole pipeline (optional sigmoid, contraction, + normalisation, cross-entropy) is a single autograd graph. + + Args: + from_logits: If ``True`` (default), apply ``torch.sigmoid`` to + the input; if ``False``, the input must already be in + ``[0, 1]``. + """ + runtime_lines, closure_vars, _used, final_state, n_folded = ( + cls._codegen_partial_eval( + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + path_steps, + syndrome_tensors, + dynamic_syndromes, + )) + final_name, is_final_dyn, final_value = final_state + fully_static = not is_final_dyn + + closure_vars["_OBS_T"] = obs_idx_true + closure_vars["_OBS_F"] = obs_idx_false + + body: list[str] = [] + if dynamic_syndromes: + body.append("def _loss(noise_probs, syndromes):") + else: + body.append("def _loss(noise_probs):") + + if fully_static: + # The contraction is a constant; the loss it produces is + # also a constant. We still need the gradient wrt + # noise_probs to be zero (a real-valued zero tensor with a + # graph edge), so emit a no-op multiplication by + # noise_probs.sum() * 0 to keep autograd happy. + with torch.no_grad(): + normed = final_value / final_value.sum(dim=1, keepdim=True) + # Compute the loss eagerly; we can't fold it because + # autograd needs a path back to noise_probs. + ce = (-torch.log(normed[obs_idx_true, 1]).sum() - + torch.log(normed[obs_idx_false, 0]).sum()) + closure_vars["_LOSS"] = ce + body.append(" return _LOSS + 0.0 * noise_probs.sum()") + runtime_lines = [] + else: + transform = "sigmoid" if from_logits else "identity" + body.extend(cls._emit_noise_header(noise_pos_ordered, transform)) + body.extend( + cls._emit_syndrome_header(syndrome_positions, + dynamic_syndromes)) + body.extend(runtime_lines) + # Fused cross-entropy that skips the explicit + # ``_preds = _out / _out.sum(dim=1, keepdim=True)`` step. + # OBS_T and OBS_F partition the batch (every row is in exactly + # one), so: + # -log(p_T[:,1]).sum() - log(p_F[:,0]).sum() + # = log(Z).sum() - log(_out[OBS_T,1]).sum() + # - log(_out[OBS_F,0]).sum() + # where Z = _out[:,0] + _out[:,1]. Saves one (shots, 2) + # division + materialisation and the corresponding backward + # nodes -- ~2-5% per-step on CPU at d=3/r=3. + body.append(f" _out = {final_name}") + body.append(" _z0 = _out[:, 0]") + body.append(" _z1 = _out[:, 1]") + body.append(" return (torch.log(_z0 + _z1).sum() " + "- torch.log(_z1[_OBS_T]).sum() " + "- torch.log(_z0[_OBS_F]).sum())") + + return cls._compile_codegen_source(body, closure_vars, n_folded, + len(runtime_lines), "loss") + + @staticmethod + def _compile_codegen_source(body: list[str], + closure_vars: dict[str, torch.Tensor], + n_folded: int, n_runtime: int, kind: str): + """Compile the assembled function source and return the callable.""" + source = "\n".join(body) + ns: dict[str, Any] = {"torch": torch} + ns.update(closure_vars) + fn_name = "_loss" if kind == "loss" else "_predict" + exec(compile(source, f"", "exec"), ns) + fn = ns[fn_name] + fn._n_folded = n_folded # type: ignore[attr-defined] + fn._n_runtime = n_runtime # type: ignore[attr-defined] + return fn + + def decoder_prediction(self) -> torch.Tensor: + """Run the forward pass; returns ``(shots, 2)`` predictions.""" + return self._compiled_predict(self._noise_probs, self._syndrome_tuple) + + def cross_entropy_loss(self) -> torch.Tensor: + """Cross-entropy loss over the syndrome batch. + + Returns a differentiable scalar; call ``.backward()`` to obtain + gradients w.r.t. :attr:`noise_params`. The fused codegen omits + the ``log`` guard, so a prior at ``0`` or ``1`` yields ``NaN`` + gradients — see :attr:`noise_params` for safe training patterns. + """ + return self._compiled_loss_from_probs(self._noise_probs, + self._syndrome_tuple) + + def current_syndrome_args(self) -> tuple[torch.Tensor, ...]: + """Return the syndrome argument expected by :meth:`loss_fn`. + + Returns ``()`` when syndromes are baked into the closure + (``execute="codegen"`` and ``dynamic_syndromes=False``), else + the current live tuple. Re-fetch each step so an intervening + :meth:`update_dataset` is reflected. + """ + if self._execute_mode == "codegen" and not self._dynamic_syndromes: + return () + return self._syndrome_tuple + + def loss_fn(self, from_logits: bool = True): + """Return a fused ``(input, syndromes) -> scalar`` loss callable. + + Useful when training in logit space (``from_logits=True``, the + default) or when feeding in an externally managed probability + tensor (``from_logits=False``). Compared to + :meth:`cross_entropy_loss`, the parameter is supplied explicitly + per call instead of being read from :attr:`noise_params`. + """ + return (self._compiled_loss_from_logits + if from_logits else self._compiled_loss_from_probs) + + def logical_error_rate(self) -> float: + """Fraction of shots decoded incorrectly. + + Uses a hard argmax threshold; **not** differentiable. + """ + with torch.no_grad(): + predictions = self.decoder_prediction() + pred = predictions[:, 1] > predictions[:, 0] + return float(1 - (pred == self._observable_flips).sum() / + len(self._observable_flips)) + + def _update_data(self, + new_syndrome_arrays: torch.Tensor, + new_observable_flips: npt.NDArray[Any], + enforce_shape: bool = True) -> None: + """In-place dataset swap on already-prepared syndrome tensors. + + ``new_syndrome_arrays`` must be in the internal layout (the + output of :func:`prepare_syndrome_data_batch`, on the right + device, shape ``(syndrome_length, shots, 2)``). Public callers + should use :meth:`update_dataset` instead. + """ + # Patch syndrome tensor data in the quimb TN in place; the + # cotengra path is invalidated below if any shape changed. + for i, tag in enumerate(self._syndrome_tags): + t = self.syndrome_tn.tensors[next( + iter(self.syndrome_tn.tag_map[tag]))] + if enforce_shape: + assert t.data.shape == new_syndrome_arrays[i].shape, ( + f"Shape mismatch for {tag}: " + f"{t.data.shape} vs {new_syndrome_arrays[i].shape}") + t.modify(data=new_syndrome_arrays[i]) + + # Suppress the loss rebuild the ``observable_flips`` setter + # would otherwise trigger; one of the branches below issues it. + self._suspend_loss_rebuild = True + self.observable_flips = new_observable_flips + + torch_dtype = getattr(torch, self._dtype) + dev = self.torch_device + new_shapes: list[tuple[int, ...]] = [] + for k, (pos, _tag) in enumerate(self._syndrome_positions): + data = self._tensors_ref[pos].data + if isinstance(data, torch.Tensor): + arr = data.detach().to(device=dev, dtype=torch_dtype) + else: + arr = torch.as_tensor(np.asarray(data), + dtype=torch_dtype, + device=dev) + self._syndrome_arrays[k] = arr + new_shapes.append(tuple(arr.shape)) + new_shapes_tuple = tuple(new_shapes) + + # Shape change: cached path / codegen / oe expression / compile + # guards are all stale. Drop the path and rebuild from scratch. + # Shapes unchanged: dynamic codegen and the unrolled / + # opt_einsum paths read syndromes per call — refreshing the + # cached tuple is enough. Static codegen baked the old tensors + # into the closure and still needs a full rebuild. + shape_changed = new_shapes_tuple != self._syndrome_shapes + if shape_changed: + self.path_batch = None + self.slicing_batch = tuple() + try: + self._snapshot_arrays_and_eq() + finally: + self._suspend_loss_rebuild = False + return + + self._syndrome_tuple = tuple(self._syndrome_arrays) + if self._execute_mode == "codegen" and not self._dynamic_syndromes: + try: + self._snapshot_arrays_and_eq() + finally: + self._suspend_loss_rebuild = False + else: + # The observable indices may have changed; the loss bakes + # them in, so it still needs a rebuild. + self._suspend_loss_rebuild = False + self._compile_loss() + + def update_dataset(self, + new_syndrome_data: npt.NDArray[Any], + new_observable_flips: npt.NDArray[Any], + enforce_shape: bool = True) -> None: + """Replace the syndrome batch and observable flips. + + Args: + new_syndrome_data: Shape ``(shots, num_checks)``. + new_observable_flips: Shape ``(shots,)``. + enforce_shape: Assert that per-tensor shapes match. A + changing batch size triggers a full rebuild of the + cached contraction path and codegen. + """ + syndrome_arrays = prepare_syndrome_data_batch(new_syndrome_data) + torch_dtype = getattr(torch, self._dtype) + syndrome_arrays = torch.tensor( + syndrome_arrays, + dtype=torch_dtype, + device=self.torch_device, + ).transpose(1, 2) + self._batch_size = int(new_syndrome_data.shape[0]) + self._update_data(syndrome_arrays, new_observable_flips, enforce_shape) + + def optimize_path(self, optimize: Any = None, batch_size: int = -1) -> Any: + """Cache a contraction path via quimb and rebuild the JIT. + + Always routes through :meth:`TensorNetwork.contraction_info` so + the resulting path is compatible with :mod:`opt_einsum` and + manual unrolling -- unlike :meth:`TensorNetworkDecoder.optimize_path`, + which defaults to a cuTensorNet-only path. + + ``batch_size`` is part of the parent ``TensorNetworkDecoder`` + signature (which rebuilds its TN around a fake batch); on the + optimiser the syndrome TN is already batched at construction + and resized in :meth:`update_dataset`, so this argument is + ignored. Kept for Liskov substitution with the parent. + """ + del batch_size + info = self.full_tn.contraction_info( + output_inds=("batch_index", self.logical_obs_inds[0]), + optimize=optimize if optimize is not None else "auto", + ) + self.path_batch = info.path + self.slicing_batch = tuple() + self._snapshot_arrays_and_eq() + return info + + +def make_compiled_step(optimizer: NMOptimizer, logits: torch.Tensor, + torch_optimizer: torch.optim.Optimizer): + """Build a no-arg callable that runs one Adam step and returns the loss. + + The step zeros grads, calls the optimizer's compiled + ``loss_fn(from_logits=True)`` (sigmoid + contraction + cross-entropy + fused), backwards, and steps ``torch_optimizer``. Use this when + training in logit space. + + Args: + optimizer: The :class:`NMOptimizer` providing the fused + inner loss; pass ``compile=True`` at the + :class:`NMOptimizer` constructor for the + ``torch.compile``-d variant. + logits: Trainable 1-D tensor of length ``len(optimizer.error_inds)`` + with ``requires_grad=True``. + torch_optimizer: A ``torch.optim`` instance owning ``logits``. + """ + + # Re-fetch per call so a rebuild from update_dataset or the + # observable_flips setter is picked up; capturing would go stale. + def _step(): + torch_optimizer.zero_grad(set_to_none=True) + loss = optimizer.loss_fn(from_logits=True)( + logits, optimizer.current_syndrome_args()) + loss.backward() + torch_optimizer.step() + return loss + + return _step diff --git a/libs/qec/python/tests/test_tn_noise_models.py b/libs/qec/python/tests/test_tn_noise_models.py new file mode 100644 index 00000000..5ff75a21 --- /dev/null +++ b/libs/qec/python/tests/test_tn_noise_models.py @@ -0,0 +1,826 @@ +# ============================================================================ # +# Copyright (c) 2026 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # +"""Tests for :class:`NMOptimizer` and friends. + +Each test is parameterised over ``device in ("cpu", "cuda")``; CUDA +cases are skipped when no GPU is available. +""" + +import sys +import warnings + +import numpy as np +import pytest + +torch = pytest.importorskip( + "torch", reason="torch not installed; skipping TN noise-learning tests") + +import cudaq_qec as qec # noqa: E402 + +if sys.version_info >= (3, 11): + from cudaq_qec.plugins.decoders.tensor_network_utils.noise_models import ( + NMOptimizer, + _PRIOR_EPS_BY_DTYPE, + _remap_eq_to_ascii, + _validate_and_clamp_priors, + make_compiled_step, + ) + +pytestmark = pytest.mark.skipif(sys.version_info < (3, 11), + reason="Requires Python >= 3.11") + + +def _gpu_available() -> bool: + return torch.cuda.is_available() + + +def _device_params(): + """``device`` parametrize values; cuda is skipped when unavailable.""" + out = ["cpu"] + if _gpu_available(): + out.append("cuda") + return out + + +_EXECUTE_MODES = ("codegen", "unrolled", "opt_einsum") + +# -- fixtures / helpers ------------------------------------------------------- + + +def _simple_repetition_code(): + """[[3,1]] repetition-code-like fixture with a single logical observable.""" + H = np.array([[1, 1, 0], [0, 1, 1]], dtype=np.float64) + logical = np.array([[1, 0, 1]], dtype=np.float64) + priors = [0.1, 0.2, 0.3] + return H, logical, priors + + +def _nondegenerate_code(): + """3-error code where ``P(l|s)`` genuinely depends on the priors. + + ``L = [1,1,1]`` is **not** in the GF(2) row span of ``H``, so every + syndrome admits error patterns of both logical values and the + gradient w.r.t. the noise priors is non-trivial. Use this whenever + a test needs to exercise the autograd path; the + :func:`_simple_repetition_code` fixture has ``L`` in ``row(H)``, + which makes ``P(l|s)`` deterministic and zeroes the gradient. + """ + H = np.array([[1, 1, 0], [1, 0, 1]], dtype=np.float64) + logical = np.array([[1, 1, 1]], dtype=np.float64) + return H, logical + + +def _random_code(rng: np.random.Generator, + n_checks: int = 5, + n_errors: int = 8): + H = rng.integers(0, 2, size=(n_checks, n_errors)).astype(np.float64) + logical = rng.integers(0, 2, size=(1, n_errors)).astype(np.float64) + priors = rng.uniform(0.02, 0.2, size=n_errors).astype(np.float64).tolist() + return H, logical, priors + + +def _sample_synthetic_dataset(H: np.ndarray, logical_obs: np.ndarray, + priors: list[float], num_shots: int, + rng: np.random.Generator): + """Sample errors from a Bernoulli noise model and derive (syn, flips).""" + n_errors = H.shape[1] + p = np.asarray(priors, dtype=np.float64) + errors = (rng.random((num_shots, n_errors)) < p).astype(np.uint8) + syndromes = (errors @ H.T) % 2 + flips = (errors @ logical_obs.T).reshape(-1) % 2 + return syndromes.astype(np.float64), flips.astype(bool) + + +def _make_opt(H, logical, priors, syn, flips, **kwargs): + """Thin wrapper that forwards kwargs; collapses 7-line constructions.""" + return NMOptimizer(H, logical, priors, syn, flips, **kwargs) + + +def _naive_cross_entropy(opt: "NMOptimizer") -> torch.Tensor: + """Reference cross-entropy: predict, then ``-log p`` per shot. + + Mirrors the pre-fusion implementation; used to verify the codegen + loss in :func:`test_fused_loss_matches_naive`. + """ + preds = opt.decoder_prediction() + obs_t = opt.obs_idx_true + obs_f = opt.obs_idx_false + return (-torch.log(preds[obs_t, 1]).sum() - + torch.log(preds[obs_f, 0]).sum()) + + +# -- construction ------------------------------------------------------------ + + +@pytest.mark.parametrize("device", _device_params()) +def test_construction_basic(device): + H, logical, priors = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=8, + rng=np.random.default_rng(0)) + opt = _make_opt(H, logical, priors, syn, flips, device=device) + assert opt._batch_size == 8 + assert opt._noise_probs.requires_grad + assert len(opt.noise_params) == 1 + assert opt.noise_params[0] is opt._noise_probs + np_probs = opt._noise_probs.detach().cpu().numpy() + assert np_probs.shape == (3,) + np.testing.assert_allclose(np_probs, priors, atol=1e-6) + assert np.all((np_probs >= 0.0) & (np_probs <= 1.0)) + + +def test_public_reexport_from_tensor_network_decoder_module(): + """``NMOptimizer`` is re-exported from the TN decoder plugin module.""" + from cudaq_qec.plugins.decoders import tensor_network_decoder as tnd + from cudaq_qec.plugins.decoders.tensor_network_utils import (noise_models as + nl) + assert tnd.NMOptimizer is nl.NMOptimizer + assert tnd.make_compiled_step is nl.make_compiled_step + assert "NMOptimizer" in tnd.__all__ + assert "make_compiled_step" in tnd.__all__ + + +@pytest.mark.parametrize("device", _device_params()) +def test_invalid_execute_mode_rejected(device): + H, logical, priors = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=4, + rng=np.random.default_rng(1)) + with pytest.raises(ValueError, match="Invalid execute mode"): + _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + execute="bogus") + + +# -- forward pass / gradient ------------------------------------------------- + + +@pytest.mark.parametrize("device", _device_params()) +def test_decoder_prediction_shape_and_range(device): + H, logical, priors = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=16, + rng=np.random.default_rng(2)) + opt = _make_opt(H, logical, priors, syn, flips, device=device) + pred = opt.decoder_prediction() + assert pred.shape == (16, 2) + s = pred.sum(dim=1) + assert torch.allclose(s, torch.ones_like(s), atol=1e-5) + assert torch.all(pred >= -1e-6) and torch.all(pred <= 1.0 + 1e-6) + + +@pytest.mark.parametrize("device", _device_params()) +@pytest.mark.parametrize("execute", _EXECUTE_MODES) +def test_gradient_flows(device, execute): + """Backward populates a non-zero gradient on ``_noise_probs``. + + Uses :func:`_nondegenerate_code` plus mismatched init priors so the + loss surface has a non-trivial gradient. Parametrized over every + ``execute`` mode so unrolled and opt_einsum paths can't silently + regress on the autograd graph. + """ + rng = np.random.default_rng(3) + H, logical = _nondegenerate_code() + true_priors = [0.1, 0.15, 0.25] + init_priors = [0.5, 0.5, 0.5] + syn, flips = _sample_synthetic_dataset(H, + logical, + true_priors, + num_shots=64, + rng=rng) + opt = _make_opt(H, + logical, + init_priors, + syn, + flips, + device=device, + dtype="float64", + execute=execute) + opt._noise_probs.grad = None + loss = opt.cross_entropy_loss() + loss.backward() + assert opt._noise_probs.grad is not None + assert torch.any(opt._noise_probs.grad != 0.0) + + +# -- fused-loss correctness -------------------------------------------------- + + +@pytest.mark.parametrize("device", _device_params()) +@pytest.mark.parametrize("execute", _EXECUTE_MODES) +def test_fused_loss_matches_naive(device, execute): + """``cross_entropy_loss`` == predict + manual CE in every execute mode. + + Codegen mode fuses the CE reduction into the contraction graph; + unrolled/opt_einsum wrap ``predict_fn``. All three must agree with + the naive reference up to fp64 round-off. + """ + rng = np.random.default_rng(11) + H, logical = _nondegenerate_code() + init_priors = [0.2, 0.3, 0.4] + syn, flips = _sample_synthetic_dataset(H, + logical, [0.1, 0.15, 0.25], + num_shots=48, + rng=rng) + opt = _make_opt(H, + logical, + init_priors, + syn, + flips, + device=device, + dtype="float64", + execute=execute) + with torch.no_grad(): + fused = opt.cross_entropy_loss() + naive = _naive_cross_entropy(opt) + assert torch.isfinite(fused) and torch.isfinite(naive) + assert torch.allclose(fused, naive, atol=1e-8, rtol=1e-8) + + +@pytest.mark.parametrize("device", _device_params()) +def test_fused_loss_matches_naive_static_codegen(device): + """Static codegen (``dynamic_syndromes=False``) numerical correctness.""" + rng = np.random.default_rng(13) + H, logical = _nondegenerate_code() + init_priors = [0.2, 0.3, 0.4] + syn, flips = _sample_synthetic_dataset(H, + logical, [0.1, 0.15, 0.25], + num_shots=40, + rng=rng) + opt = _make_opt(H, + logical, + init_priors, + syn, + flips, + device=device, + dtype="float64", + execute="codegen", + dynamic_syndromes=False) + with torch.no_grad(): + fused = opt.cross_entropy_loss() + naive = _naive_cross_entropy(opt) + assert torch.isfinite(fused) and torch.isfinite(naive) + assert torch.allclose(fused, naive, atol=1e-8, rtol=1e-8) + loss_probs = opt.loss_fn(from_logits=False) + loss_logits = opt.loss_fn(from_logits=True) + probs = opt._noise_probs.detach().clone().requires_grad_(False) + logits = torch.log(probs / (1.0 - probs)) + with torch.no_grad(): + v_probs = loss_probs(probs, ()) + v_logits = loss_logits(logits, ()) + assert torch.allclose(v_probs, fused, atol=1e-8, rtol=1e-8) + assert torch.allclose(v_logits, fused, atol=1e-8, rtol=1e-8) + + +@pytest.mark.parametrize("device", _device_params()) +@pytest.mark.parametrize("execute", _EXECUTE_MODES) +def test_loss_fn_from_logits_and_probs(device, execute): + """``loss_fn(from_logits=True)`` matches ``loss_fn(from_logits=False) o sigmoid``, + and both agree with ``cross_entropy_loss`` on the optimiser's own probs. + + Parametrized over execute modes so the logit-vs-probs equivalence is + validated on every supported backend. + """ + rng = np.random.default_rng(12) + H, logical = _nondegenerate_code() + init_priors = [0.2, 0.3, 0.4] + syn, flips = _sample_synthetic_dataset(H, + logical, [0.1, 0.15, 0.25], + num_shots=32, + rng=rng) + opt = _make_opt(H, + logical, + init_priors, + syn, + flips, + device=device, + dtype="float64", + execute=execute) + loss_probs = opt.loss_fn(from_logits=False) + loss_logits = opt.loss_fn(from_logits=True) + probs = opt._noise_probs.detach().clone().requires_grad_(False) + logits = torch.log(probs / (1.0 - probs)) + with torch.no_grad(): + v_probs = loss_probs(probs, opt._syndrome_tuple) + v_logits = loss_logits(logits, opt._syndrome_tuple) + v_self = opt.cross_entropy_loss() + assert torch.allclose(v_probs, v_logits, atol=1e-8, rtol=1e-8) + assert torch.allclose(v_probs, v_self, atol=1e-8, rtol=1e-8) + + +# -- numerical guards -------------------------------------------------------- + + +@pytest.mark.parametrize("device", _device_params()) +def test_small_priors_finite_loss(device): + """Realistic small priors (``1e-3``) pass through unclamped and yield finite loss.""" + H, logical, _ = _simple_repetition_code() + small_priors = [1e-3, 0.5, 1.0 - 1e-3] + syn, flips = _sample_synthetic_dataset(H, + logical, [0.1, 0.2, 0.3], + num_shots=8, + rng=np.random.default_rng(4)) + for dtype in ("float32", "float64"): + opt = _make_opt(H, + logical, + small_priors, + syn, + flips, + device=device, + dtype=dtype) + assert torch.all(torch.isfinite(opt._noise_probs)) + loss = opt.cross_entropy_loss() + assert torch.isfinite(loss), ( + f"non-finite loss at dtype={dtype}: {loss}") + + +@pytest.mark.parametrize("device", _device_params()) +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_boundary_priors_clamped_with_warning(device, dtype): + """Priors at the (0, 1) boundary are clamped into ``[eps, 1 - eps]`` + with a single :class:`UserWarning`; loss stays finite downstream.""" + H, logical, _ = _simple_repetition_code() + boundary_priors = [0.0, 0.5, 1.0] + syn, flips = _sample_synthetic_dataset(H, + logical, [0.1, 0.2, 0.3], + num_shots=8, + rng=np.random.default_rng(20)) + eps = _PRIOR_EPS_BY_DTYPE[dtype] + with pytest.warns(UserWarning, match=r"Clamped \d+/\d+"): + opt = _make_opt(H, + logical, + boundary_priors, + syn, + flips, + device=device, + dtype=dtype) + probs = opt._noise_probs.detach().cpu().numpy() + assert np.all(probs >= eps - 1e-9) + assert np.all(probs <= 1.0 - eps + 1e-9) + np.testing.assert_allclose(probs[1], 0.5, atol=1e-6) + loss = opt.cross_entropy_loss() + assert torch.isfinite(loss) + + +@pytest.mark.parametrize("device", _device_params()) +def test_non_finite_priors_raise(device): + """Non-finite priors are caller bugs, not stability concerns - raise.""" + H, logical, _ = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, [0.1, 0.2, 0.3], + num_shots=8, + rng=np.random.default_rng(21)) + for bad_priors in ([0.1, np.nan, 0.3], [0.1, np.inf, 0.3]): + with pytest.raises(ValueError, match="All priors must be finite"): + _make_opt(H, logical, bad_priors, syn, flips, device=device) + + +def test_in_range_priors_no_warning(): + """In-range priors must pass through with zero warnings.""" + H, logical, priors = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=4, + rng=np.random.default_rng(22)) + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + _make_opt(H, logical, priors, syn, flips, device="cpu") + + +def test_validate_and_clamp_priors_helper(): + """Unit-test the helper directly: shape, unknown-dtype rejection, idempotence.""" + with pytest.raises(ValueError, match="must be 1-D"): + _validate_and_clamp_priors(np.zeros((2, 3)) + 0.5, "float64") + with pytest.raises(ValueError, match="Unsupported dtype"): + _validate_and_clamp_priors([0.1, 0.5, 0.9], "float128_unknown") + out = _validate_and_clamp_priors([0.1, 0.5, 0.9], "float64") + assert out == [0.1, 0.5, 0.9] + out = _validate_and_clamp_priors(out, "float64") + assert out == [0.1, 0.5, 0.9] + + +# -- current_syndrome_args --------------------------------------------------- + + +@pytest.mark.parametrize("device", _device_params()) +def test_current_syndrome_args_dynamic_returns_live_tuple(device): + """Dynamic mode: returns the live syndrome tuple.""" + rng = np.random.default_rng(101) + H, logical, priors = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=12, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + execute="codegen", + dynamic_syndromes=True) + args = opt.current_syndrome_args() + assert args is opt._syndrome_tuple + assert len(args) > 0 + assert torch.isfinite( + opt.loss_fn(from_logits=False)(opt.noise_params[0], args)) + + +@pytest.mark.parametrize("device", _device_params()) +def test_current_syndrome_args_static_returns_empty(device): + """Static codegen mode: returns ``()`` (syndromes are closure-baked).""" + rng = np.random.default_rng(102) + H, logical, priors = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=12, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + execute="codegen", + dynamic_syndromes=False) + assert opt.current_syndrome_args() == () + assert torch.isfinite( + opt.loss_fn(from_logits=False)(opt.noise_params[0], ())) + + +# -- dataset swap ------------------------------------------------------------ + + +@pytest.mark.parametrize("device", _device_params()) +def test_update_dataset_dynamic_keeps_predict_fn(device): + """Dynamic mode: predict function identity unchanged across swaps.""" + rng = np.random.default_rng(5) + H, logical, priors = _simple_repetition_code() + syn1, flips1 = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=10, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn1, + flips1, + device=device, + dynamic_syndromes=True) + fn_before = opt._predict_fn + syn2, flips2 = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=10, + rng=rng) + opt.update_dataset(syn2, flips2) + assert opt._predict_fn is fn_before + + +@pytest.mark.parametrize("device", _device_params()) +def test_update_dataset_static_rebuilds_predict_fn(device): + """Static mode: predict function is re-codegened on swap.""" + rng = np.random.default_rng(6) + H, logical, priors = _simple_repetition_code() + syn1, flips1 = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=10, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn1, + flips1, + device=device, + dynamic_syndromes=False) + fn_before = opt._predict_fn + syn2, flips2 = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=10, + rng=rng) + opt.update_dataset(syn2, flips2) + assert opt._predict_fn is not fn_before + + +@pytest.mark.parametrize("device", _device_params()) +@pytest.mark.parametrize("dynamic_syndromes", [True, False]) +def test_update_dataset_shape_change_rebuilds_and_decodes( + device, dynamic_syndromes): + """A different batch size triggers a full rebuild; loss stays finite + and matches a freshly constructed optimiser to fp64 precision.""" + rng = np.random.default_rng(77) + H, logical = _nondegenerate_code() + init_priors = [0.1, 0.15, 0.25] + syn1, flips1 = _sample_synthetic_dataset(H, + logical, + init_priors, + num_shots=16, + rng=rng) + opt = _make_opt(H, + logical, + init_priors, + syn1, + flips1, + device=device, + dtype="float64", + dynamic_syndromes=dynamic_syndromes) + syn2, flips2 = _sample_synthetic_dataset(H, + logical, + init_priors, + num_shots=33, + rng=rng) + opt.update_dataset(syn2, flips2, enforce_shape=False) + assert opt._batch_size == 33 + pred = opt.decoder_prediction() + assert pred.shape == (33, 2) + loss = opt.cross_entropy_loss() + assert torch.isfinite(loss) + + ref = _make_opt(H, + logical, + init_priors, + syn2, + flips2, + device=device, + dtype="float64", + dynamic_syndromes=dynamic_syndromes) + with torch.no_grad(): + ref_loss = ref.cross_entropy_loss() + assert torch.allclose(loss, ref_loss, atol=1e-8, rtol=1e-8) + + +# -- optimize_path ----------------------------------------------------------- + + +@pytest.mark.parametrize("device", _device_params()) +def test_optimize_path_default_preserves_forward(device): + """``optimize_path()`` with the default ``"auto"`` finder rebuilds the + JIT but does not change the numerical forward output.""" + rng = np.random.default_rng(88) + H, logical = _nondegenerate_code() + priors = [0.1, 0.15, 0.25] + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=24, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + dtype="float64") + with torch.no_grad(): + before = opt.decoder_prediction().detach().cpu().numpy() + opt.optimize_path() + with torch.no_grad(): + after = opt.decoder_prediction().detach().cpu().numpy() + np.testing.assert_allclose(before, after, atol=1e-10, rtol=1e-10) + + +@pytest.mark.parametrize("device", _device_params()) +def test_optimize_path_with_cotengra(device): + """A user-supplied ``cotengra.HyperOptimizer`` is accepted by + ``optimize_path`` and the forward stays numerically consistent.""" + ctg = pytest.importorskip("cotengra") + rng = np.random.default_rng(89) + H, logical = _nondegenerate_code() + priors = [0.1, 0.15, 0.25] + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=24, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + dtype="float64") + with torch.no_grad(): + before = opt.decoder_prediction().detach().cpu().numpy() + info = opt.optimize_path( + optimize=ctg.HyperOptimizer(max_repeats=2, parallel=False)) + assert info is not None + with torch.no_grad(): + after = opt.decoder_prediction().detach().cpu().numpy() + np.testing.assert_allclose(before, after, atol=1e-10, rtol=1e-10) + + +# -- _remap_eq_to_ascii ------------------------------------------------------- + + +def test_remap_eq_to_ascii_simple(): + eq = "ab,bc->ac" + out = _remap_eq_to_ascii(eq) + # ASCII input is returned unchanged via the ``isascii()`` fast path. + assert out == "ab,bc->ac" + + +def test_remap_eq_to_ascii_unicode_labels(): + """Synthetic equation with non-ASCII labels is remapped to a-zA-Z.""" + eq = "\u0391\u0392,\u0392\u0393->\u0391\u0393" # greek letters + out = _remap_eq_to_ascii(eq) + assert "\u0391" not in out and "\u0392" not in out and "\u0393" not in out + assert "->" in out + lhs, rhs = out.split("->") + assert all(c.isascii() and c.isalpha() or c == "," for c in lhs) + assert all(c.isascii() and c.isalpha() for c in rhs) + + +def test_remap_eq_to_ascii_too_many_labels(): + """Equations with > 52 distinct labels raise.""" + chars = [chr(0x4E00 + i) for i in range(53)] # 53 distinct CJK chars + eq = "".join(chars) + "->" + chars[0] + with pytest.raises(ValueError, match="more than 52"): + _remap_eq_to_ascii(eq) + + +# -- logical_error_rate ------------------------------------------------------ + + +@pytest.mark.parametrize("device", _device_params()) +def test_logical_error_rate_matches_argmax(device): + """``logical_error_rate`` equals ``mean(argmax != observable_flips)``.""" + rng = np.random.default_rng(202) + H, logical = _nondegenerate_code() + true_priors = [0.05, 0.15, 0.10] + syn, flips = _sample_synthetic_dataset(H, + logical, + true_priors, + num_shots=256, + rng=rng) + opt = _make_opt(H, + logical, + true_priors, + syn, + flips, + device=device, + dtype="float64") + ler = opt.logical_error_rate() + assert isinstance(ler, float) + assert 0.0 <= ler <= 1.0 + + with torch.no_grad(): + preds = opt.decoder_prediction() + argmax_pred = (preds[:, 1] > preds[:, 0]).cpu().numpy() + expected = float(np.mean(argmax_pred != flips.astype(bool))) + assert abs(ler - expected) < 1e-12 + + +@pytest.mark.parametrize("device", _device_params()) +def test_logical_error_rate_improves_with_better_priors(device): + """Decoding with true Bernoulli rates beats (or matches) uniform priors.""" + rng = np.random.default_rng(303) + H, logical = _nondegenerate_code() + true_priors = [0.03, 0.18, 0.07] + syn, flips = _sample_synthetic_dataset(H, + logical, + true_priors, + num_shots=4096, + rng=rng) + uniform = [float(np.mean(true_priors))] * H.shape[1] + opt_true = _make_opt(H, + logical, + true_priors, + syn, + flips, + device=device, + dtype="float64") + opt_uniform = _make_opt(H, + logical, + uniform, + syn, + flips, + device=device, + dtype="float64") + assert opt_true.logical_error_rate( + ) <= opt_uniform.logical_error_rate() + 1e-6 + + +# -- parity vs base TN decoder ----------------------------------------------- + + +@pytest.mark.parametrize("device", _device_params()) +def test_forward_parity_with_tn_decoder(device): + """Forward with frozen probs agrees with the base TN decoder's batch.""" + rng = np.random.default_rng(123) + H, logical, priors = _random_code(rng, n_checks=4, n_errors=6) + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=12, + rng=rng) + opt = _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + dtype="float64") + with torch.no_grad(): + pred = opt.decoder_prediction().detach().cpu().numpy() + ref = qec.get_decoder("tensor_network_decoder", + H, + logical_obs=logical, + noise_model=priors, + dtype="float64") + res = ref.decode_batch(syn) + ref_p_flip = np.array([r.result[0] for r in res], dtype=np.float64) + np.testing.assert_allclose(pred[:, 1], ref_p_flip, atol=1e-4, rtol=1e-4) + + +# -- CPU/GPU parity ---------------------------------------------------------- + + +@pytest.mark.skipif(not _gpu_available(), + reason="CUDA not available; CPU/GPU parity test skipped") +def test_cpu_gpu_parity_forward(): + """Forward with the same seed agrees CPU vs GPU to atol 1e-4.""" + rng = np.random.default_rng(7) + H, logical, priors = _random_code(rng, n_checks=4, n_errors=6) + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=16, + rng=rng) + opt_cpu = _make_opt(H, + logical, + priors, + syn, + flips, + device="cpu", + dtype="float64") + opt_gpu = _make_opt(H, + logical, + priors, + syn, + flips, + device="cuda", + dtype="float64") + with torch.no_grad(): + p_cpu = opt_cpu.decoder_prediction().detach().cpu().numpy() + p_gpu = opt_gpu.decoder_prediction().detach().cpu().numpy() + np.testing.assert_allclose(p_cpu, p_gpu, atol=1e-4, rtol=1e-4) + + +# -- truth-data convergence -------------------------------------------------- + + +@pytest.mark.parametrize("device", _device_params()) +def test_recovers_true_priors_within_tol(device): + """Fitted priors converge to the Bernoulli rates that sampled the data.""" + rng = np.random.default_rng(0xC0DE) + H, logical = _nondegenerate_code() + true_priors = [0.03, 0.12, 0.08] + syn, flips = _sample_synthetic_dataset(H, + logical, + true_priors, + num_shots=2000, + rng=rng) + init_priors = [0.10] * H.shape[1] + opt = _make_opt(H, + logical, + init_priors, + syn, + flips, + device=device, + dtype="float64", + execute="codegen") + init_p = opt.noise_params[0].detach() + logits = torch.logit(init_p).clone().requires_grad_(True) + torch_opt = torch.optim.Adam([logits], lr=0.05) + step_fn = make_compiled_step(opt, logits, torch_opt) + for _ in range(500): + step_fn() + fitted = torch.sigmoid(logits).detach().cpu().numpy() + np.testing.assert_allclose(fitted, + np.asarray(true_priors, dtype=np.float64), + atol=0.02) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 2900586c053aea2e13c0e264badcf03c66e33e90 Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Tue, 19 May 2026 17:05:05 -0400 Subject: [PATCH 2/5] chore(qec): drop narrative Sphinx docs from noise-learning integration PR Restore API/introduction/decoders RST to main so documentation can ship in a follow-up PR; keep docs/sphinx/examples/qec/python/noise_learning.py here. Signed-off-by: vedika-saravanan --- .../api/qec/tensor_network_decoder_api.rst | 183 +----------------- docs/sphinx/components/qec/introduction.rst | 21 -- docs/sphinx/examples_rst/qec/decoders.rst | 25 --- 3 files changed, 1 insertion(+), 228 deletions(-) diff --git a/docs/sphinx/api/qec/tensor_network_decoder_api.rst b/docs/sphinx/api/qec/tensor_network_decoder_api.rst index 16279de2..b5bbc600 100644 --- a/docs/sphinx/api/qec/tensor_network_decoder_api.rst +++ b/docs/sphinx/api/qec/tensor_network_decoder_api.rst @@ -93,185 +93,4 @@ :param optimize: Optimization options or None :param batch_size: (int, optional) Batch size for optimization (default: -1, no batching) - :returns: Optimizer info object - -.. class:: cudaq_qec.plugins.decoders.tensor_network_decoder.NMOptimizer - - Differentiable noise-model optimizer built on top of :class:`TensorNetworkDecoder`. - - Fits a factorised per-error noise model to a syndrome dataset by - backpropagating through a torch-backed tensor-network contraction. - The noise probabilities are maintained as ``torch`` tensors with - ``requires_grad=True`` so they can be updated with any ``torch.optim`` - optimizer. - - Requires Python 3.11 or higher and the same optional dependencies as - :class:`TensorNetworkDecoder` (``pip install cudaq-qec[tensor-network-decoder]``). - PyTorch must also be installed. - - .. note:: - Quick-start example (logit-space training; the loss has no ``log`` - guard, so direct probability training requires per-step clamping - into ``[eps, 1 - eps]``):: - - import numpy as np - import torch - from cudaq_qec.plugins.decoders.tensor_network_decoder import ( - NMOptimizer, make_compiled_step, - ) - - H = np.array([[1, 1, 0], [0, 1, 1]], dtype=np.float64) - logical = np.array([[1, 0, 1]], dtype=np.float64) - priors = [0.1, 0.2, 0.3] - - opt = NMOptimizer(H, logical, priors, syndrome_data, obs_flips, - dtype="float64") - logits = torch.logit(opt.noise_params[0].detach()).requires_grad_() - adam = torch.optim.Adam([logits], lr=0.01) - step = make_compiled_step(opt, logits, adam) - for _ in range(100): - step() - - :param H: Parity check matrix (numpy.ndarray), shape (num_checks, num_errors) - :param logical_obs: Logical observable matrix (numpy.ndarray), shape (1, num_errors) - :param noise_model: Initial per-error probabilities, list of floats in (0, 1). - Values outside ``[eps, 1 - eps]`` are clamped at - construction with a ``UserWarning``; non-finite values - raise ``ValueError``. ``eps`` is ``1e-12`` for - ``"float64"`` and ``1e-6`` for ``"float32"``. - :param syndrome_data: Observed syndromes, numpy.ndarray of shape (num_shots, num_checks) - :param observable_flips: Observed logical flips, bool array of length num_shots - :param check_inds: (optional) List of check index names; defaults track the parent decoder. - :param error_inds: (optional) List of error index names; defaults track the parent decoder. - :param logical_inds: (optional) List of logical index names; defaults track the parent decoder. - :param logical_tags: (optional) List of logical tags; defaults track the parent decoder. - :param dtype: (str, optional) ``"float32"`` (default) or ``"float64"``; - other values raise ``ValueError``. - :param device: (str, optional) Torch device, e.g. ``"cpu"`` or ``"cuda"`` (default: ``"cuda"``) - :param compile: (bool, optional, keyword-only) If ``True``, wrap the forward - and loss in :func:`torch.compile`. Most useful with - ``execute="codegen"``. Defaults to ``False``. - :param execute: (str, optional, keyword-only) Forward backend. ``"codegen"`` - (default) partial-evaluates the contraction path into a flat - Python function with named locals; ``"unrolled"`` keeps an - interpretive einsum list; ``"opt_einsum"`` dispatches via - :func:`opt_einsum.contract_expression`. - :param compile_mode: (str, optional, keyword-only) Forwarded to - :func:`torch.compile` (e.g. ``"reduce-overhead"``, - ``"default"``); ignored when ``compile=False``. - :param dynamic_syndromes: (bool, optional, keyword-only) If ``True`` - (default), syndromes are runtime arguments to the - compiled forward, so :meth:`update_dataset` reuses - the codegen/``torch.compile`` artifact when shapes - are unchanged. ``False`` bakes syndromes into the - closure -- faster per call but every - :meth:`update_dataset` rebuilds the graph. Only - affects ``execute="codegen"``. - - **Attributes** - - .. attribute:: noise_params - - ``list[torch.Tensor]`` — the learnable noise-probability tensors; pass - directly to a ``torch.optim`` optimizer. - - .. attribute:: torch_device - - ``torch.device`` derived from the ``device`` constructor argument. - Read-only. - - .. attribute:: observable_flips - - Bool ``torch.Tensor`` of logical flip outcomes for the current - syndrome batch. Assigning a new value also rebuilds the fused - loss closure (the observable indices are baked into the codegen); - prefer :meth:`update_dataset` when swapping syndromes and flips - together. - - **Methods** - - .. method:: current_syndrome_args() - - Return the syndrome argument expected by the callable from - :meth:`loss_fn`: the live tuple when ``dynamic_syndromes=True``, - or ``()`` for static codegen (syndromes are closure-baked). - Re-fetch each step so an intervening :meth:`update_dataset` is - reflected. - - :returns: ``tuple[torch.Tensor, ...]`` - - .. method:: cross_entropy_loss() - - Compute the cross-entropy loss between the predicted logical-flip - probabilities and the observed ``observable_flips``. - - :returns: Scalar ``torch.Tensor`` (differentiable). - - .. method:: decoder_prediction() - - Run the forward pass and return per-shot probabilities. - - :returns: ``torch.Tensor`` of shape ``(num_shots, 2)`` where column 1 - is ``P(logical flip | syndrome)``. - - .. method:: logical_error_rate() - - Fraction of shots where ``argmax`` of :meth:`decoder_prediction` - disagrees with :attr:`observable_flips`. Not differentiable - (runs under :func:`torch.no_grad`). - - :returns: ``float`` in ``[0, 1]``. - - .. method:: loss_fn(from_logits=True) - - Return a compiled callable ``fn(params, syndrome_tuple) -> loss`` - suitable for use with external optimizers or ``torch.compile``. - - :param from_logits: If ``True`` (default), ``params`` are interpreted - as logits and passed through ``sigmoid`` before - contraction. If ``False``, ``params`` are - interpreted as probabilities already in ``[0, 1]``. - :returns: Compiled loss function. - - .. method:: optimize_path(optimize=None, batch_size=-1) - - Cache a contraction path via quimb / opt_einsum and rebuild the - compiled forward. Pass e.g. ``cotengra.HyperOptimizer()`` to run a - more expensive path search; ``None`` falls back to ``"auto"``. - - :param optimize: Optimization options (e.g. a ``cotengra.HyperOptimizer``) - or ``None``. - :param batch_size: Accepted for signature compatibility; ignored. - :returns: Contraction info object. - - .. method:: update_dataset(syndrome_data, observable_flips, enforce_shape=True) - - Swap in a new syndrome batch without rebuilding the tensor network. - If ``dynamic_syndromes=True`` and the batch size is unchanged, the - compiled contraction path is reused; a shape change triggers a full - rebuild. - - :param syndrome_data: numpy.ndarray of shape (num_shots, num_checks) - :param observable_flips: bool array of length num_shots - :param enforce_shape: (bool, optional, default ``True``) Assert - per-tensor shapes match the existing layout - before patching in place. A batch-size change - triggers a full rebuild regardless. - -.. function:: cudaq_qec.plugins.decoders.tensor_network_decoder.make_compiled_step(optimizer, logits, torch_optimizer) - - Build a no-arg callable that runs one Adam step and returns the loss. - - The returned ``step()`` callable zeros gradients, evaluates the - optimizer's fused ``loss_fn(from_logits=True)`` (sigmoid + contraction + - cross-entropy), backpropagates, and steps ``torch_optimizer``. Intended - for training in logit space; pair with :class:`NMOptimizer` constructed - with ``compile=True`` for a ``torch.compile``-d variant. - - :param optimizer: An :class:`NMOptimizer` instance providing the fused - inner loss. - :param logits: Trainable 1-D ``torch.Tensor`` of length - ``len(optimizer.error_inds)`` with ``requires_grad=True``. - :param torch_optimizer: A ``torch.optim`` instance owning ``logits``. - :returns: A no-arg callable that performs one optimization step and - returns the scalar loss as a ``torch.Tensor``. \ No newline at end of file + :returns: Optimizer info object \ No newline at end of file diff --git a/docs/sphinx/components/qec/introduction.rst b/docs/sphinx/components/qec/introduction.rst index 4088795a..99756c73 100644 --- a/docs/sphinx/components/qec/introduction.rst +++ b/docs/sphinx/components/qec/introduction.rst @@ -899,27 +899,6 @@ The decoder returns the probability that the logical observable has flipped for that this GPU will not be supported by the Tensor Network Decoder when CUDA-Q 0.5.0 is released. -Learning the Noise Model from Data -"""""""""""""""""""""""""""""""""" - -When the true per-error noise rates are unknown (typical of real hardware), -the Tensor Network Decoder ships with ``NMOptimizer``, a differentiable -extension that **fits the noise model directly from observed syndromes and -logical-flip outcomes**. Noise probabilities are held as PyTorch tensors -with ``requires_grad=True``; backpropagating through the tensor-network -contraction yields gradients that any ``torch.optim`` optimizer (Adam, SGD, -etc.) can update. Starting from a uniform initial prior and a few hundred -Adam steps is usually enough to recover the per-error rates and beat a -static-uniform baseline on a held-out batch. - -This is offline -- training happens once on a representative syndrome -dataset, and the learned probabilities can then be used as a standard -static noise model for batch decoding. See -:ref:`tensor_network_decoder_api_python` for the ``NMOptimizer`` API and -the *Learning Noise Models with NMOptimizer* example in -:doc:`../../examples_rst/qec/decoders` for a runnable end-to-end demo on a -Stim repetition-code circuit. - Sliding Window Decoder ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/sphinx/examples_rst/qec/decoders.rst b/docs/sphinx/examples_rst/qec/decoders.rst index c86f0a51..fe240ed1 100644 --- a/docs/sphinx/examples_rst/qec/decoders.rst +++ b/docs/sphinx/examples_rst/qec/decoders.rst @@ -136,31 +136,6 @@ See Also: - ``cudaq_qec.plugins.decoders.tensor_network_decoder`` -Learning Noise Models with NMOptimizer -+++++++++++++++++++++++++++++++++++++++ - -:class:`~cudaq_qec.plugins.decoders.tensor_network_decoder.NMOptimizer` extends -the Tensor Network Decoder with differentiable noise probabilities. Given a -batch of observed syndromes and logical-flip outcomes, it fits per-error noise -rates by backpropagating through the tensor-network contraction using PyTorch. - -The following example builds a distance-3 repetition-code circuit with -**asymmetric** noise (data-qubit depolarization is 10x measurement-flip -probability), samples syndromes from Stim, and trains -:class:`NMOptimizer` from a uniform initial prior with 300 Adam steps in -logit space. It then compares the **logical error rate (LER)** of the -learned noise model against a static uniform-prior baseline on a 20k-shot -held-out batch — demonstrating that fitting per-error rates from data -decodes meaningfully better than assuming uniform noise: - -.. literalinclude:: ../../examples/qec/python/noise_learning.py - :language: python - :start-after: [Begin Documentation] - -See Also: - -- :ref:`tensor_network_decoder_api_python` - .. _deploying-ai-decoders: Deploying AI Decoders with TensorRT From e818c7f85c1cad48654d096f907361dbad51d4df Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Tue, 19 May 2026 20:12:21 -0400 Subject: [PATCH 3/5] pr cleanup Signed-off-by: vedika-saravanan --- ...noise_learning.py => tn_noise_learning.py} | 5 +- libs/qec/python/cudaq_qec/__init__.py | 11 + .../decoders/tensor_network_decoder.py | 14 - .../tensor_network_utils/nm_optimizer.py | 1194 +++++++++++++++++ .../tensor_network_utils/noise_models.py | 1194 +---------------- ...n_noise_models.py => test_nm_optimizer.py} | 101 +- 6 files changed, 1279 insertions(+), 1240 deletions(-) rename docs/sphinx/examples/qec/python/{noise_learning.py => tn_noise_learning.py} (98%) create mode 100644 libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py rename libs/qec/python/tests/{test_tn_noise_models.py => test_nm_optimizer.py} (90%) diff --git a/docs/sphinx/examples/qec/python/noise_learning.py b/docs/sphinx/examples/qec/python/tn_noise_learning.py similarity index 98% rename from docs/sphinx/examples/qec/python/noise_learning.py rename to docs/sphinx/examples/qec/python/tn_noise_learning.py index 6ab2784f..f1631b93 100644 --- a/docs/sphinx/examples/qec/python/noise_learning.py +++ b/docs/sphinx/examples/qec/python/tn_noise_learning.py @@ -40,10 +40,7 @@ from beliefmatching.belief_matching import detector_error_model_to_check_matrices import cudaq_qec as qec -from cudaq_qec.plugins.decoders.tensor_network_decoder import ( - NMOptimizer, - make_compiled_step, -) +from cudaq_qec import NMOptimizer, make_compiled_step def parse_detector_error_model(dem): diff --git a/libs/qec/python/cudaq_qec/__init__.py b/libs/qec/python/cudaq_qec/__init__.py index c0973497..ed3646fc 100644 --- a/libs/qec/python/cudaq_qec/__init__.py +++ b/libs/qec/python/cudaq_qec/__init__.py @@ -106,6 +106,17 @@ def iter_namespace(ns_pkg): except (ModuleNotFoundError, ImportError) as e: pass +# Surface the TN noise learner at the top level when its optional +# dependencies (torch, quimb, opt_einsum) are installed; mirrors the +# silent-skip pattern used by the plugin loaders above. +try: + from .plugins.decoders.tensor_network_utils.nm_optimizer import ( + NMOptimizer, + make_compiled_step, + ) +except (ModuleNotFoundError, ImportError): + pass + import cudaq from .loader import qec_set_target_callback diff --git a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_decoder.py b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_decoder.py index fdd71ab6..546c7b20 100644 --- a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_decoder.py +++ b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_decoder.py @@ -536,17 +536,3 @@ def _set_tensor_type(self, tn: TensorNetwork) -> None: dtype=to_backend_dtype(self._dtype, like=self.contractor_config.backend), )) - - -# Re-export the noise learner so callers can use the same module path -# as :class:`TensorNetworkDecoder`. Imported at the bottom because -# :mod:`noise_models` subclasses :class:`TensorNetworkDecoder`. -from .tensor_network_utils.noise_models import ( # noqa: E402 - NMOptimizer, make_compiled_step, -) - -__all__ = [ - "TensorNetworkDecoder", - "NMOptimizer", - "make_compiled_step", -] diff --git a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py new file mode 100644 index 00000000..ae229545 --- /dev/null +++ b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/nm_optimizer.py @@ -0,0 +1,1194 @@ +# ============================================================================ # +# Copyright (c) 2026 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # +"""Differentiable noise learning for the tensor-network decoder. + +:class:`NMOptimizer` fits a factorised per-error noise model to a +syndrome dataset by backpropagating through a torch-backed tensor-network +contraction. :func:`make_compiled_step` is a convenience factory that +builds a no-arg callable for one Adam step in logit space. + +The static noise-model builders (:func:`factorized_noise_model`, +:func:`error_pairs_noise_model`) live in :mod:`.noise_models`. +""" +from __future__ import annotations + +import warnings +from typing import Any, Literal + +import numpy as np +import numpy.typing as npt +import opt_einsum as oe +import torch +from quimb.tensor import TensorNetwork + +from ..tensor_network_decoder import TensorNetworkDecoder +from .tensor_network_factory import ( + tensor_network_from_syndrome_batch, + prepare_syndrome_data_batch, +) + +_ASCII_POOL = ("abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ") + +# Coarse for fp32 because ``1.0 - 1e-12`` rounds back to ``1.0``. +_PRIOR_EPS_BY_DTYPE: dict[str, float] = { + "float64": 1e-12, + "float32": 1e-6, +} +_SUPPORTED_DTYPES: tuple[str, ...] = ("float32", "float64") + + +def _validate_and_clamp_priors(noise_model: Any, dtype: str) -> list[float]: + """Validate noise priors and clamp them into ``[eps, 1 - eps]``. + + The fused cross-entropy reduction in + :meth:`NMOptimizer.cross_entropy_loss` has no ``log`` guard, so a + prior of exactly ``0.0`` or ``1.0`` makes the contraction emit a + zero whose log is ``-inf`` and whose gradient is ``NaN``; training + silently diverges. Stim DEMs occasionally emit ``p=1.0`` + (deterministic detectors) or ``p<1e-15`` (underflow), so we + intercept here rather than force every caller to clamp. + + Behaviour mirrors :class:`torch.nn.BCELoss`-style stable wrappers: + + * Non-finite priors (``NaN`` / ``+/-inf``) raise ``ValueError`` - + these indicate caller bugs, not numerical fragility, and + silently coercing them would hide the real problem. + * Out-of-range priors (``p <= eps`` or ``p >= 1 - eps``) are + clamped into ``[eps, 1 - eps]`` and a single ``UserWarning`` + summarises the number of values changed. + * In-range priors pass through unchanged with no warning. + + Args: + noise_model: array-like of priors, length ``num_errors``. + dtype: contraction dtype string (``"float32"`` / ``"float64"``). + + Returns: + A plain ``list[float]`` so the base + :class:`TensorNetworkDecoder` keeps using its existing + list-based factorised noise model unchanged. + """ + arr = np.asarray(noise_model, dtype=np.float64) + if arr.ndim != 1: + raise ValueError(f"noise_model must be 1-D; got shape {arr.shape}") + if not np.all(np.isfinite(arr)): + bad = np.where(~np.isfinite(arr))[0] + raise ValueError( + f"All priors must be finite; got non-finite values at error " + f"indices {bad.tolist()}: {arr[bad].tolist()}") + + dtype_str = str(dtype) + if dtype_str not in _PRIOR_EPS_BY_DTYPE: + raise ValueError(f"Unsupported dtype {dtype_str!r}; " + f"expected one of {sorted(_PRIOR_EPS_BY_DTYPE)}.") + eps = _PRIOR_EPS_BY_DTYPE[dtype_str] + out_of_range = (arr < eps) | (arr > 1.0 - eps) + if np.any(out_of_range): + warnings.warn( + f"Clamped {int(out_of_range.sum())}/{len(arr)} NMOptimizer " + f"priors into [{eps}, {1.0 - eps}] for numerical stability; " + f"values at or outside the (0, 1) boundary produce -inf " + f"cross-entropy loss and NaN gradients in the fused codegen.", + UserWarning, + stacklevel=3, + ) + arr = np.clip(arr, eps, 1.0 - eps) + return arr.tolist() + + +def remap_eq_to_ascii(eq: str) -> str: + """Rewrite an einsum equation so every label is in ``[a-zA-Z]``. + + Needed because :mod:`opt_einsum` falls back to non-ASCII unicode + labels once the total index count exceeds 52, but + :func:`torch.einsum` rejects them. Raises if a single step has more + than 52 distinct labels. + """ + if eq.isascii(): + return eq + if "->" in eq: + lhs, rhs = eq.split("->") + else: + lhs, rhs = eq, None + + mapping: dict[str, str] = {} + out_lhs_chars: list[str] = [] + for c in lhs: + if c == ",": + out_lhs_chars.append(c) + continue + if c not in mapping: + if len(mapping) >= len(_ASCII_POOL): + raise ValueError( + f"Einsum step '{eq}' has more than {len(_ASCII_POOL)} " + "distinct labels; cannot remap to ASCII.") + mapping[c] = _ASCII_POOL[len(mapping)] + out_lhs_chars.append(mapping[c]) + out_lhs = "".join(out_lhs_chars) + if rhs is None: + return out_lhs + out_rhs_chars: list[str] = [] + for c in rhs: + if c not in mapping: + raise ValueError( + f"Einsum step '{eq}' has output label {c!r} not present " + "on the LHS; cannot remap.") + out_rhs_chars.append(mapping[c]) + return f"{out_lhs}->{''.join(out_rhs_chars)}" + + +class NMOptimizer(TensorNetworkDecoder): + """Differentiable noise-model optimiser for the TN decoder. + + The factorised noise probabilities live in the torch autograd graph + and are fit to a fixed syndrome batch by minimising the cross-entropy + of the decoder's logical prediction against the observed flips. + + The forward pass is materialised once at construction and reused + across optimisation steps. Optionally call :meth:`optimize_path` + (e.g. with ``cotengra.HyperOptimizer()``) to pin a better contraction + path; the JIT is rebuilt automatically. + + .. warning:: + + Priors are clamped into ``[eps, 1 - eps]`` only at construction; + an unconstrained optimiser step on :attr:`noise_params` can push + them past the boundary, after which :meth:`cross_entropy_loss` + returns ``NaN`` gradients. Prefer logit-space training via + :func:`make_compiled_step` (shown below), or clamp the tensor + under :func:`torch.no_grad` after each step. + + Args: + H: Parity check matrix, shape ``(num_checks, num_errors)``. + logical_obs: Logical observable matrix, shape ``(1, num_errors)``. + noise_model: Initial per-error probabilities, length ``num_errors``. + Each value must be strictly in ``(0, 1)``; values at or + outside the boundary (``p <= eps`` or ``p >= 1 - eps``, + with ``eps`` dtype-dependent) are auto-clamped at + construction with a :class:`UserWarning`. Non-finite + priors raise :class:`ValueError`. + syndrome_data: Syndrome batch, shape ``(shots, num_checks)``. + observable_flips: Observable flip outcomes, shape ``(shots,)``. + check_inds, error_inds, logical_inds, logical_tags: Optional index + and tag names; defaults track the parent decoder. + dtype: Tensor data type (e.g. ``"float32"``). + device: ``"cuda"`` (default) or ``"cpu"``. + compile: If ``True``, wrap the forward in :func:`torch.compile`. + Most useful with ``execute="codegen"``. + execute: Forward backend. ``"codegen"`` (default) partial-evaluates + the path into a flat Python function; ``"unrolled"`` keeps an + interpretive einsum list; ``"opt_einsum"`` dispatches via + :func:`opt_einsum.contract_expression`. + compile_mode: Forwarded to :func:`torch.compile`; ignored when + ``compile=False``. + dynamic_syndromes: If ``True`` (default), syndromes are runtime + arguments to the compiled forward, so :meth:`update_dataset` + does not retrigger codegen / ``torch.compile`` (provided + shapes are unchanged). ``False`` bakes the syndromes into + the closure as constants -- fewer runtime einsums, but every + :meth:`update_dataset` call rebuilds the graph. Only affects + ``execute="codegen"``. + + Example (logit-space, no clamping needed):: + + opt = NMOptimizer(H, logical_obs, priors, + syndrome_data, obs_flips) + opt.optimize_path(optimize=ctg.HyperOptimizer()) + logits = torch.logit(opt.noise_params[0].detach()).requires_grad_() + torch_opt = torch.optim.Adam([logits], lr=0.01) + step = make_compiled_step(opt, logits, torch_opt) + for _ in range(100): + loss = step() + """ + + def __init__( + self, + H: npt.NDArray[Any], + logical_obs: npt.NDArray[Any], + noise_model: list[float], + syndrome_data: npt.NDArray[Any], + observable_flips: npt.NDArray[Any], + check_inds: list[str] | None = None, + error_inds: list[str] | None = None, + logical_inds: list[str] | None = None, + logical_tags: list[str] | None = None, + dtype: str = "float32", + device: str = "cuda", + *, + compile: bool = False, + execute: Literal["codegen", "unrolled", "opt_einsum"] = "codegen", + compile_mode: str | None = None, + dynamic_syndromes: bool = True, + ) -> None: + if execute not in ("unrolled", "opt_einsum", "codegen"): + raise ValueError(f"Invalid execute mode: {execute!r}") + if dtype not in _SUPPORTED_DTYPES: + raise ValueError(f"Invalid dtype {dtype!r}; expected one of " + f"{list(_SUPPORTED_DTYPES)}.") + + # Sanitise once so the base TN tensors and ``self._noise_probs`` + # see identical values (see :func:`_validate_and_clamp_priors`). + noise_model = _validate_and_clamp_priors(noise_model, dtype) + + super().__init__( + H, + logical_obs, + noise_model, + check_inds=check_inds, + error_inds=error_inds, + logical_inds=logical_inds, + logical_tags=logical_tags, + contract_noise_model=False, + dtype=dtype, + device=device, + ) + + # Force the torch backend so tensor data lives in the autograd + # graph (the base class would otherwise pick cutensornet/numpy + # on GPU). Contractions still go through codegen / cotengra + # below, not cuTensorNet. + if self.contractor_config.contractor_name == "cutensornet" \ + and self.contractor_config.backend != "torch": + warnings.warn( + "NMOptimizer requires the torch backend for autograd; " + f"switching contractor backend from " + f"{self.contractor_config.backend!r} to 'torch'. " + "Contractions are executed via codegen/opt_einsum, not " + "cuTensorNet.", + stacklevel=3, + ) + self._set_contractor( + "cutensornet", + self.contractor_config.device, + "torch", + dtype, + ) + + # Swap the base's placeholder single-syndrome TN for a batched one. + self._syndrome_tags = [f"SYN_{i}" for i in range(len(self.check_inds))] + self.syndrome_tn = tensor_network_from_syndrome_batch( + syndrome_data, + self.check_inds, + batch_index="batch_index", + tags=self._syndrome_tags, + ) + self._batch_size = syndrome_data.shape[0] + + # Re-stitch ``full_tn`` around the batched syndrome TN. + self.full_tn = TensorNetwork() + self.full_tn = self.full_tn.combine(self.code_tn, virtual=True) + self.full_tn = self.full_tn.combine(self.logical_tn, virtual=True) + self.full_tn = self.full_tn.combine(self.syndrome_tn, virtual=True) + self.full_tn = self.full_tn.combine(self.noise_model, virtual=True) + + self._set_tensor_type(self.syndrome_tn) + + torch_dtype = getattr(torch, self._dtype) + self._noise_probs = torch.tensor( + noise_model, + dtype=torch_dtype, + device=self.torch_device, + requires_grad=True, + ) + # The base's noise tensors stay in ``full_tn`` as numpy + # placeholders: ``_snapshot_arrays_and_eq`` uses ``id()`` to + # locate their positions, then ``self._noise_probs`` (autograd + # live) is written into those slots. Do not strip them. + + self._suspend_loss_rebuild = True + self.observable_flips = observable_flips + + self._use_torch_compile = compile + self._execute_mode = execute + self._torch_compile_mode = compile_mode + self._dynamic_syndromes = dynamic_syndromes + self._compiled_predict: Any | None = None + self._syndrome_tuple: tuple[torch.Tensor, ...] = () + self._snapshot_arrays_and_eq() + self._suspend_loss_rebuild = False + + @property + def torch_device(self) -> torch.device: + """The ``torch.device`` matching the contractor config.""" + if "cuda" in self.contractor_config.device: + return torch.device(f"cuda:{self.contractor_config.device_id}",) + return torch.device("cpu") + + def _set_tensor_type(self, tn: TensorNetwork) -> None: + """Move all tensor data in *tn* to torch on the configured device. + + Overrides the base ``autoray``-routed implementation so gradients + flow through the noise-model tensors. + """ + torch_dtype = getattr(torch, self._dtype) + dev = self.torch_device + + def _to_torch(x): + if isinstance(x, torch.Tensor): + return x.to(device=dev, dtype=torch_dtype) + return torch.tensor( + np.asarray(x), + dtype=torch_dtype, + device=dev, + ) + + tn.apply_to_arrays(_to_torch) + + @property + def observable_flips(self) -> torch.Tensor: + """Boolean tensor of observable flip outcomes.""" + return self._observable_flips + + @observable_flips.setter + def observable_flips(self, value: Any) -> None: + dev = self.torch_device + if not isinstance(value, torch.Tensor): + self._observable_flips = torch.tensor( + value, + dtype=torch.bool, + device=dev, + ) + else: + self._observable_flips = value.bool().to(dev) + self.obs_idx_true = torch.where(self._observable_flips)[0] + self.obs_idx_false = torch.where(~self._observable_flips)[0] + # The fused loss bakes ``obs_idx_true/false`` into its closure + # and must be rebuilt when they change. Skip when a full + # snapshot rebuild is already pending (gated by + # ``_suspend_loss_rebuild``) or before first ``__init__``. + if (getattr(self, "_compiled_predict", None) is not None and + not getattr(self, "_suspend_loss_rebuild", False)): + self._compile_loss() + + @property + def noise_params(self) -> list[torch.Tensor]: + """Trainable noise probabilities, ready for ``torch.optim``. + + Clamped to ``[eps, 1 - eps]`` only at construction; an + unconstrained step can push past the boundary and produce + ``NaN`` gradients on the next :meth:`cross_entropy_loss`. + See the class warning for safe training patterns. + """ + return [self._noise_probs] + + def _snapshot_arrays_and_eq(self) -> None: + self._eq_batch = self.full_tn.get_equation( + output_inds=("batch_index", self.logical_obs_inds[0])) + tensors = list(self.full_tn.tensors) + self._tensors_ref = tensors + + noise_ids = {id(t) for t in self.noise_model.tensors} + syndrome_ids = {id(t) for t in self.syndrome_tn.tensors} + + self._noise_pos_for_error: dict[str, int] = {} + syndrome_positions_list: list[int] = [] + self._static_positions: list[int] = [] + + for i, t in enumerate(tensors): + if id(t) in noise_ids: + self._noise_pos_for_error[t.inds[0]] = i + elif id(t) in syndrome_ids: + syndrome_positions_list.append(i) + else: + self._static_positions.append(i) + + # Guard against a future quimb that copies tensors on virtual + # combine: every tensor in ``full_tn`` must classify into + # exactly one bucket, else the predict path rebuilds the + # operand list with a None slot or a misplaced placeholder. + n_classified = (len(self._noise_pos_for_error) + + len(syndrome_positions_list) + + len(self._static_positions)) + assert n_classified == len(tensors) + assert len(self._noise_pos_for_error) == len(self.error_inds) + + self._syndrome_positions: list[tuple[int, None]] = [ + (i, None) for i in syndrome_positions_list + ] + + self._noise_pos_ordered = tuple( + self._noise_pos_for_error[ei] for ei in self.error_inds) + + torch_dtype = getattr(torch, self._dtype) + dev = self.torch_device + + def _as_torch(x): + if isinstance(x, torch.Tensor): + return x.detach().to(device=dev, dtype=torch_dtype) + return torch.as_tensor(np.asarray(x), dtype=torch_dtype, device=dev) + + self._static_arrays: dict[int, torch.Tensor] = { + i: _as_torch(self._tensors_ref[i].data) + for i in self._static_positions + } + self._syndrome_arrays: list[torch.Tensor] = [ + _as_torch(self._tensors_ref[i].data) + for i in syndrome_positions_list + ] + self._syndrome_tuple = tuple(self._syndrome_arrays) + # Used by :meth:`_update_data` to detect layout changes that + # invalidate the cached path / codegen / oe expr. + self._syndrome_shapes: tuple[tuple[int, ...], ...] = tuple( + tuple(s.shape) for s in self._syndrome_arrays) + + if self._execute_mode == "opt_einsum": + shapes = tuple(t.shape for t in tensors) + self._oe_expr = oe.contract_expression( + self._eq_batch, + *shapes, + optimize=self.path_batch + if self.path_batch not in (None, "auto") else "auto", + ) + self._path_steps = None + else: + self._oe_expr = None + # Flatten the path into ``[(eq, idxs, sorted_desc), ...]``; + # ``sorted_desc`` is precomputed for the unrolled-mode pop + # walk, and labels are remapped to ASCII because + # opt_einsum falls back to unicode past 52 indices and + # torch.einsum rejects those. + shapes = tuple(t.shape for t in tensors) + _, info = oe.contract_path( + self._eq_batch, + *shapes, + shapes=True, + optimize=self.path_batch + if self.path_batch not in (None, "auto") else "auto", + ) + self._path_steps = [(remap_eq_to_ascii(step[2]), tuple(step[0]), + tuple(sorted(step[0], reverse=True))) + for step in info.contraction_list] + + self._compile_predict() + self._compile_loss() + + def _compile_predict(self) -> None: + """Build ``self._predict_fn`` for the configured execute mode.""" + builders = { + "opt_einsum": self._build_predict_opt_einsum, + "unrolled": self._build_predict_unrolled, + "codegen": self._build_predict_codegen, + } + self._predict_fn = builders[self._execute_mode]() + self._compiled_predict = self._maybe_torch_compile(self._predict_fn, + kind="predict") + + def _build_predict_opt_einsum(self): + """opt_einsum-backed predict: reuse the cached contract expression.""" + static_arrays = self._static_arrays + syndrome_positions = tuple(p for p, _t in self._syndrome_positions) + noise_pos_ordered = self._noise_pos_ordered + n = len(self._tensors_ref) + oe_expr = self._oe_expr + + def _predict(noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...]) -> torch.Tensor: + noise_stacked = torch.stack((1.0 - noise_probs, noise_probs), + dim=-1) + arrays: list[torch.Tensor] = [None] * n # type: ignore + for pos, arr in static_arrays.items(): + arrays[pos] = arr + for pos, arr in zip(syndrome_positions, syndrome_tuple): + arrays[pos] = arr + for k, pos in enumerate(noise_pos_ordered): + arrays[pos] = noise_stacked[k] + # Torch backend is auto-selected from the tensor type; + # avoids the per-call ``backend=`` dispatch. + out = oe_expr(*arrays) + return out / out.sum(dim=1, keepdim=True) + + return _predict + + def _build_predict_unrolled(self): + """Unrolled predict: walk the cached pairwise contraction path.""" + static_arrays = self._static_arrays + syndrome_positions = tuple(p for p, _t in self._syndrome_positions) + noise_pos_ordered = self._noise_pos_ordered + n = len(self._tensors_ref) + path_steps = self._path_steps + + def _predict(noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...]) -> torch.Tensor: + noise_stacked = torch.stack((1.0 - noise_probs, noise_probs), + dim=-1) + ops: list[torch.Tensor] = [None] * n # type: ignore + for pos, arr in static_arrays.items(): + ops[pos] = arr + for pos, arr in zip(syndrome_positions, syndrome_tuple): + ops[pos] = arr + for k, pos in enumerate(noise_pos_ordered): + ops[pos] = noise_stacked[k] + for eq_str, idxs, sorted_idxs in path_steps: + picked = [ops[i] for i in idxs] + for i in sorted_idxs: + ops.pop(i) + ops.append(torch.einsum(eq_str, *picked)) + out = ops[0] + return out / out.sum(dim=1, keepdim=True) + + return _predict + + def _build_predict_codegen(self): + """Codegen predict: partial-eval'd flat Python with named locals.""" + static_arrays = self._static_arrays + syndrome_positions = tuple(p for p, _t in self._syndrome_positions) + noise_pos_ordered = self._noise_pos_ordered + n = len(self._tensors_ref) + syndrome_tensors = list(self._syndrome_arrays) + codegen_fn = self._build_codegen_predict( + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + self._path_steps, + syndrome_tensors, + dynamic_syndromes=self._dynamic_syndromes, + ) + self._codegen_fn = codegen_fn + self._codegen_n_folded = getattr(codegen_fn, "_n_folded", 0) + self._codegen_n_runtime = getattr(codegen_fn, "_n_runtime", 0) + + if self._dynamic_syndromes: + return codegen_fn + + # Static mode bakes syndromes into the closure and returns a + # 1-arg callable; wrap to match the public 2-arg signature. + def _predict_static( + noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...] = () + ) -> torch.Tensor: + return codegen_fn(noise_probs) + + return _predict_static + + def _maybe_torch_compile(self, fn, *, kind: str): + """Wrap ``fn`` with :func:`torch.compile` if requested. + + On any compile failure, warn and fall back to eager. ``kind`` + is included in the warning to disambiguate predict vs loss. + """ + if not self._use_torch_compile: + return fn + try: + kwargs = self._torch_compile_kwargs() + return torch.compile(fn, **kwargs) + except Exception as exc: # pragma: no cover + warnings.warn( + f"torch.compile {kind} failed ({exc!r}); " + "falling back to eager.", + RuntimeWarning, + stacklevel=2, + ) + return fn + + def _compile_loss(self) -> None: + """Build the ``(input, syndromes) -> scalar_loss`` callables. + + Two variants are produced: one accepting logits (sigmoid applied + inside) and one accepting probabilities directly. + """ + if self._execute_mode == "codegen": + logits_fn, probs_fn = self._build_loss_codegen() + else: + logits_fn, probs_fn = self._build_loss_wrapped() + + self._loss_from_logits_fn = logits_fn + self._loss_from_probs_fn = probs_fn + self._compiled_loss_from_logits = self._maybe_torch_compile(logits_fn, + kind="loss") + self._compiled_loss_from_probs = self._maybe_torch_compile(probs_fn, + kind="loss") + + def _build_loss_codegen(self): + """Codegen loss: fuse the CE reduction into the contraction graph.""" + static_arrays = self._static_arrays + syndrome_positions = tuple(p for p, _t in self._syndrome_positions) + noise_pos_ordered = self._noise_pos_ordered + n = len(self._tensors_ref) + syndrome_tensors = list(self._syndrome_arrays) + + codegen_logits = self._build_codegen_loss( + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + self._path_steps, + syndrome_tensors, + obs_idx_true=self.obs_idx_true, + obs_idx_false=self.obs_idx_false, + dynamic_syndromes=self._dynamic_syndromes, + from_logits=True, + ) + codegen_probs = self._build_codegen_loss( + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + self._path_steps, + syndrome_tensors, + obs_idx_true=self.obs_idx_true, + obs_idx_false=self.obs_idx_false, + dynamic_syndromes=self._dynamic_syndromes, + from_logits=False, + ) + + if self._dynamic_syndromes: + return codegen_logits, codegen_probs + + # Static codegen bakes syndromes into the closure and returns a + # 1-arg callable; wrap to match the public 2-arg signature. + def _loss_from_logits_static( + logits: torch.Tensor, syndrome_tuple: tuple[torch.Tensor, ...] = () + ) -> torch.Tensor: + return codegen_logits(logits) + + def _loss_from_probs_static( + noise_probs: torch.Tensor, + syndrome_tuple: tuple[torch.Tensor, ...] = () + ) -> torch.Tensor: + return codegen_probs(noise_probs) + + return _loss_from_logits_static, _loss_from_probs_static + + def _build_loss_wrapped(self): + """opt_einsum / unrolled loss: wrap CE around ``self._predict_fn``.""" + obs_t = self.obs_idx_true + obs_f = self.obs_idx_false + predict_fn = self._predict_fn + + if self._dynamic_syndromes: + + def _loss_from_probs(noise_probs, syndromes): + p = predict_fn(noise_probs, syndromes) + return (-torch.log(p[obs_t, 1]).sum() - + torch.log(p[obs_f, 0]).sum()) + + def _loss_from_logits(logits, syndromes): + p = predict_fn(torch.sigmoid(logits), syndromes) + return (-torch.log(p[obs_t, 1]).sum() - + torch.log(p[obs_f, 0]).sum()) + else: + + def _loss_from_probs(noise_probs, syndromes=()): + p = predict_fn(noise_probs, ()) + return (-torch.log(p[obs_t, 1]).sum() - + torch.log(p[obs_f, 0]).sum()) + + def _loss_from_logits(logits, syndromes=()): + p = predict_fn(torch.sigmoid(logits), ()) + return (-torch.log(p[obs_t, 1]).sum() - + torch.log(p[obs_f, 0]).sum()) + + return _loss_from_logits, _loss_from_probs + + def _torch_compile_kwargs(self) -> dict[str, Any]: + """Build kwargs for :func:`torch.compile`. + + Defaults to ``mode="reduce-overhead"`` on CUDA so kernel-launch + overhead is amortised via CUDA Graphs; a ``compile_mode=...`` + passed to the constructor overrides this. + """ + kwargs: dict[str, Any] = {"dynamic": False} + if self._torch_compile_mode is not None: + kwargs["mode"] = self._torch_compile_mode + elif self.torch_device.type == "cuda": + kwargs["mode"] = "reduce-overhead" + return kwargs + + @staticmethod + def _codegen_partial_eval(n, static_arrays, syndrome_positions, + noise_pos_ordered, path_steps, syndrome_tensors, + dynamic_syndromes: bool): + """Partial-evaluate ``path_steps``; return the codegen building blocks. + + Steps whose inputs are all static are evaluated eagerly under + ``torch.no_grad`` and become closure constants; the remaining + steps become source lines. + + Returns ``(runtime_lines, closure_vars, used_static, final_state, + n_folded)``: emitted source lines, name -> tensor map for the + function namespace, the subset of names actually referenced, the + single surviving state slot ``(name, is_dynamic, value_or_None)``, + and the count of folded steps. + """ + static_positions = sorted(static_arrays.keys()) + noise_pos_set = set(noise_pos_ordered) + syn_pos_set = set(syndrome_positions) + # O(1) reverse lookup tables (avoid repeated list.index() inside + # the per-step loop below — was O(N^2) for large path lengths). + noise_pos_to_k = {pos: k for k, pos in enumerate(noise_pos_ordered)} + syn_pos_to_sidx = { + pos: sidx for sidx, pos in enumerate(syndrome_positions) + } + static_pos_to_sidx = { + pos: sidx for sidx, pos in enumerate(static_positions) + } + + # state[pos] = (var_name, is_dynamic, concrete_value_or_None) + state: list[tuple[str, bool, torch.Tensor | None]] = [] + for pos in range(n): + if pos in noise_pos_set: + k = noise_pos_to_k[pos] + state.append((f"_n{k}", True, None)) + elif pos in syn_pos_set: + sidx = syn_pos_to_sidx[pos] + if dynamic_syndromes: + state.append((f"_S{sidx}", True, None)) + else: + state.append((f"_S{sidx}", False, syndrome_tensors[sidx])) + else: + sidx = static_pos_to_sidx[pos] + state.append( + (f"_C{sidx}", False, static_arrays[static_positions[sidx]])) + + closure_vars: dict[str, torch.Tensor] = {} + runtime_lines: list[str] = [] + # Names that the emitted source actually references and that must + # be available in the closure namespace. We track this + # structurally as we go instead of re-parsing the source string, + # which is both faster and immune to lexical false positives / + # negatives (e.g. if an einsum equation contained an underscore). + used_static: set[str] = set() + n_folded = 0 + + for step_idx, step in enumerate(path_steps): + eq_str = step[0] + idxs = step[1] + picked = [state[i] for i in idxs] + for i in sorted(idxs, reverse=True): + state.pop(i) + any_dynamic = any(p[1] for p in picked) + out_name = f"_r{step_idx}" + if not any_dynamic: + arrs = [p[2] for p in picked] + with torch.no_grad(): + result = torch.einsum(eq_str, *arrs).contiguous() + static_name = f"_P{step_idx}" + closure_vars[static_name] = result + state.append((static_name, False, result)) + n_folded += 1 + else: + arg_names = [p[0] for p in picked] + # Track which closure names this line will read. ``_n*`` + # names are header-built locals (from noise_probs) so + # they are *not* closure values; everything else is. + for name in arg_names: + if name.startswith(("_C", "_P")): + used_static.add(name) + elif name.startswith("_S") and not dynamic_syndromes: + used_static.add(name) + runtime_lines.append( + f" {out_name} = torch.einsum({eq_str!r}, " + f"{', '.join(arg_names)})") + state.append((out_name, True, None)) + + assert len(state) == 1 + for name in used_static: + if name in closure_vars: + continue + if name.startswith("_C"): + sidx = int(name[2:]) + closure_vars[name] = static_arrays[static_positions[sidx]] + elif name.startswith("_S"): # static-syndromes mode only + sidx = int(name[2:]) + closure_vars[name] = syndrome_tensors[sidx] + + return runtime_lines, closure_vars, used_static, state[0], n_folded + + @staticmethod + def _emit_noise_header(noise_pos_ordered, + transform: str = "identity") -> list[str]: + """Emit source lines materialising ``_n0 .. _n{K-1}``. + + ``transform="identity"`` treats the input as probabilities; + ``"sigmoid"`` treats it as logits and applies ``torch.sigmoid`` + first. A single ``(K, 2)`` stack is built and then sliced, which + keeps the autograd graph compact. + """ + lines: list[str] = [] + if transform == "sigmoid": + lines.append(" _p = torch.sigmoid(noise_probs)") + else: + lines.append(" _p = noise_probs") + lines.append(" _q = 1.0 - _p") + # One stack of shape (K, 2) instead of K stacks of shape (2,). + # ``dim=1`` makes ``_NS[k]`` a contiguous view of length 2. + lines.append(" _NS = torch.stack((_q, _p), dim=1)") + for k in range(len(noise_pos_ordered)): + lines.append(f" _n{k} = _NS[{k}]") + return lines + + @staticmethod + def _emit_syndrome_header(syndrome_positions, + dynamic_syndromes: bool) -> list[str]: + """Emit source lines binding ``_S0 .. _S{S-1}`` to runtime + ``syndromes`` arguments; empty in static mode.""" + if not dynamic_syndromes: + return [] + return [ + f" _S{sidx} = syndromes[{sidx}]" + for sidx in range(len(syndrome_positions)) + ] + + @classmethod + def _build_codegen_predict(cls, + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + path_steps, + syndrome_tensors, + dynamic_syndromes: bool = True): + """Generate ``_predict(noise_probs[, syndromes]) -> (shots, 2)``. + + With ``dynamic_syndromes=False`` syndromes are folded into the + closure, which maximises partial evaluation but forces a rebuild + on every :meth:`update_dataset` call. With ``True`` (default) + syndromes stay runtime arguments and dataset swaps are free. + """ + runtime_lines, closure_vars, _used, final_state, n_folded = ( + cls._codegen_partial_eval( + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + path_steps, + syndrome_tensors, + dynamic_syndromes, + )) + final_name, is_final_dyn, final_value = final_state + fully_static = not is_final_dyn + + body: list[str] = [] + if dynamic_syndromes: + body.append("def _predict(noise_probs, syndromes):") + else: + body.append("def _predict(noise_probs):") + + if fully_static: + # Degenerate case: contraction didn't depend on noise (or any + # other dynamic input). Pre-normalise the constant and + # return it unchanged on every call. + with torch.no_grad(): + normed = final_value / final_value.sum(dim=1, keepdim=True) + closure_vars["_FINAL"] = normed + body.append(" return _FINAL") + runtime_lines = [] + else: + body.extend( + cls._emit_noise_header(noise_pos_ordered, transform="identity")) + body.extend( + cls._emit_syndrome_header(syndrome_positions, + dynamic_syndromes)) + body.extend(runtime_lines) + body.append(f" _out = {final_name}") + body.append(" return _out / _out.sum(dim=1, keepdim=True)") + + return cls._compile_codegen_source(body, closure_vars, n_folded, + len(runtime_lines), "predict") + + @classmethod + def _build_codegen_loss(cls, + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + path_steps, + syndrome_tensors, + obs_idx_true: torch.Tensor, + obs_idx_false: torch.Tensor, + dynamic_syndromes: bool = True, + from_logits: bool = True): + """Generate a fused ``(input, syndromes) -> scalar`` loss callable. + + Pipes the contraction output straight into the cross-entropy + reduction so the whole pipeline (optional sigmoid, contraction, + normalisation, cross-entropy) is a single autograd graph. + + Args: + from_logits: If ``True`` (default), apply ``torch.sigmoid`` to + the input; if ``False``, the input must already be in + ``[0, 1]``. + """ + runtime_lines, closure_vars, _used, final_state, n_folded = ( + cls._codegen_partial_eval( + n, + static_arrays, + syndrome_positions, + noise_pos_ordered, + path_steps, + syndrome_tensors, + dynamic_syndromes, + )) + final_name, is_final_dyn, final_value = final_state + fully_static = not is_final_dyn + + closure_vars["_OBS_T"] = obs_idx_true + closure_vars["_OBS_F"] = obs_idx_false + + body: list[str] = [] + if dynamic_syndromes: + body.append("def _loss(noise_probs, syndromes):") + else: + body.append("def _loss(noise_probs):") + + if fully_static: + # The contraction is a constant; the loss it produces is + # also a constant. We still need the gradient wrt + # noise_probs to be zero (a real-valued zero tensor with a + # graph edge), so emit a no-op multiplication by + # noise_probs.sum() * 0 to keep autograd happy. + with torch.no_grad(): + normed = final_value / final_value.sum(dim=1, keepdim=True) + # Compute the loss eagerly; we can't fold it because + # autograd needs a path back to noise_probs. + ce = (-torch.log(normed[obs_idx_true, 1]).sum() - + torch.log(normed[obs_idx_false, 0]).sum()) + closure_vars["_LOSS"] = ce + body.append(" return _LOSS + 0.0 * noise_probs.sum()") + runtime_lines = [] + else: + transform = "sigmoid" if from_logits else "identity" + body.extend(cls._emit_noise_header(noise_pos_ordered, transform)) + body.extend( + cls._emit_syndrome_header(syndrome_positions, + dynamic_syndromes)) + body.extend(runtime_lines) + # Fused cross-entropy that skips the explicit + # ``_preds = _out / _out.sum(dim=1, keepdim=True)`` step. + # OBS_T and OBS_F partition the batch (every row is in exactly + # one), so: + # -log(p_T[:,1]).sum() - log(p_F[:,0]).sum() + # = log(Z).sum() - log(_out[OBS_T,1]).sum() + # - log(_out[OBS_F,0]).sum() + # where Z = _out[:,0] + _out[:,1]. Saves one (shots, 2) + # division + materialisation and the corresponding backward + # nodes -- ~2-5% per-step on CPU at d=3/r=3. + body.append(f" _out = {final_name}") + body.append(" _z0 = _out[:, 0]") + body.append(" _z1 = _out[:, 1]") + body.append(" return (torch.log(_z0 + _z1).sum() " + "- torch.log(_z1[_OBS_T]).sum() " + "- torch.log(_z0[_OBS_F]).sum())") + + return cls._compile_codegen_source(body, closure_vars, n_folded, + len(runtime_lines), "loss") + + @staticmethod + def _compile_codegen_source(body: list[str], + closure_vars: dict[str, torch.Tensor], + n_folded: int, n_runtime: int, kind: str): + """Compile the assembled function source and return the callable.""" + source = "\n".join(body) + ns: dict[str, Any] = {"torch": torch} + ns.update(closure_vars) + fn_name = "_loss" if kind == "loss" else "_predict" + exec(compile(source, f"", "exec"), ns) + fn = ns[fn_name] + fn._n_folded = n_folded # type: ignore[attr-defined] + fn._n_runtime = n_runtime # type: ignore[attr-defined] + return fn + + def decoder_prediction(self) -> torch.Tensor: + """Run the forward pass; returns ``(shots, 2)`` predictions.""" + return self._compiled_predict(self._noise_probs, self._syndrome_tuple) + + def cross_entropy_loss(self) -> torch.Tensor: + """Cross-entropy loss over the syndrome batch. + + Returns a differentiable scalar; call ``.backward()`` to obtain + gradients w.r.t. :attr:`noise_params`. The fused codegen omits + the ``log`` guard, so a prior at ``0`` or ``1`` yields ``NaN`` + gradients — see :attr:`noise_params` for safe training patterns. + """ + return self._compiled_loss_from_probs(self._noise_probs, + self._syndrome_tuple) + + def current_syndrome_args(self) -> tuple[torch.Tensor, ...]: + """Return the syndrome argument expected by :meth:`loss_fn`. + + Returns ``()`` when syndromes are baked into the closure + (``execute="codegen"`` and ``dynamic_syndromes=False``), else + the current live tuple. Re-fetch each step so an intervening + :meth:`update_dataset` is reflected. + """ + if self._execute_mode == "codegen" and not self._dynamic_syndromes: + return () + return self._syndrome_tuple + + def loss_fn(self, from_logits: bool = True): + """Return a fused ``(input, syndromes) -> scalar`` loss callable. + + Useful when training in logit space (``from_logits=True``, the + default) or when feeding in an externally managed probability + tensor (``from_logits=False``). Compared to + :meth:`cross_entropy_loss`, the parameter is supplied explicitly + per call instead of being read from :attr:`noise_params`. + """ + return (self._compiled_loss_from_logits + if from_logits else self._compiled_loss_from_probs) + + def logical_error_rate(self) -> float: + """Fraction of shots decoded incorrectly. + + Uses a hard argmax threshold; **not** differentiable. + """ + with torch.no_grad(): + predictions = self.decoder_prediction() + pred = predictions[:, 1] > predictions[:, 0] + return float(1 - (pred == self._observable_flips).sum() / + len(self._observable_flips)) + + def _update_data(self, + new_syndrome_arrays: torch.Tensor, + new_observable_flips: npt.NDArray[Any], + enforce_shape: bool = True) -> None: + """In-place dataset swap on already-prepared syndrome tensors. + + ``new_syndrome_arrays`` must be in the internal layout (the + output of :func:`prepare_syndrome_data_batch`, on the right + device, shape ``(syndrome_length, shots, 2)``). Public callers + should use :meth:`update_dataset` instead. + """ + # Patch syndrome tensor data in the quimb TN in place; the + # cotengra path is invalidated below if any shape changed. + for i, tag in enumerate(self._syndrome_tags): + t = self.syndrome_tn.tensors[next( + iter(self.syndrome_tn.tag_map[tag]))] + if enforce_shape: + assert t.data.shape == new_syndrome_arrays[i].shape, ( + f"Shape mismatch for {tag}: " + f"{t.data.shape} vs {new_syndrome_arrays[i].shape}") + t.modify(data=new_syndrome_arrays[i]) + + # Suppress the loss rebuild the ``observable_flips`` setter + # would otherwise trigger; one of the branches below issues it. + self._suspend_loss_rebuild = True + self.observable_flips = new_observable_flips + + torch_dtype = getattr(torch, self._dtype) + dev = self.torch_device + new_shapes: list[tuple[int, ...]] = [] + for k, (pos, _tag) in enumerate(self._syndrome_positions): + data = self._tensors_ref[pos].data + if isinstance(data, torch.Tensor): + arr = data.detach().to(device=dev, dtype=torch_dtype) + else: + arr = torch.as_tensor(np.asarray(data), + dtype=torch_dtype, + device=dev) + self._syndrome_arrays[k] = arr + new_shapes.append(tuple(arr.shape)) + new_shapes_tuple = tuple(new_shapes) + + # Shape change: cached path / codegen / oe expression / compile + # guards are all stale. Drop the path and rebuild from scratch. + # Shapes unchanged: dynamic codegen and the unrolled / + # opt_einsum paths read syndromes per call — refreshing the + # cached tuple is enough. Static codegen baked the old tensors + # into the closure and still needs a full rebuild. + shape_changed = new_shapes_tuple != self._syndrome_shapes + if shape_changed: + self.path_batch = None + self.slicing_batch = tuple() + try: + self._snapshot_arrays_and_eq() + finally: + self._suspend_loss_rebuild = False + return + + self._syndrome_tuple = tuple(self._syndrome_arrays) + if self._execute_mode == "codegen" and not self._dynamic_syndromes: + try: + self._snapshot_arrays_and_eq() + finally: + self._suspend_loss_rebuild = False + else: + # The observable indices may have changed; the loss bakes + # them in, so it still needs a rebuild. + self._suspend_loss_rebuild = False + self._compile_loss() + + def update_dataset(self, + new_syndrome_data: npt.NDArray[Any], + new_observable_flips: npt.NDArray[Any], + enforce_shape: bool = True) -> None: + """Replace the syndrome batch and observable flips. + + Args: + new_syndrome_data: Shape ``(shots, num_checks)``. + new_observable_flips: Shape ``(shots,)``. + enforce_shape: Assert that per-tensor shapes match. A + changing batch size triggers a full rebuild of the + cached contraction path and codegen. + """ + syndrome_arrays = prepare_syndrome_data_batch(new_syndrome_data) + torch_dtype = getattr(torch, self._dtype) + syndrome_arrays = torch.tensor( + syndrome_arrays, + dtype=torch_dtype, + device=self.torch_device, + ).transpose(1, 2) + self._batch_size = int(new_syndrome_data.shape[0]) + self._update_data(syndrome_arrays, new_observable_flips, enforce_shape) + + def optimize_path(self, optimize: Any = None, batch_size: int = -1) -> Any: + """Cache a contraction path via quimb and rebuild the JIT. + + Always routes through :meth:`TensorNetwork.contraction_info` so + the resulting path is compatible with :mod:`opt_einsum` and + manual unrolling -- unlike :meth:`TensorNetworkDecoder.optimize_path`, + which defaults to a cuTensorNet-only path. + + ``batch_size`` is part of the parent ``TensorNetworkDecoder`` + signature (which rebuilds its TN around a fake batch); on the + optimiser the syndrome TN is already batched at construction + and resized in :meth:`update_dataset`, so this argument is + ignored. Kept for Liskov substitution with the parent. + """ + del batch_size + info = self.full_tn.contraction_info( + output_inds=("batch_index", self.logical_obs_inds[0]), + optimize=optimize if optimize is not None else "auto", + ) + self.path_batch = info.path + self.slicing_batch = tuple() + self._snapshot_arrays_and_eq() + return info + + +def make_compiled_step(optimizer: NMOptimizer, logits: torch.Tensor, + torch_optimizer: torch.optim.Optimizer): + """Build a no-arg callable that runs one Adam step and returns the loss. + + The step zeros grads, calls the optimizer's compiled + ``loss_fn(from_logits=True)`` (sigmoid + contraction + cross-entropy + fused), backwards, and steps ``torch_optimizer``. Use this when + training in logit space. + + Args: + optimizer: The :class:`NMOptimizer` providing the fused + inner loss; pass ``compile=True`` at the + :class:`NMOptimizer` constructor for the + ``torch.compile``-d variant. + logits: Trainable 1-D tensor of length ``len(optimizer.error_inds)`` + with ``requires_grad=True``. + torch_optimizer: A ``torch.optim`` instance owning ``logits``. + """ + + # Re-fetch per call so a rebuild from update_dataset or the + # observable_flips setter is picked up; capturing would go stale. + def _step(): + torch_optimizer.zero_grad(set_to_none=True) + loss = optimizer.loss_fn(from_logits=True)( + logits, optimizer.current_syndrome_args()) + loss.backward() + torch_optimizer.step() + return loss + + return _step diff --git a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/noise_models.py b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/noise_models.py index 1fb4346e..4ac03aa5 100644 --- a/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/noise_models.py +++ b/libs/qec/python/cudaq_qec/plugins/decoders/tensor_network_utils/noise_models.py @@ -5,45 +5,23 @@ # This source code and the accompanying materials are made available under # # the terms of the Apache License 2.0 which accompanies this distribution. # # ============================================================================ # -"""Noise-model tensor-network builders and differentiable noise learning. +"""Static noise-model tensor-network builders. -Static noise models -------------------- :func:`factorized_noise_model` and :func:`error_pairs_noise_model` return :class:`quimb.tensor.TensorNetwork` objects whose open indices match the error indices of the parent decoder. The networks are combined with the code / logical / syndrome tensor networks inside :class:`TensorNetworkDecoder`. -Differentiable noise learning ------------------------------- -:class:`NMOptimizer` fits a factorised per-error noise model to a -syndrome dataset by backpropagating through a torch-backed tensor-network -contraction. :func:`make_compiled_step` is a convenience factory that -builds a no-arg callable for one Adam step in logit space. +For differentiable noise learning (:class:`NMOptimizer`, +:func:`make_compiled_step`), see :mod:`.nm_optimizer`. """ from __future__ import annotations -import warnings -from typing import Any, Literal - import numpy as np -import numpy.typing as npt -import opt_einsum as oe -import torch from quimb import oset from quimb.tensor import Tensor, TensorNetwork -from ..tensor_network_decoder import TensorNetworkDecoder -from .tensor_network_factory import ( - tensor_network_from_syndrome_batch, - prepare_syndrome_data_batch, -) - -# --------------------------------------------------------------------------- -# Static noise-model builders -# --------------------------------------------------------------------------- - def factorized_noise_model( error_indices: list[str], @@ -134,1169 +112,3 @@ def error_pairs_noise_model( tags=oset([etag]), )) return TensorNetwork(tensors) - - -# --------------------------------------------------------------------------- -# Differentiable noise learning -# --------------------------------------------------------------------------- - -_ASCII_POOL = ("abcdefghijklmnopqrstuvwxyz" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ") - -# Coarse for fp32 because ``1.0 - 1e-12`` rounds back to ``1.0``. -_PRIOR_EPS_BY_DTYPE: dict[str, float] = { - "float64": 1e-12, - "float32": 1e-6, -} -_SUPPORTED_DTYPES: tuple[str, ...] = ("float32", "float64") - - -def _validate_and_clamp_priors(noise_model: Any, dtype: str) -> list[float]: - """Validate noise priors and clamp them into ``[eps, 1 - eps]``. - - The fused cross-entropy reduction in - :meth:`NMOptimizer.cross_entropy_loss` has no ``log`` guard, so a - prior of exactly ``0.0`` or ``1.0`` makes the contraction emit a - zero whose log is ``-inf`` and whose gradient is ``NaN``; training - silently diverges. Stim DEMs occasionally emit ``p=1.0`` - (deterministic detectors) or ``p<1e-15`` (underflow), so we - intercept here rather than force every caller to clamp. - - Behaviour mirrors :class:`torch.nn.BCELoss`-style stable wrappers: - - * Non-finite priors (``NaN`` / ``+/-inf``) raise ``ValueError`` - - these indicate caller bugs, not numerical fragility, and - silently coercing them would hide the real problem. - * Out-of-range priors (``p <= eps`` or ``p >= 1 - eps``) are - clamped into ``[eps, 1 - eps]`` and a single ``UserWarning`` - summarises the number of values changed. - * In-range priors pass through unchanged with no warning. - - Args: - noise_model: array-like of priors, length ``num_errors``. - dtype: contraction dtype string (``"float32"`` / ``"float64"``). - - Returns: - A plain ``list[float]`` so the base - :class:`TensorNetworkDecoder` keeps using its existing - list-based factorised noise model unchanged. - """ - arr = np.asarray(noise_model, dtype=np.float64) - if arr.ndim != 1: - raise ValueError(f"noise_model must be 1-D; got shape {arr.shape}") - if not np.all(np.isfinite(arr)): - bad = np.where(~np.isfinite(arr))[0] - raise ValueError( - f"All priors must be finite; got non-finite values at error " - f"indices {bad.tolist()}: {arr[bad].tolist()}") - - dtype_str = str(dtype) - if dtype_str not in _PRIOR_EPS_BY_DTYPE: - raise ValueError(f"Unsupported dtype {dtype_str!r}; " - f"expected one of {sorted(_PRIOR_EPS_BY_DTYPE)}.") - eps = _PRIOR_EPS_BY_DTYPE[dtype_str] - out_of_range = (arr < eps) | (arr > 1.0 - eps) - if np.any(out_of_range): - warnings.warn( - f"Clamped {int(out_of_range.sum())}/{len(arr)} NMOptimizer " - f"priors into [{eps}, {1.0 - eps}] for numerical stability; " - f"values at or outside the (0, 1) boundary produce -inf " - f"cross-entropy loss and NaN gradients in the fused codegen.", - UserWarning, - stacklevel=3, - ) - arr = np.clip(arr, eps, 1.0 - eps) - return arr.tolist() - - -def _remap_eq_to_ascii(eq: str) -> str: - """Rewrite an einsum equation so every label is in ``[a-zA-Z]``. - - Needed because :mod:`opt_einsum` falls back to non-ASCII unicode - labels once the total index count exceeds 52, but - :func:`torch.einsum` rejects them. Raises if a single step has more - than 52 distinct labels. - """ - if eq.isascii(): - return eq - if "->" in eq: - lhs, rhs = eq.split("->") - else: - lhs, rhs = eq, None - - mapping: dict[str, str] = {} - out_lhs_chars: list[str] = [] - for c in lhs: - if c == ",": - out_lhs_chars.append(c) - continue - if c not in mapping: - if len(mapping) >= len(_ASCII_POOL): - raise ValueError( - f"Einsum step '{eq}' has more than {len(_ASCII_POOL)} " - "distinct labels; cannot remap to ASCII.") - mapping[c] = _ASCII_POOL[len(mapping)] - out_lhs_chars.append(mapping[c]) - out_lhs = "".join(out_lhs_chars) - if rhs is None: - return out_lhs - out_rhs_chars: list[str] = [] - for c in rhs: - if c not in mapping: - raise ValueError( - f"Einsum step '{eq}' has output label {c!r} not present " - "on the LHS; cannot remap.") - out_rhs_chars.append(mapping[c]) - return f"{out_lhs}->{''.join(out_rhs_chars)}" - - -class NMOptimizer(TensorNetworkDecoder): - """Differentiable noise-model optimiser for the TN decoder. - - The factorised noise probabilities live in the torch autograd graph - and are fit to a fixed syndrome batch by minimising the cross-entropy - of the decoder's logical prediction against the observed flips. - - The forward pass is materialised once at construction and reused - across optimisation steps. Optionally call :meth:`optimize_path` - (e.g. with ``cotengra.HyperOptimizer()``) to pin a better contraction - path; the JIT is rebuilt automatically. - - .. warning:: - - Priors are clamped into ``[eps, 1 - eps]`` only at construction; - an unconstrained optimiser step on :attr:`noise_params` can push - them past the boundary, after which :meth:`cross_entropy_loss` - returns ``NaN`` gradients. Prefer logit-space training via - :func:`make_compiled_step` (shown below), or clamp the tensor - under :func:`torch.no_grad` after each step. - - Args: - H: Parity check matrix, shape ``(num_checks, num_errors)``. - logical_obs: Logical observable matrix, shape ``(1, num_errors)``. - noise_model: Initial per-error probabilities, length ``num_errors``. - Each value must be strictly in ``(0, 1)``; values at or - outside the boundary (``p <= eps`` or ``p >= 1 - eps``, - with ``eps`` dtype-dependent) are auto-clamped at - construction with a :class:`UserWarning`. Non-finite - priors raise :class:`ValueError`. - syndrome_data: Syndrome batch, shape ``(shots, num_checks)``. - observable_flips: Observable flip outcomes, shape ``(shots,)``. - check_inds, error_inds, logical_inds, logical_tags: Optional index - and tag names; defaults track the parent decoder. - dtype: Tensor data type (e.g. ``"float32"``). - device: ``"cuda"`` (default) or ``"cpu"``. - compile: If ``True``, wrap the forward in :func:`torch.compile`. - Most useful with ``execute="codegen"``. - execute: Forward backend. ``"codegen"`` (default) partial-evaluates - the path into a flat Python function; ``"unrolled"`` keeps an - interpretive einsum list; ``"opt_einsum"`` dispatches via - :func:`opt_einsum.contract_expression`. - compile_mode: Forwarded to :func:`torch.compile`; ignored when - ``compile=False``. - dynamic_syndromes: If ``True`` (default), syndromes are runtime - arguments to the compiled forward, so :meth:`update_dataset` - does not retrigger codegen / ``torch.compile`` (provided - shapes are unchanged). ``False`` bakes the syndromes into - the closure as constants -- fewer runtime einsums, but every - :meth:`update_dataset` call rebuilds the graph. Only affects - ``execute="codegen"``. - - Example (logit-space, no clamping needed):: - - opt = NMOptimizer(H, logical_obs, priors, - syndrome_data, obs_flips) - opt.optimize_path(optimize=ctg.HyperOptimizer()) - logits = torch.logit(opt.noise_params[0].detach()).requires_grad_() - torch_opt = torch.optim.Adam([logits], lr=0.01) - step = make_compiled_step(opt, logits, torch_opt) - for _ in range(100): - loss = step() - """ - - def __init__( - self, - H: npt.NDArray[Any], - logical_obs: npt.NDArray[Any], - noise_model: list[float], - syndrome_data: npt.NDArray[Any], - observable_flips: npt.NDArray[Any], - check_inds: list[str] | None = None, - error_inds: list[str] | None = None, - logical_inds: list[str] | None = None, - logical_tags: list[str] | None = None, - dtype: str = "float32", - device: str = "cuda", - *, - compile: bool = False, - execute: Literal["codegen", "unrolled", "opt_einsum"] = "codegen", - compile_mode: str | None = None, - dynamic_syndromes: bool = True, - ) -> None: - if execute not in ("unrolled", "opt_einsum", "codegen"): - raise ValueError(f"Invalid execute mode: {execute!r}") - if dtype not in _SUPPORTED_DTYPES: - raise ValueError(f"Invalid dtype {dtype!r}; expected one of " - f"{list(_SUPPORTED_DTYPES)}.") - - # Sanitise once so the base TN tensors and ``self._noise_probs`` - # see identical values (see :func:`_validate_and_clamp_priors`). - noise_model = _validate_and_clamp_priors(noise_model, dtype) - - super().__init__( - H, - logical_obs, - noise_model, - check_inds=check_inds, - error_inds=error_inds, - logical_inds=logical_inds, - logical_tags=logical_tags, - contract_noise_model=False, - dtype=dtype, - device=device, - ) - - # Force the torch backend so tensor data lives in the autograd - # graph (the base class would otherwise pick cutensornet/numpy - # on GPU). Contractions still go through codegen / cotengra - # below, not cuTensorNet. - if self.contractor_config.contractor_name == "cutensornet" \ - and self.contractor_config.backend != "torch": - warnings.warn( - "NMOptimizer requires the torch backend for autograd; " - f"switching contractor backend from " - f"{self.contractor_config.backend!r} to 'torch'. " - "Contractions are executed via codegen/opt_einsum, not " - "cuTensorNet.", - stacklevel=3, - ) - self._set_contractor( - "cutensornet", - self.contractor_config.device, - "torch", - dtype, - ) - - # Swap the base's placeholder single-syndrome TN for a batched one. - self._syndrome_tags = [f"SYN_{i}" for i in range(len(self.check_inds))] - self.syndrome_tn = tensor_network_from_syndrome_batch( - syndrome_data, - self.check_inds, - batch_index="batch_index", - tags=self._syndrome_tags, - ) - self._batch_size = syndrome_data.shape[0] - - # Re-stitch ``full_tn`` around the batched syndrome TN. - self.full_tn = TensorNetwork() - self.full_tn = self.full_tn.combine(self.code_tn, virtual=True) - self.full_tn = self.full_tn.combine(self.logical_tn, virtual=True) - self.full_tn = self.full_tn.combine(self.syndrome_tn, virtual=True) - self.full_tn = self.full_tn.combine(self.noise_model, virtual=True) - - self._set_tensor_type(self.syndrome_tn) - - torch_dtype = getattr(torch, self._dtype) - self._noise_probs = torch.tensor( - noise_model, - dtype=torch_dtype, - device=self.torch_device, - requires_grad=True, - ) - # The base's noise tensors stay in ``full_tn`` as numpy - # placeholders: ``_snapshot_arrays_and_eq`` uses ``id()`` to - # locate their positions, then ``self._noise_probs`` (autograd - # live) is written into those slots. Do not strip them. - - self._suspend_loss_rebuild = True - self.observable_flips = observable_flips - - self._use_torch_compile = compile - self._execute_mode = execute - self._torch_compile_mode = compile_mode - self._dynamic_syndromes = dynamic_syndromes - self._compiled_predict: Any | None = None - self._syndrome_tuple: tuple[torch.Tensor, ...] = () - self._snapshot_arrays_and_eq() - self._suspend_loss_rebuild = False - - @property - def torch_device(self) -> torch.device: - """The ``torch.device`` matching the contractor config.""" - if "cuda" in self.contractor_config.device: - return torch.device(f"cuda:{self.contractor_config.device_id}",) - return torch.device("cpu") - - def _set_tensor_type(self, tn: TensorNetwork) -> None: - """Move all tensor data in *tn* to torch on the configured device. - - Overrides the base ``autoray``-routed implementation so gradients - flow through the noise-model tensors. - """ - torch_dtype = getattr(torch, self._dtype) - dev = self.torch_device - - def _to_torch(x): - if isinstance(x, torch.Tensor): - return x.to(device=dev, dtype=torch_dtype) - return torch.tensor( - np.asarray(x), - dtype=torch_dtype, - device=dev, - ) - - tn.apply_to_arrays(_to_torch) - - @property - def observable_flips(self) -> torch.Tensor: - """Boolean tensor of observable flip outcomes.""" - return self._observable_flips - - @observable_flips.setter - def observable_flips(self, value: Any) -> None: - dev = self.torch_device - if not isinstance(value, torch.Tensor): - self._observable_flips = torch.tensor( - value, - dtype=torch.bool, - device=dev, - ) - else: - self._observable_flips = value.bool().to(dev) - self.obs_idx_true = torch.where(self._observable_flips)[0] - self.obs_idx_false = torch.where(~self._observable_flips)[0] - # The fused loss bakes ``obs_idx_true/false`` into its closure - # and must be rebuilt when they change. Skip when a full - # snapshot rebuild is already pending (gated by - # ``_suspend_loss_rebuild``) or before first ``__init__``. - if (getattr(self, "_compiled_predict", None) is not None and - not getattr(self, "_suspend_loss_rebuild", False)): - self._compile_loss() - - @property - def noise_params(self) -> list[torch.Tensor]: - """Trainable noise probabilities, ready for ``torch.optim``. - - Clamped to ``[eps, 1 - eps]`` only at construction; an - unconstrained step can push past the boundary and produce - ``NaN`` gradients on the next :meth:`cross_entropy_loss`. - See the class warning for safe training patterns. - """ - return [self._noise_probs] - - def _snapshot_arrays_and_eq(self) -> None: - self._eq_batch = self.full_tn.get_equation( - output_inds=("batch_index", self.logical_obs_inds[0])) - tensors = list(self.full_tn.tensors) - self._tensors_ref = tensors - - noise_ids = {id(t) for t in self.noise_model.tensors} - syndrome_ids = {id(t) for t in self.syndrome_tn.tensors} - - self._noise_pos_for_error: dict[str, int] = {} - syndrome_positions_list: list[int] = [] - self._static_positions: list[int] = [] - - for i, t in enumerate(tensors): - if id(t) in noise_ids: - self._noise_pos_for_error[t.inds[0]] = i - elif id(t) in syndrome_ids: - syndrome_positions_list.append(i) - else: - self._static_positions.append(i) - - # Guard against a future quimb that copies tensors on virtual - # combine: every tensor in ``full_tn`` must classify into - # exactly one bucket, else the predict path rebuilds the - # operand list with a None slot or a misplaced placeholder. - n_classified = (len(self._noise_pos_for_error) + - len(syndrome_positions_list) + - len(self._static_positions)) - assert n_classified == len(tensors) - assert len(self._noise_pos_for_error) == len(self.error_inds) - - self._syndrome_positions: list[tuple[int, None]] = [ - (i, None) for i in syndrome_positions_list - ] - - self._noise_pos_ordered = tuple( - self._noise_pos_for_error[ei] for ei in self.error_inds) - - torch_dtype = getattr(torch, self._dtype) - dev = self.torch_device - - def _as_torch(x): - if isinstance(x, torch.Tensor): - return x.detach().to(device=dev, dtype=torch_dtype) - return torch.as_tensor(np.asarray(x), dtype=torch_dtype, device=dev) - - self._static_arrays: dict[int, torch.Tensor] = { - i: _as_torch(self._tensors_ref[i].data) - for i in self._static_positions - } - self._syndrome_arrays: list[torch.Tensor] = [ - _as_torch(self._tensors_ref[i].data) - for i in syndrome_positions_list - ] - self._syndrome_tuple = tuple(self._syndrome_arrays) - # Used by :meth:`_update_data` to detect layout changes that - # invalidate the cached path / codegen / oe expr. - self._syndrome_shapes: tuple[tuple[int, ...], ...] = tuple( - tuple(s.shape) for s in self._syndrome_arrays) - - if self._execute_mode == "opt_einsum": - shapes = tuple(t.shape for t in tensors) - self._oe_expr = oe.contract_expression( - self._eq_batch, - *shapes, - optimize=self.path_batch - if self.path_batch not in (None, "auto") else "auto", - ) - self._path_steps = None - else: - self._oe_expr = None - # Flatten the path into ``[(eq, idxs, sorted_desc), ...]``; - # ``sorted_desc`` is precomputed for the unrolled-mode pop - # walk, and labels are remapped to ASCII because - # opt_einsum falls back to unicode past 52 indices and - # torch.einsum rejects those. - shapes = tuple(t.shape for t in tensors) - _, info = oe.contract_path( - self._eq_batch, - *shapes, - shapes=True, - optimize=self.path_batch - if self.path_batch not in (None, "auto") else "auto", - ) - self._path_steps = [(_remap_eq_to_ascii(step[2]), tuple(step[0]), - tuple(sorted(step[0], reverse=True))) - for step in info.contraction_list] - - self._compile_predict() - self._compile_loss() - - def _compile_predict(self) -> None: - """Build ``self._predict_fn`` for the configured execute mode.""" - builders = { - "opt_einsum": self._build_predict_opt_einsum, - "unrolled": self._build_predict_unrolled, - "codegen": self._build_predict_codegen, - } - self._predict_fn = builders[self._execute_mode]() - self._compiled_predict = self._maybe_torch_compile(self._predict_fn, - kind="predict") - - def _build_predict_opt_einsum(self): - """opt_einsum-backed predict: reuse the cached contract expression.""" - static_arrays = self._static_arrays - syndrome_positions = tuple(p for p, _t in self._syndrome_positions) - noise_pos_ordered = self._noise_pos_ordered - n = len(self._tensors_ref) - oe_expr = self._oe_expr - - def _predict(noise_probs: torch.Tensor, - syndrome_tuple: tuple[torch.Tensor, ...]) -> torch.Tensor: - noise_stacked = torch.stack((1.0 - noise_probs, noise_probs), - dim=-1) - arrays: list[torch.Tensor] = [None] * n # type: ignore - for pos, arr in static_arrays.items(): - arrays[pos] = arr - for pos, arr in zip(syndrome_positions, syndrome_tuple): - arrays[pos] = arr - for k, pos in enumerate(noise_pos_ordered): - arrays[pos] = noise_stacked[k] - # Torch backend is auto-selected from the tensor type; - # avoids the per-call ``backend=`` dispatch. - out = oe_expr(*arrays) - return out / out.sum(dim=1, keepdim=True) - - return _predict - - def _build_predict_unrolled(self): - """Unrolled predict: walk the cached pairwise contraction path.""" - static_arrays = self._static_arrays - syndrome_positions = tuple(p for p, _t in self._syndrome_positions) - noise_pos_ordered = self._noise_pos_ordered - n = len(self._tensors_ref) - path_steps = self._path_steps - - def _predict(noise_probs: torch.Tensor, - syndrome_tuple: tuple[torch.Tensor, ...]) -> torch.Tensor: - noise_stacked = torch.stack((1.0 - noise_probs, noise_probs), - dim=-1) - ops: list[torch.Tensor] = [None] * n # type: ignore - for pos, arr in static_arrays.items(): - ops[pos] = arr - for pos, arr in zip(syndrome_positions, syndrome_tuple): - ops[pos] = arr - for k, pos in enumerate(noise_pos_ordered): - ops[pos] = noise_stacked[k] - for eq_str, idxs, sorted_idxs in path_steps: - picked = [ops[i] for i in idxs] - for i in sorted_idxs: - ops.pop(i) - ops.append(torch.einsum(eq_str, *picked)) - out = ops[0] - return out / out.sum(dim=1, keepdim=True) - - return _predict - - def _build_predict_codegen(self): - """Codegen predict: partial-eval'd flat Python with named locals.""" - static_arrays = self._static_arrays - syndrome_positions = tuple(p for p, _t in self._syndrome_positions) - noise_pos_ordered = self._noise_pos_ordered - n = len(self._tensors_ref) - syndrome_tensors = list(self._syndrome_arrays) - codegen_fn = self._build_codegen_predict( - n, - static_arrays, - syndrome_positions, - noise_pos_ordered, - self._path_steps, - syndrome_tensors, - dynamic_syndromes=self._dynamic_syndromes, - ) - self._codegen_fn = codegen_fn - self._codegen_n_folded = getattr(codegen_fn, "_n_folded", 0) - self._codegen_n_runtime = getattr(codegen_fn, "_n_runtime", 0) - - if self._dynamic_syndromes: - return codegen_fn - - # Static mode bakes syndromes into the closure and returns a - # 1-arg callable; wrap to match the public 2-arg signature. - def _predict_static( - noise_probs: torch.Tensor, - syndrome_tuple: tuple[torch.Tensor, ...] = () - ) -> torch.Tensor: - return codegen_fn(noise_probs) - - return _predict_static - - def _maybe_torch_compile(self, fn, *, kind: str): - """Wrap ``fn`` with :func:`torch.compile` if requested. - - On any compile failure, warn and fall back to eager. ``kind`` - is included in the warning to disambiguate predict vs loss. - """ - if not self._use_torch_compile: - return fn - try: - kwargs = self._torch_compile_kwargs() - return torch.compile(fn, **kwargs) - except Exception as exc: # pragma: no cover - warnings.warn( - f"torch.compile {kind} failed ({exc!r}); " - "falling back to eager.", - RuntimeWarning, - stacklevel=2, - ) - return fn - - def _compile_loss(self) -> None: - """Build the ``(input, syndromes) -> scalar_loss`` callables. - - Two variants are produced: one accepting logits (sigmoid applied - inside) and one accepting probabilities directly. - """ - if self._execute_mode == "codegen": - logits_fn, probs_fn = self._build_loss_codegen() - else: - logits_fn, probs_fn = self._build_loss_wrapped() - - self._loss_from_logits_fn = logits_fn - self._loss_from_probs_fn = probs_fn - self._compiled_loss_from_logits = self._maybe_torch_compile(logits_fn, - kind="loss") - self._compiled_loss_from_probs = self._maybe_torch_compile(probs_fn, - kind="loss") - - def _build_loss_codegen(self): - """Codegen loss: fuse the CE reduction into the contraction graph.""" - static_arrays = self._static_arrays - syndrome_positions = tuple(p for p, _t in self._syndrome_positions) - noise_pos_ordered = self._noise_pos_ordered - n = len(self._tensors_ref) - syndrome_tensors = list(self._syndrome_arrays) - - codegen_logits = self._build_codegen_loss( - n, - static_arrays, - syndrome_positions, - noise_pos_ordered, - self._path_steps, - syndrome_tensors, - obs_idx_true=self.obs_idx_true, - obs_idx_false=self.obs_idx_false, - dynamic_syndromes=self._dynamic_syndromes, - from_logits=True, - ) - codegen_probs = self._build_codegen_loss( - n, - static_arrays, - syndrome_positions, - noise_pos_ordered, - self._path_steps, - syndrome_tensors, - obs_idx_true=self.obs_idx_true, - obs_idx_false=self.obs_idx_false, - dynamic_syndromes=self._dynamic_syndromes, - from_logits=False, - ) - - if self._dynamic_syndromes: - return codegen_logits, codegen_probs - - # Static codegen bakes syndromes into the closure and returns a - # 1-arg callable; wrap to match the public 2-arg signature. - def _loss_from_logits_static( - logits: torch.Tensor, syndrome_tuple: tuple[torch.Tensor, ...] = () - ) -> torch.Tensor: - return codegen_logits(logits) - - def _loss_from_probs_static( - noise_probs: torch.Tensor, - syndrome_tuple: tuple[torch.Tensor, ...] = () - ) -> torch.Tensor: - return codegen_probs(noise_probs) - - return _loss_from_logits_static, _loss_from_probs_static - - def _build_loss_wrapped(self): - """opt_einsum / unrolled loss: wrap CE around ``self._predict_fn``.""" - obs_t = self.obs_idx_true - obs_f = self.obs_idx_false - predict_fn = self._predict_fn - - if self._dynamic_syndromes: - - def _loss_from_probs(noise_probs, syndromes): - p = predict_fn(noise_probs, syndromes) - return (-torch.log(p[obs_t, 1]).sum() - - torch.log(p[obs_f, 0]).sum()) - - def _loss_from_logits(logits, syndromes): - p = predict_fn(torch.sigmoid(logits), syndromes) - return (-torch.log(p[obs_t, 1]).sum() - - torch.log(p[obs_f, 0]).sum()) - else: - - def _loss_from_probs(noise_probs, syndromes=()): - p = predict_fn(noise_probs, ()) - return (-torch.log(p[obs_t, 1]).sum() - - torch.log(p[obs_f, 0]).sum()) - - def _loss_from_logits(logits, syndromes=()): - p = predict_fn(torch.sigmoid(logits), ()) - return (-torch.log(p[obs_t, 1]).sum() - - torch.log(p[obs_f, 0]).sum()) - - return _loss_from_logits, _loss_from_probs - - def _torch_compile_kwargs(self) -> dict[str, Any]: - """Build kwargs for :func:`torch.compile`. - - Defaults to ``mode="reduce-overhead"`` on CUDA so kernel-launch - overhead is amortised via CUDA Graphs; a ``compile_mode=...`` - passed to the constructor overrides this. - """ - kwargs: dict[str, Any] = {"dynamic": False} - if self._torch_compile_mode is not None: - kwargs["mode"] = self._torch_compile_mode - elif self.torch_device.type == "cuda": - kwargs["mode"] = "reduce-overhead" - return kwargs - - @staticmethod - def _codegen_partial_eval(n, static_arrays, syndrome_positions, - noise_pos_ordered, path_steps, syndrome_tensors, - dynamic_syndromes: bool): - """Partial-evaluate ``path_steps``; return the codegen building blocks. - - Steps whose inputs are all static are evaluated eagerly under - ``torch.no_grad`` and become closure constants; the remaining - steps become source lines. - - Returns ``(runtime_lines, closure_vars, used_static, final_state, - n_folded)``: emitted source lines, name -> tensor map for the - function namespace, the subset of names actually referenced, the - single surviving state slot ``(name, is_dynamic, value_or_None)``, - and the count of folded steps. - """ - static_positions = sorted(static_arrays.keys()) - noise_pos_set = set(noise_pos_ordered) - syn_pos_set = set(syndrome_positions) - # O(1) reverse lookup tables (avoid repeated list.index() inside - # the per-step loop below — was O(N^2) for large path lengths). - noise_pos_to_k = {pos: k for k, pos in enumerate(noise_pos_ordered)} - syn_pos_to_sidx = { - pos: sidx for sidx, pos in enumerate(syndrome_positions) - } - static_pos_to_sidx = { - pos: sidx for sidx, pos in enumerate(static_positions) - } - - # state[pos] = (var_name, is_dynamic, concrete_value_or_None) - state: list[tuple[str, bool, torch.Tensor | None]] = [] - for pos in range(n): - if pos in noise_pos_set: - k = noise_pos_to_k[pos] - state.append((f"_n{k}", True, None)) - elif pos in syn_pos_set: - sidx = syn_pos_to_sidx[pos] - if dynamic_syndromes: - state.append((f"_S{sidx}", True, None)) - else: - state.append((f"_S{sidx}", False, syndrome_tensors[sidx])) - else: - sidx = static_pos_to_sidx[pos] - state.append( - (f"_C{sidx}", False, static_arrays[static_positions[sidx]])) - - closure_vars: dict[str, torch.Tensor] = {} - runtime_lines: list[str] = [] - # Names that the emitted source actually references and that must - # be available in the closure namespace. We track this - # structurally as we go instead of re-parsing the source string, - # which is both faster and immune to lexical false positives / - # negatives (e.g. if an einsum equation contained an underscore). - used_static: set[str] = set() - n_folded = 0 - - for step_idx, step in enumerate(path_steps): - eq_str = step[0] - idxs = step[1] - picked = [state[i] for i in idxs] - for i in sorted(idxs, reverse=True): - state.pop(i) - any_dynamic = any(p[1] for p in picked) - out_name = f"_r{step_idx}" - if not any_dynamic: - arrs = [p[2] for p in picked] - with torch.no_grad(): - result = torch.einsum(eq_str, *arrs).contiguous() - static_name = f"_P{step_idx}" - closure_vars[static_name] = result - state.append((static_name, False, result)) - n_folded += 1 - else: - arg_names = [p[0] for p in picked] - # Track which closure names this line will read. ``_n*`` - # names are header-built locals (from noise_probs) so - # they are *not* closure values; everything else is. - for name in arg_names: - if name.startswith(("_C", "_P")): - used_static.add(name) - elif name.startswith("_S") and not dynamic_syndromes: - used_static.add(name) - runtime_lines.append( - f" {out_name} = torch.einsum({eq_str!r}, " - f"{', '.join(arg_names)})") - state.append((out_name, True, None)) - - assert len(state) == 1 - for name in used_static: - if name in closure_vars: - continue - if name.startswith("_C"): - sidx = int(name[2:]) - closure_vars[name] = static_arrays[static_positions[sidx]] - elif name.startswith("_S"): # static-syndromes mode only - sidx = int(name[2:]) - closure_vars[name] = syndrome_tensors[sidx] - - return runtime_lines, closure_vars, used_static, state[0], n_folded - - @staticmethod - def _emit_noise_header(noise_pos_ordered, - transform: str = "identity") -> list[str]: - """Emit source lines materialising ``_n0 .. _n{K-1}``. - - ``transform="identity"`` treats the input as probabilities; - ``"sigmoid"`` treats it as logits and applies ``torch.sigmoid`` - first. A single ``(K, 2)`` stack is built and then sliced, which - keeps the autograd graph compact. - """ - lines: list[str] = [] - if transform == "sigmoid": - lines.append(" _p = torch.sigmoid(noise_probs)") - else: - lines.append(" _p = noise_probs") - lines.append(" _q = 1.0 - _p") - # One stack of shape (K, 2) instead of K stacks of shape (2,). - # ``dim=1`` makes ``_NS[k]`` a contiguous view of length 2. - lines.append(" _NS = torch.stack((_q, _p), dim=1)") - for k in range(len(noise_pos_ordered)): - lines.append(f" _n{k} = _NS[{k}]") - return lines - - @staticmethod - def _emit_syndrome_header(syndrome_positions, - dynamic_syndromes: bool) -> list[str]: - """Emit source lines binding ``_S0 .. _S{S-1}`` to runtime - ``syndromes`` arguments; empty in static mode.""" - if not dynamic_syndromes: - return [] - return [ - f" _S{sidx} = syndromes[{sidx}]" - for sidx in range(len(syndrome_positions)) - ] - - @classmethod - def _build_codegen_predict(cls, - n, - static_arrays, - syndrome_positions, - noise_pos_ordered, - path_steps, - syndrome_tensors, - dynamic_syndromes: bool = True): - """Generate ``_predict(noise_probs[, syndromes]) -> (shots, 2)``. - - With ``dynamic_syndromes=False`` syndromes are folded into the - closure, which maximises partial evaluation but forces a rebuild - on every :meth:`update_dataset` call. With ``True`` (default) - syndromes stay runtime arguments and dataset swaps are free. - """ - runtime_lines, closure_vars, _used, final_state, n_folded = ( - cls._codegen_partial_eval( - n, - static_arrays, - syndrome_positions, - noise_pos_ordered, - path_steps, - syndrome_tensors, - dynamic_syndromes, - )) - final_name, is_final_dyn, final_value = final_state - fully_static = not is_final_dyn - - body: list[str] = [] - if dynamic_syndromes: - body.append("def _predict(noise_probs, syndromes):") - else: - body.append("def _predict(noise_probs):") - - if fully_static: - # Degenerate case: contraction didn't depend on noise (or any - # other dynamic input). Pre-normalise the constant and - # return it unchanged on every call. - with torch.no_grad(): - normed = final_value / final_value.sum(dim=1, keepdim=True) - closure_vars["_FINAL"] = normed - body.append(" return _FINAL") - runtime_lines = [] - else: - body.extend( - cls._emit_noise_header(noise_pos_ordered, transform="identity")) - body.extend( - cls._emit_syndrome_header(syndrome_positions, - dynamic_syndromes)) - body.extend(runtime_lines) - body.append(f" _out = {final_name}") - body.append(" return _out / _out.sum(dim=1, keepdim=True)") - - return cls._compile_codegen_source(body, closure_vars, n_folded, - len(runtime_lines), "predict") - - @classmethod - def _build_codegen_loss(cls, - n, - static_arrays, - syndrome_positions, - noise_pos_ordered, - path_steps, - syndrome_tensors, - obs_idx_true: torch.Tensor, - obs_idx_false: torch.Tensor, - dynamic_syndromes: bool = True, - from_logits: bool = True): - """Generate a fused ``(input, syndromes) -> scalar`` loss callable. - - Pipes the contraction output straight into the cross-entropy - reduction so the whole pipeline (optional sigmoid, contraction, - normalisation, cross-entropy) is a single autograd graph. - - Args: - from_logits: If ``True`` (default), apply ``torch.sigmoid`` to - the input; if ``False``, the input must already be in - ``[0, 1]``. - """ - runtime_lines, closure_vars, _used, final_state, n_folded = ( - cls._codegen_partial_eval( - n, - static_arrays, - syndrome_positions, - noise_pos_ordered, - path_steps, - syndrome_tensors, - dynamic_syndromes, - )) - final_name, is_final_dyn, final_value = final_state - fully_static = not is_final_dyn - - closure_vars["_OBS_T"] = obs_idx_true - closure_vars["_OBS_F"] = obs_idx_false - - body: list[str] = [] - if dynamic_syndromes: - body.append("def _loss(noise_probs, syndromes):") - else: - body.append("def _loss(noise_probs):") - - if fully_static: - # The contraction is a constant; the loss it produces is - # also a constant. We still need the gradient wrt - # noise_probs to be zero (a real-valued zero tensor with a - # graph edge), so emit a no-op multiplication by - # noise_probs.sum() * 0 to keep autograd happy. - with torch.no_grad(): - normed = final_value / final_value.sum(dim=1, keepdim=True) - # Compute the loss eagerly; we can't fold it because - # autograd needs a path back to noise_probs. - ce = (-torch.log(normed[obs_idx_true, 1]).sum() - - torch.log(normed[obs_idx_false, 0]).sum()) - closure_vars["_LOSS"] = ce - body.append(" return _LOSS + 0.0 * noise_probs.sum()") - runtime_lines = [] - else: - transform = "sigmoid" if from_logits else "identity" - body.extend(cls._emit_noise_header(noise_pos_ordered, transform)) - body.extend( - cls._emit_syndrome_header(syndrome_positions, - dynamic_syndromes)) - body.extend(runtime_lines) - # Fused cross-entropy that skips the explicit - # ``_preds = _out / _out.sum(dim=1, keepdim=True)`` step. - # OBS_T and OBS_F partition the batch (every row is in exactly - # one), so: - # -log(p_T[:,1]).sum() - log(p_F[:,0]).sum() - # = log(Z).sum() - log(_out[OBS_T,1]).sum() - # - log(_out[OBS_F,0]).sum() - # where Z = _out[:,0] + _out[:,1]. Saves one (shots, 2) - # division + materialisation and the corresponding backward - # nodes -- ~2-5% per-step on CPU at d=3/r=3. - body.append(f" _out = {final_name}") - body.append(" _z0 = _out[:, 0]") - body.append(" _z1 = _out[:, 1]") - body.append(" return (torch.log(_z0 + _z1).sum() " - "- torch.log(_z1[_OBS_T]).sum() " - "- torch.log(_z0[_OBS_F]).sum())") - - return cls._compile_codegen_source(body, closure_vars, n_folded, - len(runtime_lines), "loss") - - @staticmethod - def _compile_codegen_source(body: list[str], - closure_vars: dict[str, torch.Tensor], - n_folded: int, n_runtime: int, kind: str): - """Compile the assembled function source and return the callable.""" - source = "\n".join(body) - ns: dict[str, Any] = {"torch": torch} - ns.update(closure_vars) - fn_name = "_loss" if kind == "loss" else "_predict" - exec(compile(source, f"", "exec"), ns) - fn = ns[fn_name] - fn._n_folded = n_folded # type: ignore[attr-defined] - fn._n_runtime = n_runtime # type: ignore[attr-defined] - return fn - - def decoder_prediction(self) -> torch.Tensor: - """Run the forward pass; returns ``(shots, 2)`` predictions.""" - return self._compiled_predict(self._noise_probs, self._syndrome_tuple) - - def cross_entropy_loss(self) -> torch.Tensor: - """Cross-entropy loss over the syndrome batch. - - Returns a differentiable scalar; call ``.backward()`` to obtain - gradients w.r.t. :attr:`noise_params`. The fused codegen omits - the ``log`` guard, so a prior at ``0`` or ``1`` yields ``NaN`` - gradients — see :attr:`noise_params` for safe training patterns. - """ - return self._compiled_loss_from_probs(self._noise_probs, - self._syndrome_tuple) - - def current_syndrome_args(self) -> tuple[torch.Tensor, ...]: - """Return the syndrome argument expected by :meth:`loss_fn`. - - Returns ``()`` when syndromes are baked into the closure - (``execute="codegen"`` and ``dynamic_syndromes=False``), else - the current live tuple. Re-fetch each step so an intervening - :meth:`update_dataset` is reflected. - """ - if self._execute_mode == "codegen" and not self._dynamic_syndromes: - return () - return self._syndrome_tuple - - def loss_fn(self, from_logits: bool = True): - """Return a fused ``(input, syndromes) -> scalar`` loss callable. - - Useful when training in logit space (``from_logits=True``, the - default) or when feeding in an externally managed probability - tensor (``from_logits=False``). Compared to - :meth:`cross_entropy_loss`, the parameter is supplied explicitly - per call instead of being read from :attr:`noise_params`. - """ - return (self._compiled_loss_from_logits - if from_logits else self._compiled_loss_from_probs) - - def logical_error_rate(self) -> float: - """Fraction of shots decoded incorrectly. - - Uses a hard argmax threshold; **not** differentiable. - """ - with torch.no_grad(): - predictions = self.decoder_prediction() - pred = predictions[:, 1] > predictions[:, 0] - return float(1 - (pred == self._observable_flips).sum() / - len(self._observable_flips)) - - def _update_data(self, - new_syndrome_arrays: torch.Tensor, - new_observable_flips: npt.NDArray[Any], - enforce_shape: bool = True) -> None: - """In-place dataset swap on already-prepared syndrome tensors. - - ``new_syndrome_arrays`` must be in the internal layout (the - output of :func:`prepare_syndrome_data_batch`, on the right - device, shape ``(syndrome_length, shots, 2)``). Public callers - should use :meth:`update_dataset` instead. - """ - # Patch syndrome tensor data in the quimb TN in place; the - # cotengra path is invalidated below if any shape changed. - for i, tag in enumerate(self._syndrome_tags): - t = self.syndrome_tn.tensors[next( - iter(self.syndrome_tn.tag_map[tag]))] - if enforce_shape: - assert t.data.shape == new_syndrome_arrays[i].shape, ( - f"Shape mismatch for {tag}: " - f"{t.data.shape} vs {new_syndrome_arrays[i].shape}") - t.modify(data=new_syndrome_arrays[i]) - - # Suppress the loss rebuild the ``observable_flips`` setter - # would otherwise trigger; one of the branches below issues it. - self._suspend_loss_rebuild = True - self.observable_flips = new_observable_flips - - torch_dtype = getattr(torch, self._dtype) - dev = self.torch_device - new_shapes: list[tuple[int, ...]] = [] - for k, (pos, _tag) in enumerate(self._syndrome_positions): - data = self._tensors_ref[pos].data - if isinstance(data, torch.Tensor): - arr = data.detach().to(device=dev, dtype=torch_dtype) - else: - arr = torch.as_tensor(np.asarray(data), - dtype=torch_dtype, - device=dev) - self._syndrome_arrays[k] = arr - new_shapes.append(tuple(arr.shape)) - new_shapes_tuple = tuple(new_shapes) - - # Shape change: cached path / codegen / oe expression / compile - # guards are all stale. Drop the path and rebuild from scratch. - # Shapes unchanged: dynamic codegen and the unrolled / - # opt_einsum paths read syndromes per call — refreshing the - # cached tuple is enough. Static codegen baked the old tensors - # into the closure and still needs a full rebuild. - shape_changed = new_shapes_tuple != self._syndrome_shapes - if shape_changed: - self.path_batch = None - self.slicing_batch = tuple() - try: - self._snapshot_arrays_and_eq() - finally: - self._suspend_loss_rebuild = False - return - - self._syndrome_tuple = tuple(self._syndrome_arrays) - if self._execute_mode == "codegen" and not self._dynamic_syndromes: - try: - self._snapshot_arrays_and_eq() - finally: - self._suspend_loss_rebuild = False - else: - # The observable indices may have changed; the loss bakes - # them in, so it still needs a rebuild. - self._suspend_loss_rebuild = False - self._compile_loss() - - def update_dataset(self, - new_syndrome_data: npt.NDArray[Any], - new_observable_flips: npt.NDArray[Any], - enforce_shape: bool = True) -> None: - """Replace the syndrome batch and observable flips. - - Args: - new_syndrome_data: Shape ``(shots, num_checks)``. - new_observable_flips: Shape ``(shots,)``. - enforce_shape: Assert that per-tensor shapes match. A - changing batch size triggers a full rebuild of the - cached contraction path and codegen. - """ - syndrome_arrays = prepare_syndrome_data_batch(new_syndrome_data) - torch_dtype = getattr(torch, self._dtype) - syndrome_arrays = torch.tensor( - syndrome_arrays, - dtype=torch_dtype, - device=self.torch_device, - ).transpose(1, 2) - self._batch_size = int(new_syndrome_data.shape[0]) - self._update_data(syndrome_arrays, new_observable_flips, enforce_shape) - - def optimize_path(self, optimize: Any = None, batch_size: int = -1) -> Any: - """Cache a contraction path via quimb and rebuild the JIT. - - Always routes through :meth:`TensorNetwork.contraction_info` so - the resulting path is compatible with :mod:`opt_einsum` and - manual unrolling -- unlike :meth:`TensorNetworkDecoder.optimize_path`, - which defaults to a cuTensorNet-only path. - - ``batch_size`` is part of the parent ``TensorNetworkDecoder`` - signature (which rebuilds its TN around a fake batch); on the - optimiser the syndrome TN is already batched at construction - and resized in :meth:`update_dataset`, so this argument is - ignored. Kept for Liskov substitution with the parent. - """ - del batch_size - info = self.full_tn.contraction_info( - output_inds=("batch_index", self.logical_obs_inds[0]), - optimize=optimize if optimize is not None else "auto", - ) - self.path_batch = info.path - self.slicing_batch = tuple() - self._snapshot_arrays_and_eq() - return info - - -def make_compiled_step(optimizer: NMOptimizer, logits: torch.Tensor, - torch_optimizer: torch.optim.Optimizer): - """Build a no-arg callable that runs one Adam step and returns the loss. - - The step zeros grads, calls the optimizer's compiled - ``loss_fn(from_logits=True)`` (sigmoid + contraction + cross-entropy - fused), backwards, and steps ``torch_optimizer``. Use this when - training in logit space. - - Args: - optimizer: The :class:`NMOptimizer` providing the fused - inner loss; pass ``compile=True`` at the - :class:`NMOptimizer` constructor for the - ``torch.compile``-d variant. - logits: Trainable 1-D tensor of length ``len(optimizer.error_inds)`` - with ``requires_grad=True``. - torch_optimizer: A ``torch.optim`` instance owning ``logits``. - """ - - # Re-fetch per call so a rebuild from update_dataset or the - # observable_flips setter is picked up; capturing would go stale. - def _step(): - torch_optimizer.zero_grad(set_to_none=True) - loss = optimizer.loss_fn(from_logits=True)( - logits, optimizer.current_syndrome_args()) - loss.backward() - torch_optimizer.step() - return loss - - return _step diff --git a/libs/qec/python/tests/test_tn_noise_models.py b/libs/qec/python/tests/test_nm_optimizer.py similarity index 90% rename from libs/qec/python/tests/test_tn_noise_models.py rename to libs/qec/python/tests/test_nm_optimizer.py index 5ff75a21..54af6aec 100644 --- a/libs/qec/python/tests/test_tn_noise_models.py +++ b/libs/qec/python/tests/test_nm_optimizer.py @@ -16,19 +16,16 @@ import numpy as np import pytest +import cudaq_qec as qec torch = pytest.importorskip( "torch", reason="torch not installed; skipping TN noise-learning tests") -import cudaq_qec as qec # noqa: E402 - if sys.version_info >= (3, 11): - from cudaq_qec.plugins.decoders.tensor_network_utils.noise_models import ( + from cudaq_qec.plugins.decoders.tensor_network_utils.nm_optimizer import ( NMOptimizer, - _PRIOR_EPS_BY_DTYPE, - _remap_eq_to_ascii, - _validate_and_clamp_priors, make_compiled_step, + remap_eq_to_ascii, ) pytestmark = pytest.mark.skipif(sys.version_info < (3, 11), @@ -136,17 +133,6 @@ def test_construction_basic(device): assert np.all((np_probs >= 0.0) & (np_probs <= 1.0)) -def test_public_reexport_from_tensor_network_decoder_module(): - """``NMOptimizer`` is re-exported from the TN decoder plugin module.""" - from cudaq_qec.plugins.decoders import tensor_network_decoder as tnd - from cudaq_qec.plugins.decoders.tensor_network_utils import (noise_models as - nl) - assert tnd.NMOptimizer is nl.NMOptimizer - assert tnd.make_compiled_step is nl.make_compiled_step - assert "NMOptimizer" in tnd.__all__ - assert "make_compiled_step" in tnd.__all__ - - @pytest.mark.parametrize("device", _device_params()) def test_invalid_execute_mode_rejected(device): H, logical, priors = _simple_repetition_code() @@ -165,6 +151,26 @@ def test_invalid_execute_mode_rejected(device): execute="bogus") +@pytest.mark.parametrize("device", _device_params()) +def test_invalid_dtype_rejected(device): + """Unsupported dtypes must be rejected at the constructor boundary, + before any contraction setup runs.""" + H, logical, priors = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=4, + rng=np.random.default_rng(2)) + with pytest.raises(ValueError, match="Invalid dtype"): + _make_opt(H, + logical, + priors, + syn, + flips, + device=device, + dtype="float16") + + # -- forward pass / gradient ------------------------------------------------- @@ -360,7 +366,9 @@ def test_boundary_priors_clamped_with_warning(device, dtype): logical, [0.1, 0.2, 0.3], num_shots=8, rng=np.random.default_rng(20)) - eps = _PRIOR_EPS_BY_DTYPE[dtype] + # Mirrors the dtype-eps table inside NMOptimizer. Hardcoded so the + # test pins the boundary contract independently of the implementation. + eps = 1e-12 if dtype == "float64" else 1e-6 with pytest.warns(UserWarning, match=r"Clamped \d+/\d+"): opt = _make_opt(H, logical, @@ -403,16 +411,17 @@ def test_in_range_priors_no_warning(): _make_opt(H, logical, priors, syn, flips, device="cpu") -def test_validate_and_clamp_priors_helper(): - """Unit-test the helper directly: shape, unknown-dtype rejection, idempotence.""" +def test_non_1d_noise_model_rejected(): + """A 2-D ``noise_model`` is rejected at the constructor boundary.""" + H, logical, priors = _simple_repetition_code() + syn, flips = _sample_synthetic_dataset(H, + logical, + priors, + num_shots=4, + rng=np.random.default_rng(23)) + bad = np.full((2, 3), 0.5) with pytest.raises(ValueError, match="must be 1-D"): - _validate_and_clamp_priors(np.zeros((2, 3)) + 0.5, "float64") - with pytest.raises(ValueError, match="Unsupported dtype"): - _validate_and_clamp_priors([0.1, 0.5, 0.9], "float128_unknown") - out = _validate_and_clamp_priors([0.1, 0.5, 0.9], "float64") - assert out == [0.1, 0.5, 0.9] - out = _validate_and_clamp_priors(out, "float64") - assert out == [0.1, 0.5, 0.9] + _make_opt(H, logical, bad, syn, flips, device="cpu") # -- current_syndrome_args --------------------------------------------------- @@ -570,6 +579,36 @@ def test_update_dataset_shape_change_rebuilds_and_decodes( assert torch.allclose(loss, ref_loss, atol=1e-8, rtol=1e-8) +@pytest.mark.parametrize("device", _device_params()) +def test_update_dataset_enforce_shape_mismatch_raises(device): + """``enforce_shape=True`` (default) must reject a syndrome batch whose + per-tensor shape differs from the construction-time batch. The + permissive path is already covered by + :func:`test_update_dataset_shape_change_rebuilds_and_decodes`.""" + rng = np.random.default_rng(78) + H, logical = _nondegenerate_code() + init_priors = [0.1, 0.15, 0.25] + syn1, flips1 = _sample_synthetic_dataset(H, + logical, + init_priors, + num_shots=16, + rng=rng) + opt = _make_opt(H, + logical, + init_priors, + syn1, + flips1, + device=device, + dtype="float64") + syn2, flips2 = _sample_synthetic_dataset(H, + logical, + init_priors, + num_shots=33, + rng=rng) + with pytest.raises(AssertionError, match="Shape mismatch"): + opt.update_dataset(syn2, flips2) + + # -- optimize_path ----------------------------------------------------------- @@ -630,12 +669,12 @@ def test_optimize_path_with_cotengra(device): np.testing.assert_allclose(before, after, atol=1e-10, rtol=1e-10) -# -- _remap_eq_to_ascii ------------------------------------------------------- +# -- remap_eq_to_ascii ------------------------------------------------------- def test_remap_eq_to_ascii_simple(): eq = "ab,bc->ac" - out = _remap_eq_to_ascii(eq) + out = remap_eq_to_ascii(eq) # ASCII input is returned unchanged via the ``isascii()`` fast path. assert out == "ab,bc->ac" @@ -643,7 +682,7 @@ def test_remap_eq_to_ascii_simple(): def test_remap_eq_to_ascii_unicode_labels(): """Synthetic equation with non-ASCII labels is remapped to a-zA-Z.""" eq = "\u0391\u0392,\u0392\u0393->\u0391\u0393" # greek letters - out = _remap_eq_to_ascii(eq) + out = remap_eq_to_ascii(eq) assert "\u0391" not in out and "\u0392" not in out and "\u0393" not in out assert "->" in out lhs, rhs = out.split("->") @@ -656,7 +695,7 @@ def test_remap_eq_to_ascii_too_many_labels(): chars = [chr(0x4E00 + i) for i in range(53)] # 53 distinct CJK chars eq = "".join(chars) + "->" + chars[0] with pytest.raises(ValueError, match="more than 52"): - _remap_eq_to_ascii(eq) + remap_eq_to_ascii(eq) # -- logical_error_rate ------------------------------------------------------ From 359ea621abf5ead5c44af303d072b419adcf8964 Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Tue, 26 May 2026 08:38:21 -0400 Subject: [PATCH 4/5] chore: retrigger CI Signed-off-by: vedika-saravanan From 062a2897bfcd280731891f64b29204b70e0e2acb Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Tue, 26 May 2026 08:53:17 -0400 Subject: [PATCH 5/5] chore: retrigger CI Signed-off-by: vedika-saravanan