From dd5fba385b469a411cd8e7744e48033b8dfa773d Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sat, 28 Mar 2026 16:59:10 -0700 Subject: [PATCH 01/25] add ftl astar prototype in python -- can decode d11r11p002 superdense color codes in about 5 mins / shot without any beam cutoffs whatsoever using operator cost partitioning approach to improve over detcosts with naive LP implementation. --- src/py/astar_prototype.py | 611 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 611 insertions(+) create mode 100644 src/py/astar_prototype.py diff --git a/src/py/astar_prototype.py b/src/py/astar_prototype.py new file mode 100644 index 0000000..9359813 --- /dev/null +++ b/src/py/astar_prototype.py @@ -0,0 +1,611 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with plain detcost or optimal singleton detcost. + +The default heuristic matches the original prototype's plain detector-wise +heuristic. Passing --opt-singleton-detcost switches to the exact optimal +singleton lower bound, solved as a small LP over the currently active +residual detectors. + +Notes: + * The search still uses the precedence-based tree pruning from the + prototype. + * By default, the heuristic ignores precedence-blocked errors in order to + preserve the original prototype's behavior. Use + --respect-blocked-errors-in-heuristic to exclude blocked errors from the + heuristic as well. + * The optimal singleton heuristic requires SciPy (``scipy.optimize.linprog``). +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + heuristic_calls: int + lp_calls: int + elapsed_seconds: float + + +class AStarPrototypeDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + *, + use_opt_singleton_detcost: bool = False, + respect_blocked_errors_in_heuristic: bool = False, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_detectors = int(num_detectors) + self.num_errors = len(self.errors) + self.use_opt_singleton_detcost = use_opt_singleton_detcost + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.verbose_search = verbose_search + + if self.use_opt_singleton_detcost and linprog is None: + raise RuntimeError( + "--opt-singleton-detcost requires scipy. Install scipy and rerun." + ) + + self.ecosts = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.edets: List[np.ndarray] = [ + np.array(err.detectors, dtype=np.int32) for err in self.errors + ] + self.eobs: List[np.ndarray] = [ + np.array(err.observables, dtype=np.int32) for err in self.errors + ] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.edets): + for d in dets: + d2e_lists[int(d)].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.heuristic_calls = 0 + self.lp_calls = 0 + + @property + def heuristic_name(self) -> str: + if self.use_opt_singleton_detcost: + return "opt-singleton-detcost" + return "plain-detcost" + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _plain_detcost_heuristic( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> float: + total = 0.0 + for d in np.flatnonzero(dets): + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.ecosts[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF + total += best + return total + + def _opt_singleton_detcost_heuristic( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> float: + active_dets = np.flatnonzero(dets) + if active_dets.size == 0: + return 0.0 + + det_to_var = {int(d): i for i, d in enumerate(active_dets.tolist())} + support_to_weight: Dict[Tuple[int, ...], float] = {} + covered = np.zeros(active_dets.size, dtype=bool) + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + if int(det_counts[ei]) == 0: + continue + support = tuple(det_to_var[int(d)] for d in self.edets[ei] if dets[int(d)]) + if not support: + continue + for var in support: + covered[var] = True + weight = float(self.ecosts[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + if not np.all(covered): + return INF + + num_vars = active_dets.size + supports = list(support_to_weight.keys()) + weights = np.array([support_to_weight[s] for s in supports], dtype=np.float64) + + row_indices: List[int] = [] + col_indices: List[int] = [] + data: List[float] = [] + for row, support in enumerate(supports): + row_indices.extend([row] * len(support)) + col_indices.extend(support) + data.extend([1.0] * len(support)) + + + a_ub = csr_matrix( + (data, (row_indices, col_indices)), + shape=(len(supports), num_vars), + dtype=np.float64, + ) + + self.lp_calls += 1 + result = linprog( + c=-np.ones(num_vars, dtype=np.float64), + A_ub=a_ub, + b_ub=weights, + bounds=[(0.0, None)] * num_vars, + method="highs", + ) + if result.status == 0: + return max(0.0, float(-result.fun)) + if result.status in {2, 3}: # infeasible or unbounded + return INF + raise RuntimeError(f"linprog failed with status={result.status}: {result.message}") + + def heuristic_cost( + self, + errs: np.ndarray, + blocked_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> float: + self.heuristic_calls += 1 + available = self._available_errors(errs, blocked_errs) + if self.use_opt_singleton_detcost: + return self._opt_singleton_detcost_heuristic(available, dets, det_counts) + return self._plain_detcost_heuristic(available, dets, det_counts) + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.heuristic_calls = 0 + self.lp_calls = 0 + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + h0 = self.heuristic_cost(errs0, blocked0, dets0, det_counts0) + if math.isinf(h0): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + heuristic_calls=self.heuristic_calls, + lp_calls=self.lp_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + next_node_id = 1 + heap: List[Tuple[float, int, int]] = [(h0, int(dets0.sum()), 0)] + node_data: Dict[int, SearchState] = { + 0: SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + ) + } + + nodes_pushed = 1 + nodes_popped = 0 + min_num_dets = int(dets0.sum()) + + while heap: + f_cost, num_dets, node_id = heapq.heappop(heap) + state = node_data.pop(node_id, None) + if state is None: + continue + nodes_popped += 1 + + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + errs = state.errs + blocked_errs = state.blocked_errs + dets = state.dets + det_counts = state.det_counts + g_cost = state.g_cost + + if self.verbose_search: + print( + f"len(heap)={len(heap)} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={g_cost:.6f}" + ) + + if num_dets == 0: + return DecodeResult( + success=True, + errs=errs, + residual_dets=dets, + cost=g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + lp_calls=self.lp_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + min_det = int(np.flatnonzero(dets)[0]) + prefix_blocked_errs = blocked_errs.copy() + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked_errs[ei] = True + + if errs[ei] or blocked_errs[ei]: + continue + + child_errs = errs.copy() + child_errs[ei] = True + child_blocked_errs = prefix_blocked_errs.copy() + child_dets = dets.copy() + child_det_counts = det_counts.copy() + + for d in self.edets[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + continue + + child_g = g_cost + float(self.ecosts[ei]) + child_h = self.heuristic_cost( + child_errs, + child_blocked_errs, + child_dets, + child_det_counts, + ) + if math.isinf(child_h): + continue + + child_id = next_node_id + next_node_id += 1 + node_data[child_id] = SearchState( + errs=child_errs, + blocked_errs=child_blocked_errs, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, child_id)) + nodes_pushed += 1 + + return DecodeResult( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + lp_calls=self.lp_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.ecosts[errs].sum()) + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.eobs[int(ei)]: + obs = int(obs) + parity[obs] = not parity.get(obs, False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.edets[int(ei)]: + dets[int(d)] ^= True + return dets + + +def merged_errors_from_dem(dem) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + + for error in dem.flattened(): + if error.type != "error": + continue + + probability = float(error.args_copy()[0]) + if probability <= 0: + continue + if probability > 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5], got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in error.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected target type: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + key = (tuple(sorted(detectors)), tuple(sorted(observables))) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = probability + else: + # Two independent identical symptoms combine by XORing their parity. + p_new = p_old * (1.0 - probability) + (1.0 - p_old) * probability + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def sample_detections_and_observables(circuit, num_shots: int, seed: int) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=circuit.num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=circuit.num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder using the plain detector-wise heuristic or the " + "optimal singleton detector heuristic." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--shot", + type=int, + default=0, + help="Shot index to decode after sampling --sample-num-shots shots (default: 0).", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot (default: 100).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the number of residual detections; use 'inf' for none (default: inf).", + ) + parser.add_argument( + "--opt-singleton-detcost", + action="store_true", + help=( + "Use the exact optimal singleton detector-cost lower bound instead of the " + "plain detector-wise lower bound. Requires scipy." + ), + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action="store_true", + help=( + "Exclude precedence-blocked errors from the heuristic. By default the script " + "preserves the original prototype's behavior and only excludes already-activated errors." + ), + ) + parser.add_argument( + "--show-detections", + action="store_true", + help="Print the selected shot's detection events before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action="store_true", + help="Print the decoded merged-error indices.", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print one line per expanded node during A* search.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + try: + import stim + except ImportError as exc: # pragma: no cover - depends on runtime environment. + raise SystemExit("This script requires the 'stim' package to be installed.") from exc + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) + + dets_unpacked, obs_unpacked = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + ) + shot_dets = dets_unpacked[args.shot] + shot_obs = obs_unpacked[args.shot] + + if args.show_detections: + active_dets = np.flatnonzero(shot_dets) + print("detections:", " ".join(f"D{d}" for d in active_dets)) + + decoder = AStarPrototypeDecoder( + errors, + dem.num_detectors, + use_opt_singleton_detcost=args.opt_singleton_detcost, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + verbose_search=args.verbose_search, + ) + result = decoder.decode(shot_dets, det_beam=args.det_beam) + + print(f"heuristic: {decoder.heuristic_name}") + print(f"shot: {args.shot} / {args.sample_num_shots}") + print(f"success: {result.success}") + print(f"nodes_pushed: {result.nodes_pushed}") + print(f"nodes_popped: {result.nodes_popped}") + print(f"heuristic_calls: {result.heuristic_calls}") + print(f"lp_calls: {result.lp_calls}") + print(f"elapsed_seconds: {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + decoded_err_indices = np.flatnonzero(result.errs) + if args.show_error_indices: + print("decoded_error_indices:", " ".join(map(str, decoded_err_indices.tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + reproduced_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + actual_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors: {int(result.errs.sum())}") + print(f"decoded_cost: {reproduced_cost:.12f}") + print("predicted_observables:", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables:", " ".join(f"L{o}" for o in actual_obs.tolist())) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From e0b44514c983f7d76dd2b45f8b65ecc1fdec7628 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Mon, 30 Mar 2026 09:35:08 -0700 Subject: [PATCH 02/25] first ftl c++ prototype with custom simplex solver --- src/BUILD | 28 + src/py/{ => astar}/astar_prototype.py | 0 .../astar_prototype_incremental_greedy.py | 969 ++++++++++ src/py/astar/astar_prototype_local_blast.py | 1203 ++++++++++++ src/py/astar/astar_prototype_projected.py | 882 +++++++++ ...r_prototype_singleton_greedy_heuristics.py | 751 ++++++++ ...totype_singleton_greedy_heuristics_lazy.py | 1107 +++++++++++ ...tar_prototype_singleton_restricted_lazy.py | 1675 +++++++++++++++++ .../astar/astar_prototype_subset_detcost.py | 1071 +++++++++++ .../astar_prototype_subset_detcost_lazy.py | 1314 +++++++++++++ src/py/astar/astar_singleton_lp_probe.py | 1276 +++++++++++++ src/tesseract.h | 3 +- src/tesseract.pybind.h | 4 + src/tesseract_ftl.cc | 1170 ++++++++++++ src/tesseract_ftl.h | 203 ++ src/tesseract_ftl_main.cc | 495 +++++ src/tesseract_main.cc | 6 + 17 files changed, 12156 insertions(+), 1 deletion(-) rename src/py/{ => astar}/astar_prototype.py (100%) create mode 100644 src/py/astar/astar_prototype_incremental_greedy.py create mode 100644 src/py/astar/astar_prototype_local_blast.py create mode 100644 src/py/astar/astar_prototype_projected.py create mode 100644 src/py/astar/astar_prototype_singleton_greedy_heuristics.py create mode 100644 src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py create mode 100644 src/py/astar/astar_prototype_singleton_restricted_lazy.py create mode 100644 src/py/astar/astar_prototype_subset_detcost.py create mode 100644 src/py/astar/astar_prototype_subset_detcost_lazy.py create mode 100644 src/py/astar/astar_singleton_lp_probe.py create mode 100644 src/tesseract_ftl.cc create mode 100644 src/tesseract_ftl.h create mode 100644 src/tesseract_ftl_main.cc diff --git a/src/BUILD b/src/BUILD index ebac6a5..95a6794 100644 --- a/src/BUILD +++ b/src/BUILD @@ -140,6 +140,20 @@ cc_library( ], ) + +cc_library( + name = "libtesseract_ftl", + srcs = ["tesseract_ftl.cc"], + hdrs = ["tesseract_ftl.h"], + copts = OPT_COPTS, + linkopts = OPT_LINKOPTS, + deps = [ + ":libtesseract", + "@boost//:dynamic_bitset", + "@highs", + ], +) + cc_binary( name = "tesseract", srcs = ["tesseract_main.cc"], @@ -153,6 +167,20 @@ cc_binary( ], ) + +cc_binary( + name = "tesseract_ftl", + srcs = ["tesseract_ftl_main.cc"], + copts = OPT_COPTS, + linkopts = OPT_LINKOPTS, + deps = [ + ":libtesseract_ftl", + "@argparse", + "@nlohmann_json//:json", + "@stim//:stim_lib", + ], +) + cc_test( name = "tesseract_tests", timeout = "eternal", diff --git a/src/py/astar_prototype.py b/src/py/astar/astar_prototype.py similarity index 100% rename from src/py/astar_prototype.py rename to src/py/astar/astar_prototype.py diff --git a/src/py/astar/astar_prototype_incremental_greedy.py b/src/py/astar/astar_prototype_incremental_greedy.py new file mode 100644 index 0000000..4dc6a2d --- /dev/null +++ b/src/py/astar/astar_prototype_incremental_greedy.py @@ -0,0 +1,969 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with incremental greedy singleton heuristics. + +Heuristic modes: + --heuristic plain exact plain detcost via incremental support updates + --heuristic asc-deg exact ascending-degree saturation heuristic + --heuristic plain-sweep exact plain+one-sweep saturation heuristic + --heuristic best-of-two max(asc-deg, plain-sweep) + +All four heuristics are maintained incrementally: + * the deduplicated active-support dictionary W(T) is updated from parent to + child using only errors touching flipped detectors; + * heuristic values are recomputed only on the union of touched connected + components of the active-support hypergraph; + * untouched components inherit their detector prices exactly. + +This stays inside the singleton-family lower-bound framework, but avoids any LP +solves while still being much tighter than basic detcost in practice. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple + +import numpy as np +import stim + +INF = math.inf + +SupportKey = Tuple[int, ...] + + +@dataclass(frozen=True) +class MergedError: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class DecoderData: + num_detectors: int + num_observables: int + errors: List[MergedError] + detector_to_errors: List[np.ndarray] + error_costs: np.ndarray + error_detectors: List[np.ndarray] + error_observables: List[np.ndarray] + + +@dataclass +class SupportState: + support_to_errors: Dict[SupportKey, FrozenSet[int]] + support_to_weight: Dict[SupportKey, float] + detector_to_supports: Dict[int, FrozenSet[SupportKey]] + + +@dataclass +class HeuristicCache: + support_state: SupportState + h_value: float + y_plain: Optional[np.ndarray] = None + y_asc: Optional[np.ndarray] = None + y_sweep: Optional[np.ndarray] = None + + +@dataclass +class SearchState: + activated_errors: Tuple[int, ...] + errs: np.ndarray + blocked_errors: np.ndarray + active_detectors: np.ndarray + path_cost: float + heuristic_cache: HeuristicCache + + +@dataclass +class DecodeStats: + num_pq_pushed: int + num_nodes_popped: int + max_queue_size: int + heuristic_evaluations: int + support_build_calls: int + support_build_seconds: float + support_update_calls: int + support_update_seconds: float + component_recompute_calls: int + component_recompute_seconds: float + incremental_children: int + changed_supports_total: int + touched_detectors_total: int + elapsed_seconds: float + heuristic_name: str + + +@dataclass +class DecodeResult: + activated_errors: Tuple[int, ...] + path_cost: float + stats: DecodeStats + + +class UnionFind: + def __init__(self, size: int) -> None: + self.parent = list(range(size)) + self.rank = [0] * size + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class IncrementalGreedyHeuristic: + def __init__( + self, + data: DecoderData, + *, + mode: str, + ) -> None: + valid_modes = {"plain", "asc-deg", "plain-sweep", "best-of-two"} + if mode not in valid_modes: + raise ValueError(f"Unknown heuristic mode: {mode!r}") + self.data = data + self.mode = mode + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_evaluations = 0 + self.support_build_calls = 0 + self.support_build_seconds = 0.0 + self.support_update_calls = 0 + self.support_update_seconds = 0.0 + self.component_recompute_calls = 0 + self.component_recompute_seconds = 0.0 + self.incremental_children = 0 + self.changed_supports_total = 0 + self.touched_detectors_total = 0 + + @property + def heuristic_name(self) -> str: + return f"{self.mode}-incremental" + + def _active_support(self, active_detectors: np.ndarray, error_index: int) -> Optional[SupportKey]: + support = tuple(int(d) for d in self.data.error_detectors[error_index] if active_detectors[int(d)]) + return support if support else None + + def _build_support_state_from_scratch( + self, + errs: np.ndarray, + blocked_errors: np.ndarray, + active_detectors: np.ndarray, + ) -> SupportState: + t0 = time.perf_counter() + self.support_build_calls += 1 + + support_to_errors_mut: Dict[SupportKey, Set[int]] = {} + for error_index in range(len(self.data.errors)): + if errs[error_index] or blocked_errors[error_index]: + continue + support = self._active_support(active_detectors, error_index) + if support is None: + continue + bucket = support_to_errors_mut.setdefault(support, set()) + bucket.add(error_index) + + support_to_errors: Dict[SupportKey, FrozenSet[int]] = {} + support_to_weight: Dict[SupportKey, float] = {} + detector_to_supports_mut: Dict[int, Set[SupportKey]] = defaultdict(set) + for support, bucket in support_to_errors_mut.items(): + frozen = frozenset(bucket) + support_to_errors[support] = frozen + support_to_weight[support] = float(min(self.data.error_costs[ei] for ei in frozen)) + for detector in support: + detector_to_supports_mut[detector].add(support) + detector_to_supports = { + detector: frozenset(supports) + for detector, supports in detector_to_supports_mut.items() + if supports + } + + self.support_build_seconds += time.perf_counter() - t0 + return SupportState( + support_to_errors=support_to_errors, + support_to_weight=support_to_weight, + detector_to_supports=detector_to_supports, + ) + + def _update_support_state_incremental( + self, + parent_support_state: SupportState, + parent_errs: np.ndarray, + child_errs: np.ndarray, + parent_blocked: np.ndarray, + child_blocked: np.ndarray, + parent_active_detectors: np.ndarray, + child_active_detectors: np.ndarray, + flipped_detectors: np.ndarray, + ) -> Tuple[SupportState, Set[SupportKey], Set[int]]: + t0 = time.perf_counter() + self.support_update_calls += 1 + + affected_errors: Set[int] = set() + for detector in flipped_detectors: + for error_index in self.data.detector_to_errors[int(detector)]: + affected_errors.add(int(error_index)) + + child_support_to_errors = dict(parent_support_state.support_to_errors) + child_support_to_weight = dict(parent_support_state.support_to_weight) + touched_buckets: Dict[SupportKey, Set[int]] = {} + + def get_touched_bucket(support: SupportKey) -> Set[int]: + bucket = touched_buckets.get(support) + if bucket is None: + bucket = set(parent_support_state.support_to_errors.get(support, frozenset())) + touched_buckets[support] = bucket + return bucket + + for error_index in affected_errors: + old_available = (not parent_errs[error_index]) and (not parent_blocked[error_index]) + new_available = (not child_errs[error_index]) and (not child_blocked[error_index]) + old_support = self._active_support(parent_active_detectors, error_index) if old_available else None + new_support = self._active_support(child_active_detectors, error_index) if new_available else None + if old_support == new_support: + continue + if old_support is not None: + get_touched_bucket(old_support).discard(error_index) + if new_support is not None: + get_touched_bucket(new_support).add(error_index) + + changed_supports: Set[SupportKey] = set() + touched_detectors: Set[int] = set() + + child_detector_to_supports = dict(parent_support_state.detector_to_supports) + touched_detector_sets: Dict[int, Set[SupportKey]] = {} + + for support, bucket in touched_buckets.items(): + old_bucket = parent_support_state.support_to_errors.get(support, frozenset()) + old_present = support in parent_support_state.support_to_weight + new_present = bool(bucket) + + if new_present: + frozen_bucket = frozenset(bucket) + child_support_to_errors[support] = frozen_bucket + new_weight = float(min(self.data.error_costs[ei] for ei in frozen_bucket)) + child_support_to_weight[support] = new_weight + if (not old_present) or frozen_bucket != old_bucket or abs(new_weight - parent_support_state.support_to_weight.get(support, 0.0)) > 1e-12: + changed_supports.add(support) + else: + child_support_to_errors.pop(support, None) + if old_present: + child_support_to_weight.pop(support, None) + changed_supports.add(support) + + if old_present != new_present: + for detector in support: + detector_bucket = touched_detector_sets.get(detector) + if detector_bucket is None: + detector_bucket = set(parent_support_state.detector_to_supports.get(detector, frozenset())) + touched_detector_sets[detector] = detector_bucket + if new_present: + detector_bucket.add(support) + else: + detector_bucket.discard(support) + + for support in changed_supports: + touched_detectors.update(support) + + for detector, supports in touched_detector_sets.items(): + if supports: + child_detector_to_supports[detector] = frozenset(supports) + else: + child_detector_to_supports.pop(detector, None) + + self.incremental_children += 1 + self.changed_supports_total += len(changed_supports) + self.touched_detectors_total += len(touched_detectors) + self.support_update_seconds += time.perf_counter() - t0 + + return ( + SupportState( + support_to_errors=child_support_to_errors, + support_to_weight=child_support_to_weight, + detector_to_supports=child_detector_to_supports, + ), + changed_supports, + touched_detectors, + ) + + def _component_from_seed_detectors( + self, + support_state: SupportState, + seed_detectors: Iterable[int], + active_detectors: np.ndarray, + ) -> Tuple[Set[int], Set[SupportKey]]: + seen_detectors: Set[int] = set() + seen_supports: Set[SupportKey] = set() + stack = [int(d) for d in seed_detectors if active_detectors[int(d)] and int(d) in support_state.detector_to_supports] + + while stack: + detector = stack.pop() + if detector in seen_detectors: + continue + seen_detectors.add(detector) + for support in support_state.detector_to_supports.get(detector, frozenset()): + if support in seen_supports: + continue + seen_supports.add(support) + for other_detector in support: + if active_detectors[other_detector] and other_detector not in seen_detectors: + stack.append(other_detector) + return seen_detectors, seen_supports + + def _all_components(self, support_state: SupportState) -> List[Tuple[Set[int], Set[SupportKey]]]: + components: List[Tuple[Set[int], Set[SupportKey]]] = [] + seen_detectors: Set[int] = set() + for detector in sorted(support_state.detector_to_supports): + if detector in seen_detectors: + continue + dets, supports = self._component_from_seed_detectors( + support_state=support_state, + seed_detectors=[detector], + active_detectors=np.ones(self.data.num_detectors, dtype=bool), + ) + seen_detectors.update(dets) + components.append((dets, supports)) + return components + + def _component_incidence( + self, + component_detectors: Set[int], + component_supports: Set[SupportKey], + support_state: SupportState, + ) -> Dict[int, List[SupportKey]]: + component_supports_set = set(component_supports) + incidence: Dict[int, List[SupportKey]] = {} + for detector in component_detectors: + local_supports = [ + support + for support in support_state.detector_to_supports.get(detector, frozenset()) + if support in component_supports_set + ] + incidence[detector] = local_supports + return incidence + + def _compute_plain_component( + self, + component_detectors: Set[int], + component_supports: Set[SupportKey], + support_state: SupportState, + ) -> Dict[int, float]: + incidence = self._component_incidence(component_detectors, component_supports, support_state) + y: Dict[int, float] = {} + for detector in component_detectors: + best = INF + for support in incidence[detector]: + candidate = support_state.support_to_weight[support] / len(support) + if candidate < best: + best = candidate + if math.isinf(best): + raise RuntimeError("Detector in active support component has no incident support.") + y[detector] = best + return y + + def _compute_asc_component( + self, + component_detectors: Set[int], + component_supports: Set[SupportKey], + support_state: SupportState, + ) -> Dict[int, float]: + incidence = self._component_incidence(component_detectors, component_supports, support_state) + order = sorted(component_detectors, key=lambda d: (len(incidence[d]), d)) + slacks = {support: float(support_state.support_to_weight[support]) for support in component_supports} + y: Dict[int, float] = {} + for detector in order: + value = min(slacks[support] for support in incidence[detector]) + y[detector] = value + for support in incidence[detector]: + slacks[support] -= value + return y + + def _compute_plain_sweep_component( + self, + component_detectors: Set[int], + component_supports: Set[SupportKey], + support_state: SupportState, + ) -> Dict[int, float]: + incidence = self._component_incidence(component_detectors, component_supports, support_state) + y = self._compute_plain_component(component_detectors, component_supports, support_state) + slacks = { + support: float(support_state.support_to_weight[support]) - sum(y[detector] for detector in support) + for support in component_supports + } + order = sorted(component_detectors, key=lambda d: (-y[d], d)) + for detector in order: + delta = min(slacks[support] for support in incidence[detector]) + y[detector] += delta + for support in incidence[detector]: + slacks[support] -= delta + return y + + def _build_cache_from_support_state(self, support_state: SupportState) -> HeuristicCache: + t0 = time.perf_counter() + self.heuristic_evaluations += 1 + self.component_recompute_calls += 1 + + y_plain = np.zeros(self.data.num_detectors, dtype=np.float64) if self.mode == "plain" else None + y_asc = np.zeros(self.data.num_detectors, dtype=np.float64) if self.mode in {"asc-deg", "best-of-two"} else None + y_sweep = np.zeros(self.data.num_detectors, dtype=np.float64) if self.mode in {"plain-sweep", "best-of-two"} else None + + for component_detectors, component_supports in self._all_components(support_state): + if self.mode == "plain": + comp = self._compute_plain_component(component_detectors, component_supports, support_state) + for detector, value in comp.items(): + y_plain[detector] = value + elif self.mode == "asc-deg": + comp = self._compute_asc_component(component_detectors, component_supports, support_state) + for detector, value in comp.items(): + y_asc[detector] = value + elif self.mode == "plain-sweep": + comp = self._compute_plain_sweep_component(component_detectors, component_supports, support_state) + for detector, value in comp.items(): + y_sweep[detector] = value + elif self.mode == "best-of-two": + comp_asc = self._compute_asc_component(component_detectors, component_supports, support_state) + comp_sweep = self._compute_plain_sweep_component(component_detectors, component_supports, support_state) + for detector, value in comp_asc.items(): + y_asc[detector] = value + for detector, value in comp_sweep.items(): + y_sweep[detector] = value + else: + raise AssertionError("unreachable") + + if self.mode == "plain": + h_value = float(y_plain.sum()) + elif self.mode == "asc-deg": + h_value = float(y_asc.sum()) + elif self.mode == "plain-sweep": + h_value = float(y_sweep.sum()) + else: + h_value = float(max(y_asc.sum(), y_sweep.sum())) + + self.component_recompute_seconds += time.perf_counter() - t0 + return HeuristicCache( + support_state=support_state, + h_value=h_value, + y_plain=y_plain, + y_asc=y_asc, + y_sweep=y_sweep, + ) + + def _incremental_child_cache( + self, + parent_cache: HeuristicCache, + child_support_state: SupportState, + touched_detectors: Set[int], + child_active_detectors: np.ndarray, + flipped_detectors: np.ndarray, + ) -> HeuristicCache: + t0 = time.perf_counter() + self.heuristic_evaluations += 1 + self.component_recompute_calls += 1 + + touched_component_detectors, touched_component_supports = self._component_from_seed_detectors( + support_state=child_support_state, + seed_detectors=touched_detectors, + active_detectors=child_active_detectors, + ) + + y_plain = None if parent_cache.y_plain is None else parent_cache.y_plain.copy() + y_asc = None if parent_cache.y_asc is None else parent_cache.y_asc.copy() + y_sweep = None if parent_cache.y_sweep is None else parent_cache.y_sweep.copy() + + for detector in flipped_detectors: + detector = int(detector) + if not child_active_detectors[detector]: + if y_plain is not None: + y_plain[detector] = 0.0 + if y_asc is not None: + y_asc[detector] = 0.0 + if y_sweep is not None: + y_sweep[detector] = 0.0 + + for detector in touched_component_detectors: + if y_plain is not None: + y_plain[detector] = 0.0 + if y_asc is not None: + y_asc[detector] = 0.0 + if y_sweep is not None: + y_sweep[detector] = 0.0 + + if touched_component_detectors: + if self.mode == "plain": + comp = self._compute_plain_component(touched_component_detectors, touched_component_supports, child_support_state) + for detector, value in comp.items(): + y_plain[detector] = value + elif self.mode == "asc-deg": + comp = self._compute_asc_component(touched_component_detectors, touched_component_supports, child_support_state) + for detector, value in comp.items(): + y_asc[detector] = value + elif self.mode == "plain-sweep": + comp = self._compute_plain_sweep_component(touched_component_detectors, touched_component_supports, child_support_state) + for detector, value in comp.items(): + y_sweep[detector] = value + elif self.mode == "best-of-two": + comp_asc = self._compute_asc_component(touched_component_detectors, touched_component_supports, child_support_state) + comp_sweep = self._compute_plain_sweep_component(touched_component_detectors, touched_component_supports, child_support_state) + for detector, value in comp_asc.items(): + y_asc[detector] = value + for detector, value in comp_sweep.items(): + y_sweep[detector] = value + else: + raise AssertionError("unreachable") + + if self.mode == "plain": + h_value = float(y_plain.sum()) + elif self.mode == "asc-deg": + h_value = float(y_asc.sum()) + elif self.mode == "plain-sweep": + h_value = float(y_sweep.sum()) + else: + h_value = float(max(y_asc.sum(), y_sweep.sum())) + + self.component_recompute_seconds += time.perf_counter() - t0 + return HeuristicCache( + support_state=child_support_state, + h_value=h_value, + y_plain=y_plain, + y_asc=y_asc, + y_sweep=y_sweep, + ) + + def build_root_cache( + self, + errs: np.ndarray, + blocked_errors: np.ndarray, + active_detectors: np.ndarray, + ) -> HeuristicCache: + support_state = self._build_support_state_from_scratch(errs, blocked_errors, active_detectors) + return self._build_cache_from_support_state(support_state) + + def build_child_cache( + self, + parent_state: SearchState, + child_errs: np.ndarray, + child_blocked_errors: np.ndarray, + child_active_detectors: np.ndarray, + flipped_detectors: np.ndarray, + ) -> HeuristicCache: + child_support_state, _changed_supports, touched_detectors = self._update_support_state_incremental( + parent_support_state=parent_state.heuristic_cache.support_state, + parent_errs=parent_state.errs, + child_errs=child_errs, + parent_blocked=parent_state.blocked_errors, + child_blocked=child_blocked_errors, + parent_active_detectors=parent_state.active_detectors, + child_active_detectors=child_active_detectors, + flipped_detectors=flipped_detectors, + ) + return self._incremental_child_cache( + parent_cache=parent_state.heuristic_cache, + child_support_state=child_support_state, + touched_detectors=touched_detectors, + child_active_detectors=child_active_detectors, + flipped_detectors=flipped_detectors, + ) + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1.0 - p1) + (1.0 - p0) * p1 + + +def iter_dem_errors(dem: stim.DetectorErrorModel) -> Iterable[MergedError]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError("This prototype assumes DEM probabilities in (0, 0.5).") + detectors: Set[int] = set() + observables: Set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + yield MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1.0 - probability))), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors(dem: stim.DetectorErrorModel) -> List[MergedError]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors(dem): + key = (error.detectors, error.observables) + previous = errors_by_symptom.get(key) + if previous is None: + errors_by_symptom[key] = error.probability + else: + errors_by_symptom[key] = xor_probability(previous, error.probability) + + merged: List[MergedError] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError("Merged error has probability >= 0.5.") + merged.append( + MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1.0 - probability))), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def build_decoder_data(dem: stim.DetectorErrorModel, *, merge_errors_in_dem: bool = True) -> DecoderData: + errors = merged_errors(dem) if merge_errors_in_dem else list(iter_dem_errors(dem)) + detector_to_errors_lists: List[List[int]] = [[] for _ in range(dem.num_detectors)] + for error_index, error in enumerate(errors): + for detector in error.detectors: + detector_to_errors_lists[detector].append(error_index) + return DecoderData( + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + errors=errors, + detector_to_errors=[np.asarray(v, dtype=np.int32) for v in detector_to_errors_lists], + error_costs=np.asarray([err.likelihood_cost for err in errors], dtype=np.float64), + error_detectors=[np.asarray(err.detectors, dtype=np.int32) for err in errors], + error_observables=[np.asarray(err.observables, dtype=np.int32) for err in errors], + ) + + +def unpack_bit_packed_rows(bits: np.ndarray, count: int) -> np.ndarray: + return np.unpackbits(bits, bitorder="little", axis=1, count=count).astype(bool, copy=False) + + +def detectors_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + detectors = np.zeros(data.num_detectors, dtype=bool) + for error_index in activated_errors: + for detector in data.error_detectors[error_index]: + detectors[int(detector)] ^= True + return detectors + + +def observables_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + observables = np.zeros(data.num_observables, dtype=bool) + for error_index in activated_errors: + for observable in data.error_observables[error_index]: + observables[int(observable)] ^= True + return observables + + +def parse_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "+inf", "infinity", "+infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("beam must be non-negative or 'inf'") + return float(value) + + +def format_indices(indices: Iterable[int], prefix: str) -> str: + items = list(indices) + if not items: + return "(none)" + return " ".join(f"{prefix}{i}" for i in items) + + +def decode( + data: DecoderData, + detections: np.ndarray, + *, + det_beam: float, + heuristic: IncrementalGreedyHeuristic, + verbose_search: bool = False, +) -> DecodeResult: + start_time = time.perf_counter() + heuristic.reset_stats() + + root_dets = np.asarray(detections, dtype=bool).copy() + root_errs = np.zeros(len(data.errors), dtype=bool) + root_blocked = np.zeros(len(data.errors), dtype=bool) + root_cache = heuristic.build_root_cache(root_errs, root_blocked, root_dets) + root_state = SearchState( + activated_errors=(), + errs=root_errs, + blocked_errors=root_blocked, + active_detectors=root_dets, + path_cost=0.0, + heuristic_cache=root_cache, + ) + + heap: List[Tuple[float, int, int]] = [(root_state.path_cost + root_state.heuristic_cache.h_value, int(root_dets.sum()), 0)] + node_data: Dict[int, SearchState] = {0: root_state} + next_node_id = 1 + + num_pq_pushed = 1 + num_nodes_popped = 0 + max_queue_size = 1 + min_num_dets = int(root_dets.sum()) + + while heap: + max_queue_size = max(max_queue_size, len(heap)) + f_cost, num_dets, node_id = heapq.heappop(heap) + state = node_data.pop(node_id, None) + if state is None: + continue + num_nodes_popped += 1 + + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + if verbose_search: + print( + f"len(heap)={len(heap)} nodes_pushed={num_pq_pushed} nodes_popped={num_nodes_popped} " + f"active_dets={num_dets} beam_max={max_num_dets} depth={len(state.activated_errors)} " + f"f={f_cost:.12g} g={state.path_cost:.12g} h={state.heuristic_cache.h_value:.12g}" + ) + + if num_dets == 0: + elapsed = time.perf_counter() - start_time + return DecodeResult( + activated_errors=state.activated_errors, + path_cost=state.path_cost, + stats=DecodeStats( + num_pq_pushed=num_pq_pushed, + num_nodes_popped=num_nodes_popped, + max_queue_size=max_queue_size, + heuristic_evaluations=heuristic.heuristic_evaluations, + support_build_calls=heuristic.support_build_calls, + support_build_seconds=heuristic.support_build_seconds, + support_update_calls=heuristic.support_update_calls, + support_update_seconds=heuristic.support_update_seconds, + component_recompute_calls=heuristic.component_recompute_calls, + component_recompute_seconds=heuristic.component_recompute_seconds, + incremental_children=heuristic.incremental_children, + changed_supports_total=heuristic.changed_supports_total, + touched_detectors_total=heuristic.touched_detectors_total, + elapsed_seconds=elapsed, + heuristic_name=heuristic.heuristic_name, + ), + ) + + min_detector = int(np.flatnonzero(state.active_detectors)[0]) + blocked_prefix = state.blocked_errors.copy() + + children_generated = 0 + children_beam_pruned = 0 + for error_index in data.detector_to_errors[min_detector]: + error_index = int(error_index) + blocked_prefix[error_index] = True + if state.errs[error_index] or state.blocked_errors[error_index]: + continue + + child_errs = state.errs.copy() + child_errs[error_index] = True + child_blocked = blocked_prefix.copy() + child_active_detectors = state.active_detectors.copy() + flipped_detectors = data.error_detectors[error_index] + for detector in flipped_detectors: + child_active_detectors[int(detector)] = ~child_active_detectors[int(detector)] + + child_num_dets = int(child_active_detectors.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_cache = heuristic.build_child_cache( + parent_state=state, + child_errs=child_errs, + child_blocked_errors=child_blocked, + child_active_detectors=child_active_detectors, + flipped_detectors=flipped_detectors, + ) + child_state = SearchState( + activated_errors=state.activated_errors + (error_index,), + errs=child_errs, + blocked_errors=child_blocked, + active_detectors=child_active_detectors, + path_cost=state.path_cost + float(data.error_costs[error_index]), + heuristic_cache=child_cache, + ) + child_id = next_node_id + next_node_id += 1 + node_data[child_id] = child_state + heapq.heappush(heap, (child_state.path_cost + child_cache.h_value, child_num_dets, child_id)) + num_pq_pushed += 1 + children_generated += 1 + + if verbose_search: + print( + f" expanded node={node_id} children_generated={children_generated} " + f"beam_pruned={children_beam_pruned} support_updates={heuristic.support_update_calls}" + ) + + raise RuntimeError("Decoding failed to find any completion.") + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Stim-compatible A* prototype with incrementally maintained greedy singleton-family lower bounds." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a Stim circuit file.") + parser.add_argument("--shot", type=int, default=0, help="Zero-based sampled shot index to decode.") + parser.add_argument("--sample-num-shots", type=int, default=100, help="Number of shots to sample before selecting --shot.") + parser.add_argument("--seed", type=int, default=27123839530, help="Seed passed to stim.compile_detector_sampler(...).sample(...).") + parser.add_argument("--det-beam", type=parse_beam, default=INF, help="Beam cutoff on the residual detector count. Use an integer or 'inf'.") + parser.add_argument( + "--heuristic", + choices=["plain", "asc-deg", "plain-sweep", "best-of-two"], + default="plain-sweep", + help="Incremental singleton-family heuristic to use.", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the sampled shot's active detector IDs before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the activated error indices in the final decoding.", + ) + parser.add_argument("--verbose-search", action="store_true", help="Print per-node search diagnostics.") + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + data = build_decoder_data(dem, merge_errors_in_dem=args.merge_errors) + + dets_packed, obs_packed = circuit.compile_detector_sampler(seed=args.seed).sample( + shots=args.sample_num_shots, + separate_observables=True, + bit_packed=True, + ) + detections = unpack_bit_packed_rows(dets_packed, count=dem.num_detectors) + observables = unpack_bit_packed_rows(obs_packed, count=dem.num_observables) + + if args.shot >= detections.shape[0]: + parser.error(f"--shot={args.shot} is out of range for {detections.shape[0]} sampled shots.") + + shot_detections = detections[args.shot] + shot_observables = observables[args.shot] if observables.size else np.zeros(0, dtype=bool) + + heuristic = IncrementalGreedyHeuristic(data, mode=args.heuristic) + + print(f"circuit = {args.circuit}") + print(f"heuristic = {heuristic.heuristic_name}") + print(f"shot = {args.shot}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"num_detectors = {data.num_detectors}") + print(f"num_observables = {data.num_observables}") + print(f"num_errors = {len(data.errors)}") + print(f"beam = {args.det_beam}") + if args.show_shot_detectors: + print(f"shot_detectors = {format_indices(np.flatnonzero(shot_detections), 'D')}") + + result = decode( + data=data, + detections=shot_detections, + det_beam=args.det_beam, + heuristic=heuristic, + verbose_search=args.verbose_search, + ) + + predicted_observables = observables_from_solution(data, result.activated_errors) + reproduced_detectors = detectors_from_solution(data, result.activated_errors) + if not np.array_equal(reproduced_detectors, shot_detections): + raise AssertionError("Decoded error set does not reproduce the shot's syndrome.") + + print(f"solution_size = {len(result.activated_errors)}") + print(f"solution_cost = {result.path_cost:.12g}") + if args.show_error_indices: + print(f"activated_errors = {format_indices(result.activated_errors, 'E')}") + print(f"predicted_observables = {format_indices(np.flatnonzero(predicted_observables), 'L')}") + print(f"sample_observables = {format_indices(np.flatnonzero(shot_observables), 'L')}") + print(f"observables_match = {bool(np.array_equal(predicted_observables, shot_observables))}") + print(f"num_pq_pushed = {result.stats.num_pq_pushed}") + print(f"num_nodes_popped = {result.stats.num_nodes_popped}") + print(f"max_queue_size = {result.stats.max_queue_size}") + print(f"heuristic_evaluations = {result.stats.heuristic_evaluations}") + print(f"support_build_calls = {result.stats.support_build_calls}") + print(f"support_build_seconds = {result.stats.support_build_seconds:.6f}") + print(f"support_update_calls = {result.stats.support_update_calls}") + print(f"support_update_seconds = {result.stats.support_update_seconds:.6f}") + print(f"component_recompute_calls = {result.stats.component_recompute_calls}") + print(f"component_recompute_seconds = {result.stats.component_recompute_seconds:.6f}") + print(f"incremental_children = {result.stats.incremental_children}") + mean_changed_supports = ( + result.stats.changed_supports_total / result.stats.incremental_children + if result.stats.incremental_children else 0.0 + ) + mean_touched_detectors = ( + result.stats.touched_detectors_total / result.stats.incremental_children + if result.stats.incremental_children else 0.0 + ) + print(f"mean_changed_supports = {mean_changed_supports:.6f}") + print(f"mean_touched_detectors = {mean_touched_detectors:.6f}") + print(f"elapsed_seconds = {result.stats.elapsed_seconds:.6f}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_local_blast.py b/src/py/astar/astar_prototype_local_blast.py new file mode 100644 index 0000000..3524e83 --- /dev/null +++ b/src/py/astar/astar_prototype_local_blast.py @@ -0,0 +1,1203 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with projected singleton-LP refinement. + +The default heuristic matches the original prototype's plain detector-wise +heuristic. Passing --opt-singleton-detcost enables a lazy version of the exact +optimal singleton detector lower bound: + + * a node is first inserted with a cheap lower bound; + * when the node is popped, an LP-based refinement is optionally run; + * if the refined LP value raises the node's key, the node is reinserted; + * expanded nodes project their refined LP solution onto each child. + +By default the refinement is a full singleton LP solve on pop. Two experimental +modes are also available: + + * --local-lp-component: only reoptimize the active support component(s) + touched by the flipped error used to create the node. + * --local-lp-radius R: only reoptimize detector prices within an R-hop + neighborhood of the changed region, freezing all other prices at their + projected values. + +Both restricted modes remain admissible because they optimize over a subset of +variables while keeping the rest fixed at a feasible point. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") +HEURISTIC_EPS = 1e-9 +FEAS_EPS = 1e-9 + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SupportSystem: + active_dets: np.ndarray + support_to_weight: Dict[Tuple[int, ...], float] + covered_all: bool + + +@dataclass +class OptSingletonLPResult: + value: float + y_full: np.ndarray + num_active_dets: int + num_supports: int + num_free_vars: int + seed_frontier_size: int + mode: str + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + h_cost: float + h_source: str + ready_to_expand: bool + lp_y: Optional[np.ndarray] = None + seed_detectors: Optional[np.ndarray] = None + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + refinement_calls: int + lp_calls: int + full_lp_calls: int + component_lp_calls: int + radius_lp_calls: int + lp_reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_lp_refinement_gain: float + max_lp_refinement_gain: float + elapsed_seconds: float + + +class AStarPrototypeDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + *, + use_opt_singleton_detcost: bool = False, + respect_blocked_errors_in_heuristic: bool = False, + verbose_search: bool = False, + local_lp_radius: Optional[int] = None, + local_lp_component: bool = False, + ) -> None: + self.errors = list(errors) + self.num_detectors = int(num_detectors) + self.num_errors = len(self.errors) + self.use_opt_singleton_detcost = use_opt_singleton_detcost + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.verbose_search = verbose_search + self.local_lp_radius = local_lp_radius + self.local_lp_component = local_lp_component + + if self.local_lp_radius is not None and self.local_lp_radius < 0: + raise ValueError("local_lp_radius must be non-negative or None") + if self.local_lp_radius is not None and self.local_lp_component: + raise ValueError("Choose at most one of local_lp_radius and local_lp_component") + + self.ecosts = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.edets: List[np.ndarray] = [ + np.array(err.detectors, dtype=np.int32) for err in self.errors + ] + self.eobs: List[np.ndarray] = [ + np.array(err.observables, dtype=np.int32) for err in self.errors + ] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.edets): + for d in dets: + d2e_lists[int(d)].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + self.plain_heuristic_calls = 0 + self.projection_heuristic_calls = 0 + self.refinement_calls = 0 + self.lp_calls = 0 + self.full_lp_calls = 0 + self.component_lp_calls = 0 + self.radius_lp_calls = 0 + self.lp_reinserts = 0 + self.projected_nodes_generated = 0 + self.projected_nodes_refined = 0 + self.total_lp_refinement_gain = 0.0 + self.max_lp_refinement_gain = 0.0 + + @property + def heuristic_name(self) -> str: + if not self.use_opt_singleton_detcost: + return "plain-detcost" + if self.local_lp_component: + return "opt-singleton-detcost-lazy-projection-component" + if self.local_lp_radius is not None: + return f"opt-singleton-detcost-lazy-projection-radius{self.local_lp_radius}" + return "opt-singleton-detcost-lazy-projection-full" + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _plain_detcost_heuristic( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> float: + self.heuristic_calls += 1 + self.plain_heuristic_calls += 1 + + total = 0.0 + for d in np.flatnonzero(dets): + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.ecosts[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF + total += best + return total + + def _build_support_system( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> SupportSystem: + active_dets = np.flatnonzero(dets) + if active_dets.size == 0: + return SupportSystem( + active_dets=active_dets, + support_to_weight={}, + covered_all=True, + ) + + covered = np.zeros(self.num_detectors, dtype=bool) + support_to_weight: Dict[Tuple[int, ...], float] = {} + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + if int(det_counts[ei]) == 0: + continue + support = tuple(int(d) for d in self.edets[ei] if dets[int(d)]) + if not support: + continue + for d in support: + covered[d] = True + weight = float(self.ecosts[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + covered_all = bool(np.all(covered[active_dets])) + return SupportSystem( + active_dets=active_dets, + support_to_weight=support_to_weight, + covered_all=covered_all, + ) + + def _build_active_neighbors( + self, + supports: Iterable[Tuple[int, ...]], + active_dets: np.ndarray, + ) -> Dict[int, Set[int]]: + neighbors: Dict[int, Set[int]] = {int(d): set() for d in active_dets.tolist()} + for support in supports: + if len(support) <= 1: + continue + support_list = list(support) + for i, d in enumerate(support_list): + nbrs = neighbors[int(d)] + for od in support_list[:i]: + nbrs.add(int(od)) + for od in support_list[i + 1 :]: + nbrs.add(int(od)) + return neighbors + + def _seed_frontier_from_flipped_detectors( + self, + flipped_detectors: np.ndarray, + dets: np.ndarray, + ) -> np.ndarray: + if flipped_detectors is None or len(flipped_detectors) == 0: + return np.zeros(0, dtype=np.int32) + + frontier: Set[int] = set() + seen_errors: Set[int] = set() + for fd in flipped_detectors: + for ei in self.d2e[int(fd)]: + ei = int(ei) + if ei in seen_errors: + continue + seen_errors.add(ei) + for d in self.edets[ei]: + d = int(d) + if dets[d]: + frontier.add(d) + if not frontier: + return np.zeros(0, dtype=np.int32) + return np.array(sorted(frontier), dtype=np.int32) + + def _free_dets_from_scope( + self, + support_system: SupportSystem, + dets: np.ndarray, + flipped_detectors: Optional[np.ndarray], + ) -> Tuple[np.ndarray, int]: + active_dets = support_system.active_dets + if active_dets.size == 0: + return np.zeros(0, dtype=np.int32), 0 + if flipped_detectors is None or len(flipped_detectors) == 0: + return np.array(active_dets, copy=True), int(active_dets.size) + + seed_frontier = self._seed_frontier_from_flipped_detectors(flipped_detectors, dets) + seed_frontier_size = int(seed_frontier.size) + if seed_frontier_size == 0: + return np.zeros(0, dtype=np.int32), 0 + + neighbors = self._build_active_neighbors(support_system.support_to_weight.keys(), active_dets) + if self.local_lp_component: + radius_limit: Optional[int] = None + elif self.local_lp_radius is not None: + radius_limit = int(self.local_lp_radius) + else: + return np.array(active_dets, copy=True), seed_frontier_size + + visited: Set[int] = set(int(d) for d in seed_frontier.tolist()) + frontier: List[int] = [int(d) for d in seed_frontier.tolist()] + depth = 0 + + while frontier and (radius_limit is None or depth < radius_limit): + next_frontier: List[int] = [] + for d in frontier: + for od in neighbors.get(d, ()): # detector adjacency in active support graph + if od in visited: + continue + visited.add(od) + next_frontier.append(od) + frontier = next_frontier + depth += 1 + + if not visited: + return np.zeros(0, dtype=np.int32), seed_frontier_size + return np.array(sorted(visited), dtype=np.int32), seed_frontier_size + + def _solve_lp_with_fixed_outside( + self, + support_system: SupportSystem, + base_y: np.ndarray, + free_dets: np.ndarray, + *, + mode: str, + seed_frontier_size: int, + ) -> OptSingletonLPResult: + self.heuristic_calls += 1 + self.refinement_calls += 1 + self.lp_calls += 1 + if mode == "full": + self.full_lp_calls += 1 + elif mode == "component": + self.component_lp_calls += 1 + elif mode.startswith("radius"): + self.radius_lp_calls += 1 + else: + raise ValueError(f"Unknown LP mode: {mode}") + + active_dets = support_system.active_dets + num_active_dets = int(active_dets.size) + num_supports = int(len(support_system.support_to_weight)) + + if num_active_dets == 0: + return OptSingletonLPResult( + value=0.0, + y_full=np.zeros(self.num_detectors, dtype=np.float64), + num_active_dets=0, + num_supports=0, + num_free_vars=0, + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + + if not support_system.covered_all: + return OptSingletonLPResult( + value=INF, + y_full=np.array(base_y, copy=True), + num_active_dets=num_active_dets, + num_supports=num_supports, + num_free_vars=int(free_dets.size), + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + + y_full = np.array(base_y, copy=True) + free_dets = np.array(sorted(set(int(d) for d in free_dets.tolist() if bool(base_y.shape[0] > d))), dtype=np.int32) + if free_dets.size == 0: + return OptSingletonLPResult( + value=float(y_full[active_dets].sum()), + y_full=y_full, + num_active_dets=num_active_dets, + num_supports=num_supports, + num_free_vars=0, + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + + free_set = set(int(d) for d in free_dets.tolist()) + det_to_var = {int(d): i for i, d in enumerate(free_dets.tolist())} + + row_indices: List[int] = [] + col_indices: List[int] = [] + data: List[float] = [] + rhs: List[float] = [] + row = 0 + + for support, weight in support_system.support_to_weight.items(): + fixed_sum = 0.0 + free_support_vars: List[int] = [] + for d in support: + d = int(d) + if d in free_set: + free_support_vars.append(det_to_var[d]) + else: + fixed_sum += float(y_full[d]) + remaining = float(weight) - fixed_sum + if remaining < -FEAS_EPS: + raise AssertionError( + f"Base y is infeasible in restricted LP: remaining={remaining} mode={mode}" + ) + remaining = max(0.0, remaining) + if not free_support_vars: + continue + rhs.append(remaining) + row_indices.extend([row] * len(free_support_vars)) + col_indices.extend(free_support_vars) + data.extend([1.0] * len(free_support_vars)) + row += 1 + + if row == 0: + return OptSingletonLPResult( + value=float(y_full[active_dets].sum()), + y_full=y_full, + num_active_dets=num_active_dets, + num_supports=num_supports, + num_free_vars=int(free_dets.size), + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + + a_ub = csr_matrix( + (data, (row_indices, col_indices)), + shape=(row, int(free_dets.size)), + dtype=np.float64, + ) + + result = linprog( + c=-np.ones(int(free_dets.size), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * int(free_dets.size), + method="highs", + ) + if result.status == 0: + y_full[free_dets] = np.asarray(result.x, dtype=np.float64) + return OptSingletonLPResult( + value=float(y_full[active_dets].sum()), + y_full=y_full, + num_active_dets=num_active_dets, + num_supports=num_supports, + num_free_vars=int(free_dets.size), + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + if result.status in {2, 3}: # infeasible or unbounded + return OptSingletonLPResult( + value=INF, + y_full=y_full, + num_active_dets=num_active_dets, + num_supports=num_supports, + num_free_vars=int(free_dets.size), + seed_frontier_size=seed_frontier_size, + mode=mode, + ) + raise RuntimeError(f"linprog failed with status={result.status}: {result.message}") + + def _plain_heuristic_from_state(self, state: SearchState) -> float: + available = self._available_errors(state.errs, state.blocked_errs) + return self._plain_detcost_heuristic(available, state.dets, state.det_counts) + + def _project_child_solution_and_heuristic( + self, + parent_state: SearchState, + flipped_detectors: np.ndarray, + ) -> Tuple[np.ndarray, float]: + if parent_state.lp_y is None: + raise AssertionError("Expected parent LP solution before projecting to children.") + + self.heuristic_calls += 1 + self.projection_heuristic_calls += 1 + + child_y = np.array(parent_state.lp_y, copy=True) + for d in flipped_detectors: + d = int(d) + if parent_state.dets[d]: + child_y[d] = 0.0 + value = float(parent_state.h_cost) + for d in flipped_detectors: + d = int(d) + if parent_state.dets[d]: + value -= float(parent_state.lp_y[d]) + if value < -HEURISTIC_EPS: + raise AssertionError(f"Projected heuristic became negative: {value}") + return child_y, max(0.0, value) + + def _refine_scope_name(self) -> str: + if self.local_lp_component: + return "component" + if self.local_lp_radius is not None: + return f"radius{self.local_lp_radius}" + return "full" + + def _refine_node_lp( + self, + state: SearchState, + ) -> OptSingletonLPResult: + available = self._available_errors(state.errs, state.blocked_errs) + support_system = self._build_support_system(available, state.dets, state.det_counts) + + # Root or any node without a projected parent solution falls back to a full LP. + if state.lp_y is None or state.seed_detectors is None: + base_y = np.zeros(self.num_detectors, dtype=np.float64) + return self._solve_lp_with_fixed_outside( + support_system, + base_y, + support_system.active_dets, + mode="full", + seed_frontier_size=0, + ) + + if self.local_lp_component or self.local_lp_radius is not None: + free_dets, seed_frontier_size = self._free_dets_from_scope( + support_system, + state.dets, + state.seed_detectors, + ) + mode = self._refine_scope_name() + return self._solve_lp_with_fixed_outside( + support_system, + state.lp_y, + free_dets, + mode=mode, + seed_frontier_size=seed_frontier_size, + ) + + return self._solve_lp_with_fixed_outside( + support_system, + np.zeros(self.num_detectors, dtype=np.float64), + support_system.active_dets, + mode="full", + seed_frontier_size=0, + ) + + def _maybe_refine_node_with_lp( + self, + node_id: int, + state: SearchState, + num_dets: int, + ) -> Tuple[SearchState, Optional[Tuple[float, int]], Optional[Dict[str, float | str]]]: + if not self.use_opt_singleton_detcost or state.ready_to_expand: + return state, None, None + + prev_h = state.h_cost + prev_source = state.h_source + lp_result = self._refine_node_lp(state) + refined_h = lp_result.value + + if math.isinf(refined_h): + refine_info = { + "approx_h": prev_h, + "exact_h": refined_h, + "delta": INF, + "num_vars": float(lp_result.num_active_dets), + "num_supports": float(lp_result.num_supports), + "num_free_vars": float(lp_result.num_free_vars), + "seed_frontier_size": float(lp_result.seed_frontier_size), + "reinserted": 0.0, + "discarded": 1.0, + "mode": lp_result.mode, + } + if prev_source == "projected": + self.projected_nodes_refined += 1 + return state, None, refine_info + + if refined_h + 1e-7 < prev_h: + raise AssertionError( + f"Refined LP lower bound {refined_h} is below stored {prev_source} lower bound {prev_h}." + ) + + delta = refined_h - prev_h + if prev_source == "projected": + self.projected_nodes_refined += 1 + self.total_lp_refinement_gain += delta + self.max_lp_refinement_gain = max(self.max_lp_refinement_gain, delta) + + state.h_cost = refined_h + state.h_source = lp_result.mode + state.ready_to_expand = True + state.lp_y = lp_result.y_full + + should_reinsert = delta > HEURISTIC_EPS + reinsert_entry = (state.g_cost + refined_h, num_dets) if should_reinsert else None + if should_reinsert: + self.lp_reinserts += 1 + + refine_info = { + "approx_h": prev_h, + "exact_h": refined_h, + "delta": delta, + "num_vars": float(lp_result.num_active_dets), + "num_supports": float(lp_result.num_supports), + "num_free_vars": float(lp_result.num_free_vars), + "seed_frontier_size": float(lp_result.seed_frontier_size), + "reinserted": 1.0 if should_reinsert else 0.0, + "discarded": 0.0, + "mode": lp_result.mode, + } + return state, reinsert_entry, refine_info + + def _log_pop( + self, + *, + heap_len: int, + nodes_pushed: int, + nodes_popped: int, + num_dets: int, + max_num_dets: float, + f_cost: float, + state: SearchState, + ) -> None: + if not self.verbose_search: + return + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f"len(heap)={heap_len} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"lp_calls={self.lp_calls} full_lp_calls={self.full_lp_calls} " + f"component_lp_calls={self.component_lp_calls} radius_lp_calls={self.radius_lp_calls} " + f"lp_reinserts={self.lp_reinserts} proj_generated={self.projected_nodes_generated} " + f"proj_refined={self.projected_nodes_refined} proj_unrefined_so_far={projected_unrefined} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f} " + f"h={state.h_cost:.6f} h_source={state.h_source} ready_to_expand={state.ready_to_expand}" + ) + + def _log_refine(self, node_id: int, info: Dict[str, float | str]) -> None: + if not self.verbose_search: + return + exact_h = float(info["exact_h"]) + exact_text = "INF" if math.isinf(exact_h) else f"{exact_h:.6f}" + delta = float(info["delta"]) + delta_text = "INF" if math.isinf(delta) else f"{delta:.6f}" + print( + f" lp_refine node={node_id} mode={info['mode']} approx_h={float(info['approx_h']):.6f} " + f"refined_h={exact_text} delta={delta_text} vars={int(float(info['num_vars']))} " + f"supports={int(float(info['num_supports']))} free_vars={int(float(info['num_free_vars']))} " + f"seed_frontier={int(float(info['seed_frontier_size']))} " + f"reinserted={bool(info['reinserted'])} discarded={bool(info['discarded'])}" + ) + + def _log_expand( + self, + *, + node_id: int, + children_generated: int, + children_projected: int, + children_beam_pruned: int, + children_infeasible: int, + ) -> None: + if not self.verbose_search: + return + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f" expanded node={node_id} children_generated={children_generated} " + f"children_projected={children_projected} beam_pruned={children_beam_pruned} " + f"infeasible={children_infeasible} lp_calls={self.lp_calls} full_lp_calls={self.full_lp_calls} " + f"component_lp_calls={self.component_lp_calls} radius_lp_calls={self.radius_lp_calls} " + f"proj_unrefined_so_far={projected_unrefined}" + ) + + def _result( + self, + *, + success: bool, + errs: np.ndarray, + residual_dets: np.ndarray, + cost: float, + nodes_pushed: int, + nodes_popped: int, + start_time: float, + ) -> DecodeResult: + return DecodeResult( + success=success, + errs=errs, + residual_dets=residual_dets, + cost=cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + full_lp_calls=self.full_lp_calls, + component_lp_calls=self.component_lp_calls, + radius_lp_calls=self.radius_lp_calls, + lp_reinserts=self.lp_reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=( + self.projected_nodes_generated - self.projected_nodes_refined + ), + total_lp_refinement_gain=self.total_lp_refinement_gain, + max_lp_refinement_gain=self.max_lp_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + root_state = SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + h_cost=0.0, + h_source="plain", + ready_to_expand=not self.use_opt_singleton_detcost, + lp_y=None, + seed_detectors=None, + ) + root_state.h_cost = self._plain_heuristic_from_state(root_state) + if math.isinf(root_state.h_cost): + return self._result( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + start_time=start_time, + ) + + next_node_id = 1 + heap: List[Tuple[float, int, int]] = [ + (root_state.g_cost + root_state.h_cost, int(dets0.sum()), 0) + ] + node_data: Dict[int, SearchState] = {0: root_state} + + nodes_pushed = 1 + nodes_popped = 0 + min_num_dets = int(dets0.sum()) + + while heap: + f_cost, num_dets, node_id = heapq.heappop(heap) + state = node_data.pop(node_id, None) + if state is None: + continue + nodes_popped += 1 + + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + self._log_pop( + heap_len=len(heap), + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + num_dets=num_dets, + max_num_dets=max_num_dets, + f_cost=f_cost, + state=state, + ) + + if num_dets == 0: + return self._result( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + start_time=start_time, + ) + + state, reinsert_entry, refine_info = self._maybe_refine_node_with_lp( + node_id=node_id, + state=state, + num_dets=num_dets, + ) + if refine_info is not None: + self._log_refine(node_id, refine_info) + if bool(refine_info["discarded"]): + continue + if reinsert_entry is not None: + node_data[node_id] = state + heapq.heappush(heap, (reinsert_entry[0], reinsert_entry[1], node_id)) + continue + + if self.use_opt_singleton_detcost and not state.ready_to_expand: + raise AssertionError("Opt-singleton mode should only expand refined nodes.") + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked_errs = state.blocked_errs.copy() + + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + children_projected = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked_errs[ei] = True + + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked_errs = prefix_blocked_errs.copy() + child_dets = state.dets.copy() + child_det_counts = state.det_counts.copy() + + for d in self.edets[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_g = state.g_cost + float(self.ecosts[ei]) + + if self.use_opt_singleton_detcost: + if state.lp_y is None: + raise AssertionError("Expected a refined parent LP solution before projection.") + child_lp_y, child_h = self._project_child_solution_and_heuristic( + state, + self.edets[ei], + ) + child_h_source = "projected" + child_ready_to_expand = False + child_seed_detectors = np.array(self.edets[ei], copy=True) + self.projected_nodes_generated += 1 + children_projected += 1 + else: + child_tmp_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked_errs, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=0.0, + h_source="plain", + ready_to_expand=True, + lp_y=None, + seed_detectors=None, + ) + child_h = self._plain_heuristic_from_state(child_tmp_state) + child_h_source = "plain" + child_ready_to_expand = True + child_lp_y = None + child_seed_detectors = None + if math.isinf(child_h): + children_infeasible += 1 + continue + + child_id = next_node_id + next_node_id += 1 + node_data[child_id] = SearchState( + errs=child_errs, + blocked_errs=child_blocked_errs, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=child_h, + h_source=child_h_source, + ready_to_expand=child_ready_to_expand, + lp_y=child_lp_y, + seed_detectors=child_seed_detectors, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, child_id)) + nodes_pushed += 1 + children_generated += 1 + + self._log_expand( + node_id=node_id, + children_generated=children_generated, + children_projected=children_projected, + children_beam_pruned=children_beam_pruned, + children_infeasible=children_infeasible, + ) + + return self._result( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + start_time=start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.ecosts[errs].sum()) + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.eobs[int(ei)]: + obs = int(obs) + parity[obs] = not parity.get(obs, False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.edets[int(ei)]: + dets[int(d)] ^= True + return dets + + +def merged_errors_from_dem(dem) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + + for error in dem.flattened(): + if error.type != "error": + continue + + probability = float(error.args_copy()[0]) + if probability <= 0: + continue + if probability > 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5], got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in error.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected target type: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + key = (tuple(sorted(detectors)), tuple(sorted(observables))) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = probability + else: + p_new = p_old * (1.0 - probability) + (1.0 - p_old) * probability + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def sample_detections_and_observables(circuit, num_shots: int, seed: int) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=circuit.num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=circuit.num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def parse_nonnegative_int(text: str) -> int: + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("Expected a non-negative integer.") + return value + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder using the plain detector-wise heuristic or a lazy " + "projected version of the optimal singleton detector heuristic." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--shot", + type=int, + default=0, + help="Shot index to decode after sampling --sample-num-shots shots (default: 0).", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot (default: 100).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the number of residual detections; use 'inf' for none (default: inf).", + ) + parser.add_argument( + "--opt-singleton-detcost", + action="store_true", + help=( + "Use lazy refinement of the optimal singleton detector-cost lower bound. " + "Nodes are seeded with projected LP prices from their parent and refined " + "when popped." + ), + ) + parser.add_argument( + "--local-lp-component", + action="store_true", + help=( + "Instead of a full LP on pop, only reoptimize the active support component(s) " + "touched by the flipped error that created the node." + ), + ) + parser.add_argument( + "--local-lp-radius", + type=parse_nonnegative_int, + default=None, + help=( + "Instead of a full LP on pop, only reoptimize detector prices within this " + "radius in the active support graph around the changed region." + ), + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action="store_true", + help=( + "Exclude precedence-blocked errors from the heuristic. By default the script " + "preserves the original prototype's behavior and only excludes already-activated errors." + ), + ) + parser.add_argument( + "--show-detections", + action="store_true", + help="Print the selected shot's detection events before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action="store_true", + help="Print the decoded merged-error indices.", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print detailed search, LP-refinement, projection, and locality statistics during A* search.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + if args.local_lp_component and args.local_lp_radius is not None: + parser.error("Choose at most one of --local-lp-component and --local-lp-radius.") + if (args.local_lp_component or args.local_lp_radius is not None) and not args.opt_singleton_detcost: + parser.error("Local LP refinement flags require --opt-singleton-detcost.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) + + dets_unpacked, obs_unpacked = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + ) + shot_dets = dets_unpacked[args.shot] + shot_obs = obs_unpacked[args.shot] + + if args.show_detections: + active_dets = np.flatnonzero(shot_dets) + print("detections:", " ".join(f"D{d}" for d in active_dets)) + + decoder = AStarPrototypeDecoder( + errors, + dem.num_detectors, + use_opt_singleton_detcost=args.opt_singleton_detcost, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + verbose_search=args.verbose_search, + local_lp_radius=args.local_lp_radius, + local_lp_component=args.local_lp_component, + ) + result = decoder.decode(shot_dets, det_beam=args.det_beam) + + print(f"heuristic: {decoder.heuristic_name}") + print(f"shot: {args.shot} / {args.sample_num_shots}") + print(f"success: {result.success}") + print(f"nodes_pushed: {result.nodes_pushed}") + print(f"nodes_popped: {result.nodes_popped}") + print(f"heuristic_calls: {result.heuristic_calls}") + print(f"plain_heuristic_calls: {result.plain_heuristic_calls}") + print(f"projection_heuristic_calls: {result.projection_heuristic_calls}") + print(f"refinement_calls: {result.refinement_calls}") + print(f"lp_calls: {result.lp_calls}") + print(f"full_lp_calls: {result.full_lp_calls}") + print(f"component_lp_calls: {result.component_lp_calls}") + print(f"radius_lp_calls: {result.radius_lp_calls}") + print(f"lp_reinserts: {result.lp_reinserts}") + print(f"projected_nodes_generated: {result.projected_nodes_generated}") + print(f"projected_nodes_refined: {result.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish: {result.projected_nodes_unrefined_at_finish}") + print(f"total_lp_refinement_gain: {result.total_lp_refinement_gain:.6f}") + print(f"max_lp_refinement_gain: {result.max_lp_refinement_gain:.6f}") + print(f"elapsed_seconds: {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + decoded_err_indices = np.flatnonzero(result.errs) + if args.show_error_indices: + print("decoded_error_indices:", " ".join(map(str, decoded_err_indices.tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + reproduced_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + actual_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors: {int(result.errs.sum())}") + print(f"decoded_cost: {reproduced_cost:.12f}") + print("predicted_observables:", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables:", " ".join(f"L{o}" for o in actual_obs.tolist())) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_projected.py b/src/py/astar/astar_prototype_projected.py new file mode 100644 index 0000000..5843c3a --- /dev/null +++ b/src/py/astar/astar_prototype_projected.py @@ -0,0 +1,882 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with lazy singleton-LP refinement. + +The default heuristic matches the original prototype's plain detector-wise +heuristic. Passing --opt-singleton-detcost enables a lazy version of the exact +optimal singleton detector lower bound: + + * a node is first inserted with a cheap lower bound; + * when the node is popped, the exact singleton LP is solved; + * if the exact LP value raises the node's key, the node is reinserted; + * expanded nodes project their exact LP solution onto each child to seed a + much tighter cheap first-pass lower bound than plain detcost. + +This keeps the prototype pedagogical while making the expensive LP solves much +more selective. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class OptSingletonLPResult: + value: float + y_full: np.ndarray + num_active_dets: int + num_supports: int + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + h_cost: float + h_source: str + exact_refined: bool + lp_y: Optional[np.ndarray] = None + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + exact_refinement_calls: int + lp_calls: int + lp_reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_lp_refinement_gain: float + max_lp_refinement_gain: float + elapsed_seconds: float + + +class AStarPrototypeDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + *, + use_opt_singleton_detcost: bool = False, + respect_blocked_errors_in_heuristic: bool = False, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_detectors = int(num_detectors) + self.num_errors = len(self.errors) + self.use_opt_singleton_detcost = use_opt_singleton_detcost + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.verbose_search = verbose_search + + self.ecosts = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.edets: List[np.ndarray] = [ + np.array(err.detectors, dtype=np.int32) for err in self.errors + ] + self.eobs: List[np.ndarray] = [ + np.array(err.observables, dtype=np.int32) for err in self.errors + ] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.edets): + for d in dets: + d2e_lists[int(d)].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + self.plain_heuristic_calls = 0 + self.projection_heuristic_calls = 0 + self.exact_refinement_calls = 0 + self.lp_calls = 0 + self.lp_reinserts = 0 + self.projected_nodes_generated = 0 + self.projected_nodes_refined = 0 + self.total_lp_refinement_gain = 0.0 + self.max_lp_refinement_gain = 0.0 + + @property + def heuristic_name(self) -> str: + if self.use_opt_singleton_detcost: + return "opt-singleton-detcost-lazy-projection" + return "plain-detcost" + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _plain_detcost_heuristic( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> float: + self.heuristic_calls += 1 + self.plain_heuristic_calls += 1 + + total = 0.0 + for d in np.flatnonzero(dets): + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.ecosts[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF + total += best + return total + + def _solve_opt_singleton_lp( + self, + available_errs: np.ndarray, + dets: np.ndarray, + det_counts: np.ndarray, + ) -> OptSingletonLPResult: + self.heuristic_calls += 1 + self.exact_refinement_calls += 1 + + active_dets = np.flatnonzero(dets) + if active_dets.size == 0: + return OptSingletonLPResult( + value=0.0, + y_full=np.zeros(self.num_detectors, dtype=np.float64), + num_active_dets=0, + num_supports=0, + ) + + det_to_var = {int(d): i for i, d in enumerate(active_dets.tolist())} + support_to_weight: Dict[Tuple[int, ...], float] = {} + covered = np.zeros(active_dets.size, dtype=bool) + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + if int(det_counts[ei]) == 0: + continue + support = tuple(det_to_var[int(d)] for d in self.edets[ei] if dets[int(d)]) + if not support: + continue + for var in support: + covered[var] = True + weight = float(self.ecosts[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + if not np.all(covered): + return OptSingletonLPResult( + value=INF, + y_full=np.zeros(self.num_detectors, dtype=np.float64), + num_active_dets=int(active_dets.size), + num_supports=len(support_to_weight), + ) + + supports = list(support_to_weight.keys()) + weights = np.array([support_to_weight[s] for s in supports], dtype=np.float64) + num_vars = int(active_dets.size) + + row_indices: List[int] = [] + col_indices: List[int] = [] + data: List[float] = [] + for row, support in enumerate(supports): + row_indices.extend([row] * len(support)) + col_indices.extend(support) + data.extend([1.0] * len(support)) + + a_ub = csr_matrix( + (data, (row_indices, col_indices)), + shape=(len(supports), num_vars), + dtype=np.float64, + ) + + self.lp_calls += 1 + result = linprog( + c=-np.ones(num_vars, dtype=np.float64), + A_ub=a_ub, + b_ub=weights, + bounds=[(0.0, None)] * num_vars, + method="highs", + ) + if result.status == 0: + y_full = np.zeros(self.num_detectors, dtype=np.float64) + y_full[active_dets] = np.asarray(result.x, dtype=np.float64) + return OptSingletonLPResult( + value=max(0.0, float(-result.fun)), + y_full=y_full, + num_active_dets=num_vars, + num_supports=len(supports), + ) + if result.status in {2, 3}: # infeasible or unbounded + return OptSingletonLPResult( + value=INF, + y_full=np.zeros(self.num_detectors, dtype=np.float64), + num_active_dets=num_vars, + num_supports=len(supports), + ) + raise RuntimeError(f"linprog failed with status={result.status}: {result.message}") + + def _plain_heuristic_from_state(self, state: SearchState) -> float: + available = self._available_errors(state.errs, state.blocked_errs) + return self._plain_detcost_heuristic(available, state.dets, state.det_counts) + + def _project_child_heuristic( + self, + parent_state: SearchState, + flipped_detectors: np.ndarray, + ) -> float: + if parent_state.lp_y is None: + raise AssertionError("Expected parent exact LP solution before projecting to children.") + + self.heuristic_calls += 1 + self.projection_heuristic_calls += 1 + + value = parent_state.h_cost + for d in flipped_detectors: + d = int(d) + if parent_state.dets[d]: + value -= float(parent_state.lp_y[d]) + if value < -HEURISTIC_EPS: + raise AssertionError(f"Projected heuristic became negative: {value}") + return max(0.0, value) + + def _maybe_refine_node_with_exact_lp( + self, + node_id: int, + state: SearchState, + num_dets: int, + ) -> Tuple[SearchState, Optional[Tuple[float, int]], Optional[Dict[str, float]]]: + if not self.use_opt_singleton_detcost or state.exact_refined: + return state, None, None + + prev_h = state.h_cost + prev_source = state.h_source + available = self._available_errors(state.errs, state.blocked_errs) + lp_result = self._solve_opt_singleton_lp(available, state.dets, state.det_counts) + exact_h = lp_result.value + + if math.isinf(exact_h): + refine_info = { + "approx_h": prev_h, + "exact_h": exact_h, + "delta": INF, + "num_vars": float(lp_result.num_active_dets), + "num_supports": float(lp_result.num_supports), + "reinserted": 0.0, + "discarded": 1.0, + } + if prev_source == "projected": + self.projected_nodes_refined += 1 + return state, None, refine_info + + if exact_h + 1e-7 < prev_h: + raise AssertionError( + f"Exact LP lower bound {exact_h} is below stored {prev_source} lower bound {prev_h}." + ) + + delta = exact_h - prev_h + if prev_source == "projected": + self.projected_nodes_refined += 1 + self.total_lp_refinement_gain += delta + self.max_lp_refinement_gain = max(self.max_lp_refinement_gain, delta) + + state.h_cost = exact_h + state.h_source = "exact" + state.exact_refined = True + state.lp_y = lp_result.y_full + + should_reinsert = delta > HEURISTIC_EPS + reinsert_entry = (state.g_cost + exact_h, num_dets) if should_reinsert else None + if should_reinsert: + self.lp_reinserts += 1 + + refine_info = { + "approx_h": prev_h, + "exact_h": exact_h, + "delta": delta, + "num_vars": float(lp_result.num_active_dets), + "num_supports": float(lp_result.num_supports), + "reinserted": 1.0 if should_reinsert else 0.0, + "discarded": 0.0, + } + return state, reinsert_entry, refine_info + + def _log_pop( + self, + *, + heap_len: int, + nodes_pushed: int, + nodes_popped: int, + num_dets: int, + max_num_dets: float, + f_cost: float, + state: SearchState, + ) -> None: + if not self.verbose_search: + return + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f"len(heap)={heap_len} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"lp_calls={self.lp_calls} lp_reinserts={self.lp_reinserts} " + f"proj_generated={self.projected_nodes_generated} proj_refined={self.projected_nodes_refined} " + f"proj_unrefined_so_far={projected_unrefined} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f} " + f"h={state.h_cost:.6f} h_source={state.h_source} exact_refined={state.exact_refined}" + ) + + def _log_refine(self, node_id: int, info: Dict[str, float]) -> None: + if not self.verbose_search: + return + exact_h = info["exact_h"] + exact_text = "INF" if math.isinf(exact_h) else f"{exact_h:.6f}" + delta = info["delta"] + delta_text = "INF" if math.isinf(delta) else f"{delta:.6f}" + print( + f" lp_refine node={node_id} approx_h={info['approx_h']:.6f} exact_h={exact_text} " + f"delta={delta_text} vars={int(info['num_vars'])} supports={int(info['num_supports'])} " + f"reinserted={bool(info['reinserted'])} discarded={bool(info['discarded'])}" + ) + + def _log_expand( + self, + *, + node_id: int, + children_generated: int, + children_projected: int, + children_beam_pruned: int, + children_infeasible: int, + ) -> None: + if not self.verbose_search: + return + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f" expanded node={node_id} children_generated={children_generated} " + f"children_projected={children_projected} beam_pruned={children_beam_pruned} " + f"infeasible={children_infeasible} lp_calls={self.lp_calls} " + f"proj_unrefined_so_far={projected_unrefined}" + ) + + def _result( + self, + *, + success: bool, + errs: np.ndarray, + residual_dets: np.ndarray, + cost: float, + nodes_pushed: int, + nodes_popped: int, + start_time: float, + ) -> DecodeResult: + return DecodeResult( + success=success, + errs=errs, + residual_dets=residual_dets, + cost=cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + exact_refinement_calls=self.exact_refinement_calls, + lp_calls=self.lp_calls, + lp_reinserts=self.lp_reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=( + self.projected_nodes_generated - self.projected_nodes_refined + ), + total_lp_refinement_gain=self.total_lp_refinement_gain, + max_lp_refinement_gain=self.max_lp_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + root_state = SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + h_cost=0.0, + h_source="plain", + exact_refined=not self.use_opt_singleton_detcost, + lp_y=None, + ) + root_state.h_cost = self._plain_heuristic_from_state(root_state) + if math.isinf(root_state.h_cost): + return self._result( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + start_time=start_time, + ) + + next_node_id = 1 + heap: List[Tuple[float, int, int]] = [ + (root_state.g_cost + root_state.h_cost, int(dets0.sum()), 0) + ] + node_data: Dict[int, SearchState] = {0: root_state} + + nodes_pushed = 1 + nodes_popped = 0 + min_num_dets = int(dets0.sum()) + + while heap: + f_cost, num_dets, node_id = heapq.heappop(heap) + state = node_data.pop(node_id, None) + if state is None: + continue + nodes_popped += 1 + + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + self._log_pop( + heap_len=len(heap), + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + num_dets=num_dets, + max_num_dets=max_num_dets, + f_cost=f_cost, + state=state, + ) + + if num_dets == 0: + return self._result( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + start_time=start_time, + ) + + state, reinsert_entry, refine_info = self._maybe_refine_node_with_exact_lp( + node_id=node_id, + state=state, + num_dets=num_dets, + ) + if refine_info is not None: + self._log_refine(node_id, refine_info) + if bool(refine_info["discarded"]): + continue + if reinsert_entry is not None: + node_data[node_id] = state + heapq.heappush(heap, (reinsert_entry[0], reinsert_entry[1], node_id)) + continue + + if self.use_opt_singleton_detcost and not state.exact_refined: + raise AssertionError("Opt-singleton mode should only expand exact-refined nodes.") + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked_errs = state.blocked_errs.copy() + + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + children_projected = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked_errs[ei] = True + + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked_errs = prefix_blocked_errs.copy() + child_dets = state.dets.copy() + child_det_counts = state.det_counts.copy() + + for d in self.edets[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_g = state.g_cost + float(self.ecosts[ei]) + + if self.use_opt_singleton_detcost: + child_h = self._project_child_heuristic(state, self.edets[ei]) + child_h_source = "projected" + child_exact_refined = False + child_lp_y = None + self.projected_nodes_generated += 1 + children_projected += 1 + else: + child_tmp_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked_errs, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=0.0, + h_source="plain", + exact_refined=False, + lp_y=None, + ) + child_h = self._plain_heuristic_from_state(child_tmp_state) + child_h_source = "plain" + child_exact_refined = True + child_lp_y = None + if math.isinf(child_h): + children_infeasible += 1 + continue + + child_id = next_node_id + next_node_id += 1 + node_data[child_id] = SearchState( + errs=child_errs, + blocked_errs=child_blocked_errs, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=child_h, + h_source=child_h_source, + exact_refined=child_exact_refined, + lp_y=child_lp_y, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, child_id)) + nodes_pushed += 1 + children_generated += 1 + + self._log_expand( + node_id=node_id, + children_generated=children_generated, + children_projected=children_projected, + children_beam_pruned=children_beam_pruned, + children_infeasible=children_infeasible, + ) + + return self._result( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + start_time=start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.ecosts[errs].sum()) + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.eobs[int(ei)]: + obs = int(obs) + parity[obs] = not parity.get(obs, False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.edets[int(ei)]: + dets[int(d)] ^= True + return dets + + +def merged_errors_from_dem(dem) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + + for error in dem.flattened(): + if error.type != "error": + continue + + probability = float(error.args_copy()[0]) + if probability <= 0: + continue + if probability > 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5], got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in error.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected target type: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + key = (tuple(sorted(detectors)), tuple(sorted(observables))) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = probability + else: + p_new = p_old * (1.0 - probability) + (1.0 - p_old) * probability + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def sample_detections_and_observables(circuit, num_shots: int, seed: int) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=circuit.num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=circuit.num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder using the plain detector-wise heuristic or a lazy " + "projected version of the optimal singleton detector heuristic." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--shot", + type=int, + default=0, + help="Shot index to decode after sampling --sample-num-shots shots (default: 0).", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot (default: 100).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the number of residual detections; use 'inf' for none (default: inf).", + ) + parser.add_argument( + "--opt-singleton-detcost", + action="store_true", + help=( + "Use lazy refinement of the exact optimal singleton detector-cost lower bound. " + "Nodes are seeded with projected LP prices from their parent and only solved " + "exactly when popped." + ), + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action="store_true", + help=( + "Exclude precedence-blocked errors from the heuristic. By default the script " + "preserves the original prototype's behavior and only excludes already-activated errors." + ), + ) + parser.add_argument( + "--show-detections", + action="store_true", + help="Print the selected shot's detection events before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action="store_true", + help="Print the decoded merged-error indices.", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print detailed search, LP-refinement, and projection statistics during A* search.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) + + dets_unpacked, obs_unpacked = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + ) + shot_dets = dets_unpacked[args.shot] + shot_obs = obs_unpacked[args.shot] + + if args.show_detections: + active_dets = np.flatnonzero(shot_dets) + print("detections:", " ".join(f"D{d}" for d in active_dets)) + + decoder = AStarPrototypeDecoder( + errors, + dem.num_detectors, + use_opt_singleton_detcost=args.opt_singleton_detcost, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + verbose_search=args.verbose_search, + ) + result = decoder.decode(shot_dets, det_beam=args.det_beam) + + print(f"heuristic: {decoder.heuristic_name}") + print(f"shot: {args.shot} / {args.sample_num_shots}") + print(f"success: {result.success}") + print(f"nodes_pushed: {result.nodes_pushed}") + print(f"nodes_popped: {result.nodes_popped}") + print(f"heuristic_calls: {result.heuristic_calls}") + print(f"plain_heuristic_calls: {result.plain_heuristic_calls}") + print(f"projection_heuristic_calls: {result.projection_heuristic_calls}") + print(f"exact_refinement_calls: {result.exact_refinement_calls}") + print(f"lp_calls: {result.lp_calls}") + print(f"lp_reinserts: {result.lp_reinserts}") + print(f"projected_nodes_generated: {result.projected_nodes_generated}") + print(f"projected_nodes_refined: {result.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish: {result.projected_nodes_unrefined_at_finish}") + print(f"total_lp_refinement_gain: {result.total_lp_refinement_gain:.6f}") + print(f"max_lp_refinement_gain: {result.max_lp_refinement_gain:.6f}") + print(f"elapsed_seconds: {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + decoded_err_indices = np.flatnonzero(result.errs) + if args.show_error_indices: + print("decoded_error_indices:", " ".join(map(str, decoded_err_indices.tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + reproduced_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + actual_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors: {int(result.errs.sum())}") + print(f"decoded_cost: {reproduced_cost:.12f}") + print("predicted_observables:", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables:", " ".join(f"L{o}" for o in actual_obs.tolist())) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_singleton_greedy_heuristics.py b/src/py/astar/astar_prototype_singleton_greedy_heuristics.py new file mode 100644 index 0000000..42e65e5 --- /dev/null +++ b/src/py/astar/astar_prototype_singleton_greedy_heuristics.py @@ -0,0 +1,751 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder for experimenting with fast singleton-budget heuristics. + +This version mirrors the earlier Stim-based prototypes: + * load a .stim circuit, + * extract its detector error model with decompose_errors=False, + * optionally merge indistinguishable errors, + * sample detector shots from Stim, + * run precedence-pruned A* with a selectable singleton lower-bound heuristic. + +Supported heuristic choices: + plain original detector-wise feasible point + asc_deg zero-start saturation ordered by ascending detector degree + desc_plain zero-start saturation ordered by descending plain y_d + plain_sweep start from plain, then one descending saturation sweep + best_of_two max(plain_sweep, asc_deg) + best_of_three max(plain_sweep, asc_deg, desc_plain) + exact_lp exact optimal singleton LP lower bound + +The greedy heuristics are derived from feasible points of the singleton LP + + max sum_d y_d + s.t. sum_{d in T} y_d <= W(T) + y_d >= 0, + +where W(T) is the cheapest available error whose active support is T. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SupportData: + active_detectors: List[int] + supports: List[Tuple[Tuple[int, ...], float]] + incident: Dict[int, List[int]] + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + g_cost: float + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + heuristic_calls: int + elapsed_seconds: float + + +class UnionFind: + def __init__(self, n: int) -> None: + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1.0 - p1) + (1.0 - p0) * p1 + + +def iter_dem_errors_from_dem(dem: stim.DetectorErrorModel) -> Iterable[ErrorRecord]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5), got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + yield ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors_from_dem(dem: stim.DetectorErrorModel) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors_from_dem(dem): + key = (error.detectors, error.observables) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = error.probability + else: + p_new = xor_probability(p_old, error.probability) + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Merged error has probability >= 0.5 ({probability}); cannot assign positive cost." + ) + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +class GreedySingletonHeuristicDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + num_observables: int, + *, + heuristic: str = "best_of_two", + respect_blocked_errors_in_heuristic: bool = True, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_errors = len(self.errors) + self.num_detectors = int(num_detectors) + self.num_observables = int(num_observables) + self.heuristic_name = heuristic + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.verbose_search = verbose_search + + self.probabilities = np.array([err.probability for err in self.errors], dtype=np.float64) + self.weights = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.error_detectors: List[Tuple[int, ...]] = [tuple(err.detectors) for err in self.errors] + self.error_observables: List[Tuple[int, ...]] = [tuple(err.observables) for err in self.errors] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.error_detectors): + for d in dets: + d2e_lists[d].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.heuristic_calls = 0 + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + + def build_support_data(self, active_dets: np.ndarray, available_errs: np.ndarray) -> SupportData: + active_list = sorted(map(int, np.flatnonzero(active_dets))) + incident: Dict[int, List[int]] = {d: [] for d in active_list} + support_to_weight: Dict[Tuple[int, ...], float] = {} + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + support = tuple(d for d in self.error_detectors[ei] if active_dets[d]) + if not support: + continue + weight = float(self.weights[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + supports = list(support_to_weight.items()) + for i, (support, _weight) in enumerate(supports): + for d in support: + if d in incident: + incident[d].append(i) + + return SupportData(active_detectors=active_list, supports=supports, incident=incident) + + def _check_coverage(self, support_data: SupportData) -> bool: + return all(len(support_data.incident[d]) > 0 for d in support_data.active_detectors) + + def heuristic_plain(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + y = np.zeros(self.num_detectors, dtype=np.float64) + for d in support_data.active_detectors: + best = INF + for i in support_data.incident[d]: + support, weight = support_data.supports[i] + best = min(best, weight / len(support)) + y[d] = best + return float(y[support_data.active_detectors].sum()), y + + def heuristic_saturation_zero(self, support_data: SupportData, *, order_kind: str) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + slack = np.array([weight for _support, weight in support_data.supports], dtype=np.float64) + y = np.zeros(self.num_detectors, dtype=np.float64) + + if order_kind == "asc_deg": + order = sorted(support_data.active_detectors, key=lambda d: (len(support_data.incident[d]), d)) + elif order_kind == "desc_plain": + _plain_value, y_plain = self.heuristic_plain(support_data) + if y_plain is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y_plain[d], d), reverse=True) + else: + raise ValueError(f"Unknown order_kind={order_kind!r}") + + for d in order: + value = min(slack[i] for i in support_data.incident[d]) + if value < 0: + value = 0.0 + y[d] = value + for i in support_data.incident[d]: + slack[i] -= value + return float(y[support_data.active_detectors].sum()), y + + def heuristic_plain_sweep(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + plain_value, y = self.heuristic_plain(support_data) + if y is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y[d], d), reverse=True) + for d in order: + max_feasible = min( + weight - sum(y[dd] for dd in support if dd != d) + for support, weight in support_data.supports + if d in support + ) + if max_feasible > y[d]: + y[d] = max_feasible + return float(y[support_data.active_detectors].sum()), y + + def heuristic_exact_lp(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + active = support_data.active_detectors + if not active: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + detector_index = {d: i for i, d in enumerate(active)} + uf = UnionFind(len(active)) + for support, _weight in support_data.supports: + if len(support) > 1: + a = detector_index[support[0]] + for d in support[1:]: + uf.union(a, detector_index[d]) + + components: Dict[int, List[int]] = defaultdict(list) + for d in active: + components[uf.find(detector_index[d])].append(d) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for component in components.values(): + component_set = set(component) + local = {d: i for i, d in enumerate(sorted(component))} + component_supports: List[Tuple[Tuple[int, ...], float]] = [] + for support, weight in support_data.supports: + if support[0] in component_set: + component_supports.append((tuple(local[d] for d in support), weight)) + + rows: List[int] = [] + cols: List[int] = [] + data: List[float] = [] + rhs: List[float] = [] + for r, (support, weight) in enumerate(component_supports): + rhs.append(weight) + for c in support: + rows.append(r) + cols.append(c) + data.append(1.0) + + a_ub = csr_matrix( + (data, (rows, cols)), + shape=(len(component_supports), len(component)), + dtype=np.float64, + ) + result = linprog( + c=-np.ones(len(component), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * len(component), + method="highs", + ) + if not result.success: + return INF, None + total += -float(result.fun) + for d, value in zip(sorted(component), result.x): + y[d] = float(value) + return float(total), y + + def evaluate_named_heuristic(self, support_data: SupportData, name: str) -> Tuple[float, Optional[np.ndarray]]: + if name == "plain": + return self.heuristic_plain(support_data) + if name == "asc_deg": + return self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if name == "desc_plain": + return self.heuristic_saturation_zero(support_data, order_kind="desc_plain") + if name == "plain_sweep": + return self.heuristic_plain_sweep(support_data) + if name == "best_of_two": + v1, y1 = self.heuristic_plain_sweep(support_data) + v2, y2 = self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if v1 >= v2: + return v1, y1 + return v2, y2 + if name == "best_of_three": + candidates = [ + self.heuristic_plain_sweep(support_data), + self.heuristic_saturation_zero(support_data, order_kind="asc_deg"), + self.heuristic_saturation_zero(support_data, order_kind="desc_plain"), + ] + return max(candidates, key=lambda t: t[0]) + if name == "exact_lp": + return self.heuristic_exact_lp(support_data) + raise ValueError(f"Unknown heuristic {name!r}") + + def compute_heuristic(self, dets: np.ndarray, errs: np.ndarray, blocked_errs: np.ndarray) -> float: + self.heuristic_calls += 1 + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + support_data = self.build_support_data(dets, available) + value, _ = self.evaluate_named_heuristic(support_data, self.heuristic_name) + return value + + def report_root_heuristics(self, dets: np.ndarray, errs: np.ndarray, blocked_errs: np.ndarray) -> List[Tuple[str, float]]: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + support_data = self.build_support_data(dets, available) + names = ["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp"] + out: List[Tuple[str, float]] = [] + for name in names: + value, _ = self.evaluate_named_heuristic(support_data, name) + out.append((name, value)) + return out + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + h0 = self.compute_heuristic(dets0, errs0, blocked0) + if math.isinf(h0): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + heuristic_calls=self.heuristic_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + heap: List[Tuple[float, int, int, SearchState]] = [] + counter = 0 + root_state = SearchState(errs=errs0, blocked_errs=blocked0, dets=dets0, g_cost=0.0) + heapq.heappush(heap, (h0, int(dets0.sum()), counter, root_state)) + counter += 1 + nodes_pushed = 1 + nodes_popped = 0 + min_num_dets = int(dets0.sum()) + + while heap: + f_cost, num_dets, _entry_id, state = heapq.heappop(heap) + nodes_popped += 1 + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + if self.verbose_search: + print( + f"len(heap)={len(heap)} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f}" + ) + + if num_dets == 0: + return DecodeResult( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked = state.blocked_errs.copy() + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked[ei] = True + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked = prefix_blocked.copy() + child_dets = state.dets.copy() + for d in self.error_detectors[ei]: + child_dets[d] ^= True + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + child_g = state.g_cost + float(self.weights[ei]) + child_h = self.compute_heuristic(child_dets, child_errs, child_blocked) + if math.isinf(child_h): + children_infeasible += 1 + continue + child_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked, + dets=child_dets, + g_cost=child_g, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, counter, child_state)) + counter += 1 + nodes_pushed += 1 + children_generated += 1 + + if self.verbose_search: + print( + f" expanded children_generated={children_generated} beam_pruned={children_beam_pruned} " + f"infeasible={children_infeasible}" + ) + + return DecodeResult( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + heuristic_calls=self.heuristic_calls, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.weights[errs].sum()) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.error_detectors[int(ei)]: + dets[d] ^= True + return dets + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.error_observables[int(ei)]: + parity[int(obs)] = not parity.get(int(obs), False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + +def sample_detections_and_observables( + circuit: stim.Circuit, + *, + num_shots: int, + seed: int, + num_detectors: int, + num_observables: int, +) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample from Stim before selecting --shot (default: 100).", + ) + parser.add_argument( + "--shot", + type=int, + default=0, + help="Index of the sampled shot to decode (default: 0).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the residual detector count; use 'inf' for none.", + ) + parser.add_argument( + "--heuristic", + choices=["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp"], + default="best_of_two", + help="Lower-bound heuristic to use during A* search (default: best_of_two).", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action=argparse.BooleanOptionalAction, + default=True, + help="Exclude precedence-blocked errors when forming the lower bound (default: enabled).", + ) + parser.add_argument( + "--report-all-root-heuristics", + action="store_true", + help="Print all root-node heuristic values, including exact_lp, for the selected shot.", + ) + parser.add_argument( + "--skip-decode", + action="store_true", + help="Only report root heuristics; do not run A* search.", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the selected shot's active detector IDs (default: enabled).", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the decoded merged-error indices when decoding succeeds (default: enabled).", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) if args.merge_errors else list(iter_dem_errors_from_dem(dem)) + + dets, obs = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + ) + shot_dets = dets[args.shot] + shot_obs = obs[args.shot] + + decoder = GreedySingletonHeuristicDecoder( + errors, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + heuristic=args.heuristic, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + verbose_search=args.verbose_search, + ) + + print(f"circuit = {args.circuit}") + print(f"heuristic = {args.heuristic}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"shot = {args.shot}") + print(f"num_errors = {decoder.num_errors}") + print(f"num_detectors = {decoder.num_detectors}") + print(f"num_observables = {decoder.num_observables}") + print(f"det_beam = {args.det_beam}") + print(f"merge_errors = {args.merge_errors}") + print(f"respect_blocked_errors_in_heuristic = {args.respect_blocked_errors_in_heuristic}") + + if args.show_shot_detectors: + active_dets = np.flatnonzero(shot_dets) + print("shot_detectors =", " ".join(f"D{d}" for d in active_dets)) + + if args.report_all_root_heuristics: + root_errs = np.zeros(decoder.num_errors, dtype=bool) + root_blocked = np.zeros(decoder.num_errors, dtype=bool) + report = decoder.report_root_heuristics(shot_dets, root_errs, root_blocked) + exact = next((v for k, v in report if k == "exact_lp"), None) + print("root_heuristics:") + for name, value in report: + if exact is not None and not math.isinf(exact) and exact > 0: + ratio = value / exact if not math.isinf(value) else INF + ratio_text = "INF" if math.isinf(ratio) else f"{ratio:.6f}" + else: + ratio_text = "n/a" + value_text = "INF" if math.isinf(value) else f"{value:.12f}" + print(f" {name:>12s} value={value_text} ratio_to_exact={ratio_text}") + + if args.skip_decode: + return 0 + + result = decoder.decode(shot_dets, det_beam=args.det_beam) + print(f"success = {result.success}") + print(f"nodes_pushed = {result.nodes_pushed}") + print(f"nodes_popped = {result.nodes_popped}") + print(f"heuristic_calls = {result.heuristic_calls}") + print(f"elapsed_seconds = {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + if args.show_error_indices: + print("decoded_error_indices =", " ".join(map(str, np.flatnonzero(result.errs).tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + decoded_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + sampled_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors = {int(result.errs.sum())}") + print(f"decoded_cost = {decoded_cost:.12f}") + print("predicted_observables =", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables =", " ".join(f"L{o}" for o in sampled_obs.tolist())) + print(f"observables_match = {bool(np.array_equal(predicted_obs, sampled_obs))}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py b/src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py new file mode 100644 index 0000000..1051961 --- /dev/null +++ b/src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py @@ -0,0 +1,1107 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics. + +This version keeps the same Stim-facing API as the earlier greedy prototype but +adds lazy reinsertion / parent-y projection, in the same spirit as the lazy +optimal-singleton prototype: + + * nodes are seeded with a cheap feasible lower bound; + * when a node is popped, the selected heuristic is evaluated on that node; + * if the refined heuristic raises the node key, the node is reinserted; + * expanded nodes project their current feasible y-prices onto children; + * optionally, the projected child bound is maxed with plain detcost. + +Supported heuristic choices: + plain original detector-wise feasible point + asc_deg zero-start saturation ordered by ascending detector degree + desc_plain zero-start saturation ordered by descending plain y_d + plain_sweep start from plain, then one descending saturation sweep + best_of_two max(plain_sweep, asc_deg) + best_of_three max(plain_sweep, asc_deg, desc_plain) + exact_lp exact optimal singleton LP lower bound + +When --lazy-reinsert-heuristics is enabled (the default), the root is seeded by +plain detcost and only popped nodes are refined with the selected heuristic. +This works for all of the above heuristics because each returns a feasible +singleton-budget vector y, and projecting that y to a child by keeping prices +on detectors that remain active and zeroing newly active detectors is still a +feasible child singleton-budget point. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SupportData: + active_detectors: List[int] + supports: List[Tuple[Tuple[int, ...], float]] + incident: Dict[int, List[int]] + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + h_cost: float + h_source: str + refined: bool + y_prices: Optional[np.ndarray] + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + max_queue_size: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + refinement_calls: int + lp_calls: int + reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_refinement_gain: float + max_refinement_gain: float + elapsed_seconds: float + + +class UnionFind: + def __init__(self, n: int) -> None: + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1.0 - p1) + (1.0 - p0) * p1 + + +def iter_dem_errors_from_dem(dem: stim.DetectorErrorModel) -> Iterable[ErrorRecord]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5), got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + yield ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors_from_dem(dem: stim.DetectorErrorModel) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors_from_dem(dem): + key = (error.detectors, error.observables) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = error.probability + else: + p_new = xor_probability(p_old, error.probability) + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Merged error has probability >= 0.5 ({probability}); cannot assign positive cost." + ) + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +class GreedySingletonHeuristicDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + num_observables: int, + *, + heuristic: str = "best_of_two", + respect_blocked_errors_in_heuristic: bool = True, + lazy_reinsert_heuristics: bool = True, + projection_combine_max_plain: bool = True, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_errors = len(self.errors) + self.num_detectors = int(num_detectors) + self.num_observables = int(num_observables) + self.heuristic_name = heuristic + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.lazy_reinsert_heuristics = lazy_reinsert_heuristics + self.projection_combine_max_plain = projection_combine_max_plain + self.verbose_search = verbose_search + + self.probabilities = np.array([err.probability for err in self.errors], dtype=np.float64) + self.weights = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.error_detectors: List[Tuple[int, ...]] = [tuple(err.detectors) for err in self.errors] + self.error_observables: List[Tuple[int, ...]] = [tuple(err.observables) for err in self.errors] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.error_detectors): + for d in dets: + d2e_lists[d].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + self.plain_heuristic_calls = 0 + self.projection_heuristic_calls = 0 + self.refinement_calls = 0 + self.lp_calls = 0 + self.reinserts = 0 + self.projected_nodes_generated = 0 + self.projected_nodes_refined = 0 + self.total_refinement_gain = 0.0 + self.max_refinement_gain = 0.0 + + @property + def mode_name(self) -> str: + if self.heuristic_name == "plain": + return "plain" + if self.lazy_reinsert_heuristics: + suffix = "-lazy-projection" + if self.projection_combine_max_plain: + suffix += "-maxplain" + return f"{self.heuristic_name}{suffix}" + return self.heuristic_name + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _has_cover_for_all_active_detectors(self, dets: np.ndarray, available_errs: np.ndarray) -> bool: + for d in np.flatnonzero(dets): + found = False + for ei in self.d2e[int(d)]: + if available_errs[int(ei)]: + found = True + break + if not found: + return False + return True + + def build_support_data(self, active_dets: np.ndarray, available_errs: np.ndarray) -> SupportData: + active_list = sorted(map(int, np.flatnonzero(active_dets))) + incident: Dict[int, List[int]] = {d: [] for d in active_list} + support_to_weight: Dict[Tuple[int, ...], float] = {} + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + support = tuple(d for d in self.error_detectors[ei] if active_dets[d]) + if not support: + continue + weight = float(self.weights[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + supports = list(support_to_weight.items()) + for i, (support, _weight) in enumerate(supports): + for d in support: + if d in incident: + incident[d].append(i) + + return SupportData(active_detectors=active_list, supports=supports, incident=incident) + + def _check_coverage(self, support_data: SupportData) -> bool: + return all(len(support_data.incident[d]) > 0 for d in support_data.active_detectors) + + def plain_detcost_from_counts( + self, + dets: np.ndarray, + available_errs: np.ndarray, + det_counts: np.ndarray, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + self.plain_heuristic_calls += 1 + active = np.flatnonzero(dets) + if active.size == 0: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for d in active: + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.weights[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF, None + y[int(d)] = best + total += best + return total, y + + def heuristic_plain(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + y = np.zeros(self.num_detectors, dtype=np.float64) + for d in support_data.active_detectors: + best = INF + for i in support_data.incident[d]: + support, weight = support_data.supports[i] + best = min(best, weight / len(support)) + y[d] = best + return float(y[support_data.active_detectors].sum()), y + + def heuristic_saturation_zero(self, support_data: SupportData, *, order_kind: str) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + slack = np.array([weight for _support, weight in support_data.supports], dtype=np.float64) + y = np.zeros(self.num_detectors, dtype=np.float64) + + if order_kind == "asc_deg": + order = sorted(support_data.active_detectors, key=lambda d: (len(support_data.incident[d]), d)) + elif order_kind == "desc_plain": + _plain_value, y_plain = self.heuristic_plain(support_data) + if y_plain is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y_plain[d], d), reverse=True) + else: + raise ValueError(f"Unknown order_kind={order_kind!r}") + + for d in order: + value = min(slack[i] for i in support_data.incident[d]) + if value < 0: + value = 0.0 + y[d] = value + for i in support_data.incident[d]: + slack[i] -= value + return float(y[support_data.active_detectors].sum()), y + + def heuristic_plain_sweep(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + plain_value, y = self.heuristic_plain(support_data) + if y is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y[d], d), reverse=True) + for d in order: + max_feasible = min( + weight - sum(y[dd] for dd in support if dd != d) + for support, weight in support_data.supports + if d in support + ) + if max_feasible > y[d]: + y[d] = max_feasible + return float(y[support_data.active_detectors].sum()), y + + def heuristic_exact_lp(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + active = support_data.active_detectors + if not active: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + detector_index = {d: i for i, d in enumerate(active)} + uf = UnionFind(len(active)) + for support, _weight in support_data.supports: + if len(support) > 1: + a = detector_index[support[0]] + for d in support[1:]: + uf.union(a, detector_index[d]) + + components: Dict[int, List[int]] = defaultdict(list) + for d in active: + components[uf.find(detector_index[d])].append(d) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for component in components.values(): + component_set = set(component) + local = {d: i for i, d in enumerate(sorted(component))} + component_supports: List[Tuple[Tuple[int, ...], float]] = [] + for support, weight in support_data.supports: + if support[0] in component_set: + component_supports.append((tuple(local[d] for d in support), weight)) + + rows: List[int] = [] + cols: List[int] = [] + data: List[float] = [] + rhs: List[float] = [] + for r, (support, weight) in enumerate(component_supports): + rhs.append(weight) + for c in support: + rows.append(r) + cols.append(c) + data.append(1.0) + + a_ub = csr_matrix( + (data, (rows, cols)), + shape=(len(component_supports), len(component)), + dtype=np.float64, + ) + self.lp_calls += 1 + result = linprog( + c=-np.ones(len(component), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * len(component), + method="highs", + ) + if not result.success: + return INF, None + total += -float(result.fun) + for d, value in zip(sorted(component), result.x): + y[d] = float(value) + return float(total), y + + def evaluate_named_heuristic(self, support_data: SupportData, name: str) -> Tuple[float, Optional[np.ndarray]]: + if name == "plain": + return self.heuristic_plain(support_data) + if name == "asc_deg": + return self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if name == "desc_plain": + return self.heuristic_saturation_zero(support_data, order_kind="desc_plain") + if name == "plain_sweep": + return self.heuristic_plain_sweep(support_data) + if name == "best_of_two": + v1, y1 = self.heuristic_plain_sweep(support_data) + v2, y2 = self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if v1 >= v2: + return v1, y1 + return v2, y2 + if name == "best_of_three": + candidates = [ + self.heuristic_plain_sweep(support_data), + self.heuristic_saturation_zero(support_data, order_kind="asc_deg"), + self.heuristic_saturation_zero(support_data, order_kind="desc_plain"), + ] + return max(candidates, key=lambda t: t[0]) + if name == "exact_lp": + return self.heuristic_exact_lp(support_data) + raise ValueError(f"Unknown heuristic {name!r}") + + def compute_support_based_heuristic( + self, + dets: np.ndarray, + errs: np.ndarray, + blocked_errs: np.ndarray, + *, + name: Optional[str] = None, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + available = self._available_errors(errs, blocked_errs) + support_data = self.build_support_data(dets, available) + return self.evaluate_named_heuristic(support_data, name or self.heuristic_name) + + def project_child_y( + self, + parent_state: SearchState, + child_dets: np.ndarray, + child_errs: np.ndarray, + child_blocked_errs: np.ndarray, + child_det_counts: np.ndarray, + flipped_detectors: Sequence[int], + ) -> Tuple[float, Optional[np.ndarray], str]: + if parent_state.y_prices is None: + raise AssertionError("Expected a stored feasible y vector before projecting to a child.") + + self.heuristic_calls += 1 + self.projection_heuristic_calls += 1 + available = self._available_errors(child_errs, child_blocked_errs) + if not self._has_cover_for_all_active_detectors(child_dets, available): + return INF, None, "projected" + + y_projected = np.zeros(self.num_detectors, dtype=np.float64) + keep = parent_state.dets & child_dets + y_projected[keep] = parent_state.y_prices[keep] + projected_value = float(y_projected[np.flatnonzero(child_dets)].sum()) + best_value = projected_value + best_y = y_projected + best_source = "projected" + + if self.projection_combine_max_plain: + plain_value, plain_y = self.plain_detcost_from_counts(child_dets, available, child_det_counts) + if plain_y is None: + return INF, None, "plain" + if plain_value > best_value + HEURISTIC_EPS: + best_value = plain_value + best_y = plain_y + best_source = "plain" + + return best_value, best_y, best_source + + def report_root_heuristics(self, dets: np.ndarray, errs: np.ndarray, blocked_errs: np.ndarray) -> List[Tuple[str, float]]: + available = self._available_errors(errs, blocked_errs) + support_data = self.build_support_data(dets, available) + names = ["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp"] + out: List[Tuple[str, float]] = [] + saved_lp_calls = self.lp_calls + for name in names: + value, _ = self.evaluate_named_heuristic(support_data, name) + out.append((name, value)) + self.lp_calls = saved_lp_calls + return out + + def _maybe_refine_node(self, state: SearchState) -> Tuple[SearchState, bool]: + if state.refined or self.heuristic_name == "plain" or not self.lazy_reinsert_heuristics: + return state, False + + previous_source = state.h_source + self.refinement_calls += 1 + new_value, new_y = self.compute_support_based_heuristic( + state.dets, + state.errs, + state.blocked_errs, + name=self.heuristic_name, + ) + if new_y is None: + if previous_source == "projected": + self.projected_nodes_refined += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost:.6f} new_h=INF delta=INF reinserted=False discarded=True" + ) + state.refined = True + return state, True + + delta = new_value - state.h_cost + self.total_refinement_gain += max(0.0, delta) + self.max_refinement_gain = max(self.max_refinement_gain, max(0.0, delta)) + + if self.heuristic_name == "exact_lp" and new_value + 1e-7 < state.h_cost: + raise AssertionError( + f"Exact LP value {new_value} below stored projected value {state.h_cost}." + ) + + if new_value > state.h_cost + HEURISTIC_EPS: + if previous_source == "projected": + self.projected_nodes_refined += 1 + state.h_cost = new_value + state.h_source = "refined" + state.y_prices = new_y + state.refined = True + self.reinserts += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost - delta:.6f} new_h={new_value:.6f} delta={delta:.6f} reinserted=True discarded=False" + ) + return state, True + + # Non-improving greedy recomputation: keep the existing projected feasible point. + if previous_source == "projected": + self.projected_nodes_refined += 1 + if abs(new_value - state.h_cost) <= HEURISTIC_EPS: + state.y_prices = new_y + state.refined = True + if self.verbose_search: + new_text = "INF" if math.isinf(new_value) else f"{new_value:.6f}" + print( + f" refine approx_h={state.h_cost:.6f} new_h={new_text} delta={delta:.6f} reinserted=False discarded=False" + ) + return state, False + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + root_h, root_y = self.plain_detcost_from_counts(dets0, self._available_errors(errs0, blocked0), det_counts0) + if root_y is None or math.isinf(root_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + root_refined = (self.heuristic_name == "plain") or (not self.lazy_reinsert_heuristics) + if root_refined and self.heuristic_name != "plain": + # Eager mode: use the selected heuristic immediately. + eager_h, eager_y = self.compute_support_based_heuristic(dets0, errs0, blocked0, name=self.heuristic_name) + if eager_y is None or math.isinf(eager_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + root_h, root_y = eager_h, eager_y + + root_state = SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + h_cost=root_h, + h_source="plain" if not root_refined else ("plain" if self.heuristic_name == "plain" else "refined"), + refined=root_refined, + y_prices=root_y, + ) + + heap: List[Tuple[float, int, int, SearchState]] = [] + counter = 0 + heapq.heappush(heap, (root_state.g_cost + root_state.h_cost, int(dets0.sum()), counter, root_state)) + counter += 1 + nodes_pushed = 1 + nodes_popped = 0 + max_queue_size = 1 + min_num_dets = int(dets0.sum()) + + while heap: + max_queue_size = max(max_queue_size, len(heap)) + f_cost, num_dets, _entry_id, state = heapq.heappop(heap) + nodes_popped += 1 + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f"len(heap)={len(heap)} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"lp_calls={self.lp_calls} reinserts={self.reinserts} proj_generated={self.projected_nodes_generated} " + f"proj_refined={self.projected_nodes_refined} proj_unrefined_so_far={projected_unrefined} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f} " + f"h={state.h_cost:.6f} h_source={state.h_source} refined={state.refined}" + ) + + if num_dets == 0: + return DecodeResult( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + state, should_reinsert = self._maybe_refine_node(state) + if should_reinsert: + if state.y_prices is None or math.isinf(state.h_cost): + if state.h_source == "projected": + self.projected_nodes_refined += 1 + continue + if state.h_source != "plain": + heapq.heappush(heap, (state.g_cost + state.h_cost, num_dets, counter, state)) + counter += 1 + continue + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked = state.blocked_errs.copy() + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + children_projected = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked[ei] = True + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked = prefix_blocked.copy() + child_dets = state.dets.copy() + child_det_counts = state.det_counts.copy() + for d in self.error_detectors[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_g = state.g_cost + float(self.weights[ei]) + if self.heuristic_name == "plain" or (not self.lazy_reinsert_heuristics): + child_h, child_y = self.compute_support_based_heuristic( + child_dets, child_errs, child_blocked, name=self.heuristic_name + ) + child_source = "plain" if self.heuristic_name == "plain" else "refined" + child_refined = True + else: + if state.y_prices is None: + raise AssertionError("Expected parent feasible y-prices before projecting to child.") + child_h, child_y, child_source = self.project_child_y( + state, + child_dets, + child_errs, + child_blocked, + child_det_counts, + self.error_detectors[ei], + ) + self.projected_nodes_generated += 1 + children_projected += 1 + child_refined = False + + if child_y is None or math.isinf(child_h): + children_infeasible += 1 + continue + + child_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=child_h, + h_source=child_source, + refined=child_refined, + y_prices=child_y, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, counter, child_state)) + counter += 1 + nodes_pushed += 1 + children_generated += 1 + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f" expanded children_generated={children_generated} children_projected={children_projected} " + f"beam_pruned={children_beam_pruned} infeasible={children_infeasible} " + f"lp_calls={self.lp_calls} proj_unrefined_so_far={projected_unrefined}" + ) + + return DecodeResult( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.weights[errs].sum()) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.error_detectors[int(ei)]: + dets[d] ^= True + return dets + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.error_observables[int(ei)]: + parity[int(obs)] = not parity.get(int(obs), False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + +def sample_detections_and_observables( + circuit: stim.Circuit, + *, + num_shots: int, + seed: int, + num_detectors: int, + num_observables: int, +) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample from Stim before selecting --shot (default: 100).", + ) + parser.add_argument( + "--shot", + type=int, + default=0, + help="Index of the sampled shot to decode (default: 0).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the residual detector count; use 'inf' for none.", + ) + parser.add_argument( + "--heuristic", + choices=["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp"], + default="best_of_two", + help="Lower-bound heuristic to use during A* search (default: best_of_two).", + ) + parser.add_argument( + "--lazy-reinsert-heuristics", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For non-plain heuristics, seed nodes with plain detcost, refine on pop, and reinsert when the selected " + "heuristic improves the key (default: enabled)." + ), + ) + parser.add_argument( + "--projection-combine-max-plain", + action=argparse.BooleanOptionalAction, + default=True, + help="When projecting parent y-prices to a child, take max(projected, plain detcost) (default: enabled).", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action=argparse.BooleanOptionalAction, + default=True, + help="Exclude precedence-blocked errors when forming the lower bound (default: enabled).", + ) + parser.add_argument( + "--report-all-root-heuristics", + action="store_true", + help="Print all root-node heuristic values, including exact_lp, for the selected shot.", + ) + parser.add_argument( + "--skip-decode", + action="store_true", + help="Only report root heuristics; do not run A* search.", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the selected shot's active detector IDs (default: enabled).", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the decoded merged-error indices when decoding succeeds (default: enabled).", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) if args.merge_errors else list(iter_dem_errors_from_dem(dem)) + + dets, obs = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + ) + shot_dets = dets[args.shot] + shot_obs = obs[args.shot] + + decoder = GreedySingletonHeuristicDecoder( + errors, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + heuristic=args.heuristic, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + lazy_reinsert_heuristics=args.lazy_reinsert_heuristics, + projection_combine_max_plain=args.projection_combine_max_plain, + verbose_search=args.verbose_search, + ) + + print(f"circuit = {args.circuit}") + print(f"heuristic = {args.heuristic}") + print(f"mode = {decoder.mode_name}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"shot = {args.shot}") + print(f"num_errors = {decoder.num_errors}") + print(f"num_detectors = {decoder.num_detectors}") + print(f"num_observables = {decoder.num_observables}") + print(f"det_beam = {args.det_beam}") + print(f"merge_errors = {args.merge_errors}") + print(f"respect_blocked_errors_in_heuristic = {args.respect_blocked_errors_in_heuristic}") + print(f"lazy_reinsert_heuristics = {args.lazy_reinsert_heuristics}") + print(f"projection_combine_max_plain = {args.projection_combine_max_plain}") + + if args.show_shot_detectors: + active_dets = np.flatnonzero(shot_dets) + print("shot_detectors =", " ".join(f"D{d}" for d in active_dets)) + + if args.report_all_root_heuristics: + root_errs = np.zeros(decoder.num_errors, dtype=bool) + root_blocked = np.zeros(decoder.num_errors, dtype=bool) + report = decoder.report_root_heuristics(shot_dets, root_errs, root_blocked) + exact = next((v for k, v in report if k == "exact_lp"), None) + print("root_heuristics:") + for name, value in report: + if exact is not None and not math.isinf(exact) and exact > 0: + ratio = value / exact if not math.isinf(value) else INF + ratio_text = "INF" if math.isinf(ratio) else f"{ratio:.6f}" + else: + ratio_text = "n/a" + value_text = "INF" if math.isinf(value) else f"{value:.12f}" + print(f" {name:>12s} value={value_text} ratio_to_exact={ratio_text}") + + if args.skip_decode: + return 0 + + result = decoder.decode(shot_dets, det_beam=args.det_beam) + print(f"success = {result.success}") + print(f"nodes_pushed = {result.nodes_pushed}") + print(f"nodes_popped = {result.nodes_popped}") + print(f"max_queue_size = {result.max_queue_size}") + print(f"heuristic_calls = {result.heuristic_calls}") + print(f"plain_heuristic_calls = {result.plain_heuristic_calls}") + print(f"projection_heuristic_calls = {result.projection_heuristic_calls}") + print(f"refinement_calls = {result.refinement_calls}") + print(f"lp_calls = {result.lp_calls}") + print(f"reinserts = {result.reinserts}") + print(f"projected_nodes_generated = {result.projected_nodes_generated}") + print(f"projected_nodes_refined = {result.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish = {result.projected_nodes_unrefined_at_finish}") + print(f"total_refinement_gain = {result.total_refinement_gain:.6f}") + print(f"max_refinement_gain = {result.max_refinement_gain:.6f}") + print(f"elapsed_seconds = {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + if args.show_error_indices: + print("decoded_error_indices =", " ".join(map(str, np.flatnonzero(result.errs).tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + decoded_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + sampled_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors = {int(result.errs.sum())}") + print(f"decoded_cost = {decoded_cost:.12f}") + print("predicted_observables =", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables =", " ".join(f"L{o}" for o in sampled_obs.tolist())) + print(f"observables_match = {bool(np.array_equal(predicted_obs, sampled_obs))}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_singleton_restricted_lazy.py b/src/py/astar/astar_prototype_singleton_restricted_lazy.py new file mode 100644 index 0000000..3545498 --- /dev/null +++ b/src/py/astar/astar_prototype_singleton_restricted_lazy.py @@ -0,0 +1,1675 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with lazy optimal-singleton refinement. + +This script is intentionally packaged similarly to astar_prototype_subset_detcost_lazy.py, +but specialized to the singleton LP. It offers three modes: + + --opt-singleton-detcost-mode plain + Use plain detcost only. + + --opt-singleton-detcost-mode full + Lazy exact singleton LP on pop, with projected child lower bounds. + + --opt-singleton-detcost-mode restricted + Lazy exact singleton LP on pop, but solved by a restricted-master / + separation loop seeded from the parent tight set. + +Two "outside the box" ideas are built in: + + 1) Parent-primal projection. + If y_parent is feasible for the parent singleton LP, then setting the child + detector prices to y_parent on detectors that remain active and 0 on newly + active detectors is automatically feasible for the child singleton LP. + That gives a cheap admissible child lower bound. + + 2) Local residual projection LP. + On top of the projected parent prices, we can re-optimize a tiny local LP + on either the newly active detectors or the neighborhood touched by the + changed detector set, while keeping the outside detector prices fixed. + This is still admissible because it is a feasible child primal solution. +""" + +from __future__ import annotations + +import argparse +import heapq +import json +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy import sparse +from scipy.optimize import linprog + +INF = math.inf +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class MergedError: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class DecoderData: + num_detectors: int + num_observables: int + errors: List[MergedError] + detector_to_errors: List[List[int]] + error_costs: np.ndarray + error_detectors: List[Tuple[int, ...]] + error_observables: List[Tuple[int, ...]] + + +@dataclass +class SearchState: + activated_errors: Tuple[int, ...] + blocked_errors: np.ndarray + active_detectors: np.ndarray + active_detector_counts: np.ndarray + path_cost: float + heuristic_cost: float + heuristic_source: str + exact_refined: bool + lp_solution: Optional["SingletonLPSolution"] = None + warm_start_solution: Optional["SingletonLPSolution"] = None + changed_detectors_from_parent: Tuple[int, ...] = () + + +@dataclass +class DecodeStats: + num_pq_pushed: int + num_nodes_popped: int + max_queue_size: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + exact_refinement_calls: int + lp_calls: int + lp_reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_lp_refinement_gain: float + max_lp_refinement_gain: float + lp_total_seconds: float + projection_local_lp_calls: int + projection_local_lp_seconds: float + restricted_total_rounds: int + restricted_total_added_supports: int + restricted_total_fallbacks: int + full_check_calls: int + full_check_max_abs_delta: float + elapsed_seconds: float + heuristic_name: str + + +@dataclass +class DecodeResult: + activated_errors: Tuple[int, ...] + path_cost: float + stats: DecodeStats + + +@dataclass(frozen=True) +class RestrictedMasterConfig: + add_policy: str = "topk" # one | topk | all + add_top_k: int = 3 + violation_tol: float = 1e-9 + tight_tol: float = 1e-8 + prune_slack: bool = True + prune_tol: float = 1e-8 + seed_normalized_global_k: int = 0 + seed_normalized_touching_changed_k: int = 2 + max_rounds: int = 50 + fallback_full: bool = True + full_check_every: int = 0 + + +@dataclass +class SingletonLPSolution: + value: float + active_detectors: Tuple[int, ...] + y_by_detector: Dict[int, float] + tight_supports: Tuple[Tuple[int, ...], ...] + num_components: int + num_variables: int + num_constraints: int + num_selected_constraints: int + num_rounds: int + solve_mode: str + + +@dataclass +class SingletonLPSolverStats: + lp_calls: int = 0 + lp_total_seconds: float = 0.0 + projection_local_lp_calls: int = 0 + projection_local_lp_seconds: float = 0.0 + restricted_total_rounds: int = 0 + restricted_total_added_supports: int = 0 + restricted_total_fallbacks: int = 0 + full_check_calls: int = 0 + full_check_max_abs_delta: float = 0.0 + + +class UnionFind: + def __init__(self, size: int) -> None: + self.parent = list(range(size)) + self.rank = [0] * size + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class LPLogger: + def __init__(self, path: Path, *, every: int = 1, top_k: int = 12) -> None: + self.path = path + self.every = max(1, every) + self.top_k = max(1, top_k) + self.path.parent.mkdir(parents=True, exist_ok=True) + self.path.write_text("") + + def maybe_log(self, *, call_index: int, payload: Dict[str, Any]) -> None: + if call_index % self.every != 0: + return + with self.path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, sort_keys=True) + "\n") + + +class SingletonLPHeuristic: + def __init__( + self, + data: DecoderData, + *, + exact_mode: str, + projection_mode: str, + projection_combine_max_plain: bool, + restricted_config: RestrictedMasterConfig, + logger: Optional[LPLogger] = None, + ) -> None: + if exact_mode not in {"full", "restricted"}: + raise ValueError(f"Unsupported exact_mode: {exact_mode}") + if projection_mode not in {"plain", "parent_y", "new_only", "changed_neighborhood"}: + raise ValueError(f"Unsupported projection_mode: {projection_mode}") + self.data = data + self.exact_mode = exact_mode + self.projection_mode = projection_mode + self.projection_combine_max_plain = projection_combine_max_plain + self.restricted_config = restricted_config + self.logger = logger + self.stats = SingletonLPSolverStats() + self.exact_solve_calls = 0 + + def reset_stats(self) -> None: + self.stats = SingletonLPSolverStats() + self.exact_solve_calls = 0 + + def solve_exact( + self, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, + *, + warm_start_solution: Optional[SingletonLPSolution], + changed_detectors: Tuple[int, ...], + ) -> Tuple[SingletonLPSolution, Dict[str, Any]]: + self.exact_solve_calls += 1 + t0 = time.perf_counter() + active_detector_ids = tuple(int(d) for d in np.flatnonzero(active_detectors)) + support_costs = build_active_support_costs( + data=self.data, + active_detectors=active_detectors, + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + ) + + if not active_detector_ids: + elapsed = time.perf_counter() - t0 + payload = { + "solve_mode": self.exact_mode, + "objective": 0.0, + "num_components": 0, + "num_variables": 0, + "num_constraints": 0, + "num_selected_constraints": 0, + "num_rounds": 0, + "num_supports_total": 0, + "solve_seconds": elapsed, + "structurally_infeasible": False, + } + return ( + SingletonLPSolution( + value=0.0, + active_detectors=(), + y_by_detector={}, + tight_supports=(), + num_components=0, + num_variables=0, + num_constraints=0, + num_selected_constraints=0, + num_rounds=0, + solve_mode=self.exact_mode, + ), + payload, + ) + + missing_cover = [ + detector + for detector in active_detector_ids + if not any(detector in support for support in support_costs) + ] + if missing_cover: + elapsed = time.perf_counter() - t0 + payload = { + "solve_mode": self.exact_mode, + "objective": INF, + "num_components": 0, + "num_variables": len(active_detector_ids), + "num_constraints": len(support_costs), + "num_selected_constraints": 0, + "num_rounds": 0, + "num_supports_total": len(support_costs), + "solve_seconds": elapsed, + "structurally_infeasible": True, + "missing_cover_detectors": missing_cover, + } + return ( + SingletonLPSolution( + value=INF, + active_detectors=active_detector_ids, + y_by_detector={}, + tight_supports=(), + num_components=0, + num_variables=len(active_detector_ids), + num_constraints=len(support_costs), + num_selected_constraints=0, + num_rounds=0, + solve_mode=self.exact_mode, + ), + payload, + ) + + if self.exact_mode == "full": + solution, full_payload = self._solve_full_support_lp( + active_detector_ids=active_detector_ids, + support_costs=support_costs, + solve_mode="full", + ) + elapsed = time.perf_counter() - t0 + payload = dict(full_payload) + payload.update( + { + "solve_mode": "full", + "num_supports_total": len(support_costs), + "solve_seconds": elapsed, + "structurally_infeasible": False, + } + ) + return solution, payload + + solution, payload = self._solve_restricted_exact( + active_detector_ids=active_detector_ids, + support_costs=support_costs, + warm_start_solution=warm_start_solution, + changed_detectors=changed_detectors, + ) + payload.update( + { + "solve_mode": "restricted", + "num_supports_total": len(support_costs), + "structurally_infeasible": False, + } + ) + return solution, payload + + def project_to_child( + self, + parent_solution: SingletonLPSolution, + child_active_detectors: np.ndarray, + child_blocked_errors: np.ndarray, + child_active_detector_counts: np.ndarray, + *, + changed_detectors: Tuple[int, ...], + ) -> float: + if self.projection_mode == "plain": + projected = plain_detcost_heuristic( + data=self.data, + active_detectors=child_active_detectors, + blocked_errors=child_blocked_errors, + active_detector_counts=child_active_detector_counts, + ) + return projected + + parent_active_set = set(parent_solution.active_detectors) + child_active_ids = tuple(int(d) for d in np.flatnonzero(child_active_detectors)) + parent_y = parent_solution.y_by_detector + + # Fixed outside prices inherited from the parent exact primal solution. + fixed_outside_y: Dict[int, float] = {} + region_detectors: set[int] = set() + if self.projection_mode == "parent_y": + region_detectors = set() + elif self.projection_mode == "new_only": + region_detectors = {d for d in child_active_ids if d not in parent_active_set} + elif self.projection_mode == "changed_neighborhood": + changed_set = set(changed_detectors) + region_detectors = {d for d in child_active_ids if d not in parent_active_set} + for detector in changed_set: + for error_index in self.data.detector_to_errors[detector]: + if child_blocked_errors[error_index]: + continue + if child_active_detector_counts[error_index] <= 0: + continue + for other_detector in self.data.error_detectors[error_index]: + if child_active_detectors[other_detector]: + region_detectors.add(other_detector) + else: + raise AssertionError("unreachable projection mode") + + for detector in child_active_ids: + if detector in region_detectors: + continue + if detector in parent_active_set: + fixed_outside_y[detector] = parent_y.get(detector, 0.0) + else: + fixed_outside_y[detector] = 0.0 + + projected = sum(fixed_outside_y.values()) + + if region_detectors: + local_gain = self._solve_local_region_projection_lp( + child_active_detectors=child_active_detectors, + child_blocked_errors=child_blocked_errors, + child_active_detector_counts=child_active_detector_counts, + region_detectors=region_detectors, + fixed_outside_y=fixed_outside_y, + ) + if local_gain == INF: + projected = INF + else: + projected += local_gain + + if self.projection_combine_max_plain: + plain = plain_detcost_heuristic( + data=self.data, + active_detectors=child_active_detectors, + blocked_errors=child_blocked_errors, + active_detector_counts=child_active_detector_counts, + ) + projected = max(projected, plain) + return projected + + def _solve_local_region_projection_lp( + self, + *, + child_active_detectors: np.ndarray, + child_blocked_errors: np.ndarray, + child_active_detector_counts: np.ndarray, + region_detectors: set[int], + fixed_outside_y: Dict[int, float], + ) -> float: + if not region_detectors: + return 0.0 + t0 = time.perf_counter() + region_support_costs: Dict[Tuple[int, ...], float] = {} + for error_index, error_detectors in enumerate(self.data.error_detectors): + if child_blocked_errors[error_index]: + continue + count = int(child_active_detector_counts[error_index]) + if count <= 0: + continue + full_support = tuple(d for d in error_detectors if child_active_detectors[d]) + assert len(full_support) == count + local_support = tuple(d for d in full_support if d in region_detectors) + if not local_support: + continue + fixed = sum(fixed_outside_y.get(d, 0.0) for d in full_support if d not in region_detectors) + residual = float(self.data.error_costs[error_index]) - fixed + if residual < -1e-8: + raise AssertionError( + f"Projected parent y is infeasible for child: residual={residual} on error {error_index}." + ) + residual = max(0.0, residual) + previous = region_support_costs.get(local_support) + if previous is None or residual < previous: + region_support_costs[local_support] = residual + + region_detector_ids = tuple(sorted(region_detectors)) + if any(not any(detector in support for support in region_support_costs) for detector in region_detector_ids): + # No admissible gain on uncovered region detectors; keep them at zero. + elapsed = time.perf_counter() - t0 + self.stats.projection_local_lp_calls += 1 + self.stats.projection_local_lp_seconds += elapsed + return 0.0 + + objective, _, _, _, _ = solve_primal_lp_on_supports( + detector_ids=region_detector_ids, + support_costs=region_support_costs, + record_stats=self.stats, + count_as_main_lp=False, + ) + elapsed = time.perf_counter() - t0 + self.stats.projection_local_lp_calls += 1 + self.stats.projection_local_lp_seconds += elapsed + return objective + + def _solve_full_support_lp( + self, + *, + active_detector_ids: Tuple[int, ...], + support_costs: Dict[Tuple[int, ...], float], + solve_mode: str, + ) -> Tuple[SingletonLPSolution, Dict[str, Any]]: + components = split_support_costs_into_components( + active_detector_ids=active_detector_ids, + support_costs=support_costs, + ) + total_value = 0.0 + total_num_variables = 0 + total_num_constraints = 0 + y_by_detector: Dict[int, float] = {} + tight_supports: List[Tuple[int, ...]] = [] + for detector_ids, component_support_costs in components: + value, component_y, component_tight, num_vars, num_constraints = solve_primal_lp_on_supports( + detector_ids=detector_ids, + support_costs=component_support_costs, + record_stats=self.stats, + count_as_main_lp=True, + ) + total_value += value + total_num_variables += num_vars + total_num_constraints += num_constraints + y_by_detector.update(component_y) + tight_supports.extend(component_tight) + + solution = SingletonLPSolution( + value=total_value, + active_detectors=active_detector_ids, + y_by_detector=y_by_detector, + tight_supports=tuple(sorted(set(tight_supports))), + num_components=len(components), + num_variables=total_num_variables, + num_constraints=total_num_constraints, + num_selected_constraints=total_num_constraints, + num_rounds=1, + solve_mode=solve_mode, + ) + payload = { + "objective": total_value, + "num_components": len(components), + "num_variables": total_num_variables, + "num_constraints": total_num_constraints, + "num_selected_constraints": total_num_constraints, + "num_rounds": 1, + "tight_support_count": len(solution.tight_supports), + "top_tight_supports": [ + {"support": list(support), "cost": float(support_costs[support])} + for support in sorted(solution.tight_supports, key=lambda s: (len(s), s))[: self.logger.top_k if self.logger else 12] + ], + } + return solution, payload + + def _solve_restricted_exact( + self, + *, + active_detector_ids: Tuple[int, ...], + support_costs: Dict[Tuple[int, ...], float], + warm_start_solution: Optional[SingletonLPSolution], + changed_detectors: Tuple[int, ...], + ) -> Tuple[SingletonLPSolution, Dict[str, Any]]: + t0 = time.perf_counter() + components = split_support_costs_into_components( + active_detector_ids=active_detector_ids, + support_costs=support_costs, + ) + total_value = 0.0 + total_num_variables = 0 + total_num_constraints = 0 + total_num_selected_constraints = 0 + total_rounds = 0 + y_by_detector: Dict[int, float] = {} + tight_supports: List[Tuple[int, ...]] = [] + component_payloads: List[Dict[str, Any]] = [] + fallbacks_used = 0 + + parent_tight_supports = set() if warm_start_solution is None else set(warm_start_solution.tight_supports) + changed_set = set(changed_detectors) + + for detector_ids, component_support_costs in components: + component_result, component_payload = self._solve_restricted_component( + detector_ids=detector_ids, + support_costs=component_support_costs, + parent_tight_supports=parent_tight_supports, + changed_set=changed_set, + ) + total_value += component_result["value"] + total_num_variables += len(detector_ids) + total_num_constraints += len(component_support_costs) + total_num_selected_constraints += component_result["num_selected_constraints"] + total_rounds += component_result["num_rounds"] + y_by_detector.update(component_result["y_by_detector"]) + tight_supports.extend(component_result["tight_supports"]) + component_payloads.append(component_payload) + if component_result["used_full_fallback"]: + fallbacks_used += 1 + + self.stats.restricted_total_rounds += total_rounds + self.stats.restricted_total_fallbacks += fallbacks_used + + solution = SingletonLPSolution( + value=total_value, + active_detectors=active_detector_ids, + y_by_detector=y_by_detector, + tight_supports=tuple(sorted(set(tight_supports))), + num_components=len(components), + num_variables=total_num_variables, + num_constraints=total_num_constraints, + num_selected_constraints=total_num_selected_constraints, + num_rounds=total_rounds, + solve_mode="restricted", + ) + + if self.restricted_config.full_check_every > 0 and self.exact_solve_calls % self.restricted_config.full_check_every == 0: + self.stats.full_check_calls += 1 + full_solution, _ = self._solve_full_support_lp( + active_detector_ids=active_detector_ids, + support_costs=support_costs, + solve_mode="full_check", + ) + delta = abs(full_solution.value - solution.value) + self.stats.full_check_max_abs_delta = max(self.stats.full_check_max_abs_delta, delta) + if delta > 1e-7: + raise AssertionError( + f"Restricted exact solver mismatch: restricted={solution.value} full={full_solution.value} delta={delta}" + ) + + payload = { + "objective": total_value, + "num_components": len(components), + "num_variables": total_num_variables, + "num_constraints": total_num_constraints, + "num_selected_constraints": total_num_selected_constraints, + "num_rounds": total_rounds, + "tight_support_count": len(solution.tight_supports), + "used_full_fallbacks": fallbacks_used, + "components": component_payloads, + "solve_seconds": time.perf_counter() - t0, + } + return solution, payload + + def _solve_restricted_component( + self, + *, + detector_ids: Tuple[int, ...], + support_costs: Dict[Tuple[int, ...], float], + parent_tight_supports: set[Tuple[int, ...]], + changed_set: set[int], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + cover_support_for_detector: Dict[int, Tuple[int, ...]] = {} + supports_touching_changed: List[Tuple[int, ...]] = [] + for support, cost in support_costs.items(): + for detector in support: + previous = cover_support_for_detector.get(detector) + if previous is None: + cover_support_for_detector[detector] = support + else: + prev_key = (support_costs[previous], len(previous), previous) + cur_key = (cost, len(support), support) + if cur_key < prev_key: + cover_support_for_detector[detector] = support + if changed_set and any(detector in changed_set for detector in support): + supports_touching_changed.append(support) + + selected_supports: set[Tuple[int, ...]] = set(cover_support_for_detector.values()) + cover_supports = set(selected_supports) + surviving_parent_tight = {support for support in parent_tight_supports if support in support_costs} + selected_supports |= surviving_parent_tight + + if self.restricted_config.seed_normalized_global_k > 0: + cheapest_norm = sorted( + support_costs, + key=lambda support: (support_costs[support] / len(support), support_costs[support], len(support), support), + )[: self.restricted_config.seed_normalized_global_k] + selected_supports.update(cheapest_norm) + + if self.restricted_config.seed_normalized_touching_changed_k > 0 and supports_touching_changed: + touching = sorted( + supports_touching_changed, + key=lambda support: (support_costs[support] / len(support), support_costs[support], len(support), support), + )[: self.restricted_config.seed_normalized_touching_changed_k] + selected_supports.update(touching) + + rounds = 0 + total_added_supports = 0 + used_full_fallback = False + payload_rounds: List[Dict[str, Any]] = [] + + while True: + rounds += 1 + selected_supports |= cover_supports + restricted_support_costs = {support: support_costs[support] for support in selected_supports} + value, y_by_detector, selected_tight_supports, num_vars, num_selected_constraints = solve_primal_lp_on_supports( + detector_ids=detector_ids, + support_costs=restricted_support_costs, + record_stats=self.stats, + count_as_main_lp=True, + ) + slacks: Dict[Tuple[int, ...], float] = {} + violations: List[Tuple[float, Tuple[int, ...]]] = [] + full_tight_supports: List[Tuple[int, ...]] = [] + for support, cost in support_costs.items(): + lhs = sum(y_by_detector.get(detector, 0.0) for detector in support) + slack = cost - lhs + slacks[support] = slack + if slack < -self.restricted_config.violation_tol: + violations.append((-slack, support)) + if abs(slack) <= self.restricted_config.tight_tol: + full_tight_supports.append(support) + + payload_rounds.append( + { + "round": rounds, + "selected_constraints": len(selected_supports), + "restricted_tight_count": len(selected_tight_supports), + "full_tight_count": len(full_tight_supports), + "max_violation": 0.0 if not violations else float(max(v for v, _ in violations)), + } + ) + + if not violations: + self.stats.restricted_total_added_supports += total_added_supports + component_result = { + "value": value, + "y_by_detector": y_by_detector, + "tight_supports": tuple(sorted(full_tight_supports)), + "num_selected_constraints": len(selected_supports), + "num_rounds": rounds, + "used_full_fallback": False, + } + component_payload = { + "detectors": list(detector_ids), + "supports_total": len(support_costs), + "initial_seed_count": len(cover_supports | surviving_parent_tight), + "final_selected_constraints": len(selected_supports), + "rounds": rounds, + "used_full_fallback": False, + "parent_tight_survivors": len(surviving_parent_tight), + "cover_supports": len(cover_supports), + "round_summaries": payload_rounds, + } + return component_result, component_payload + + if rounds >= self.restricted_config.max_rounds: + if not self.restricted_config.fallback_full: + raise RuntimeError( + f"Restricted singleton LP exceeded max rounds={self.restricted_config.max_rounds} without fallback." + ) + used_full_fallback = True + self.stats.restricted_total_added_supports += total_added_supports + full_value, full_y, full_tight, _, _ = solve_primal_lp_on_supports( + detector_ids=detector_ids, + support_costs=support_costs, + record_stats=self.stats, + count_as_main_lp=True, + ) + component_result = { + "value": full_value, + "y_by_detector": full_y, + "tight_supports": tuple(sorted(full_tight)), + "num_selected_constraints": len(support_costs), + "num_rounds": rounds, + "used_full_fallback": True, + } + component_payload = { + "detectors": list(detector_ids), + "supports_total": len(support_costs), + "initial_seed_count": len(cover_supports | surviving_parent_tight), + "final_selected_constraints": len(support_costs), + "rounds": rounds, + "used_full_fallback": True, + "parent_tight_survivors": len(surviving_parent_tight), + "cover_supports": len(cover_supports), + "round_summaries": payload_rounds, + } + return component_result, component_payload + + if self.restricted_config.prune_slack: + selected_supports = { + support + for support in selected_supports + if slacks.get(support, INF) <= self.restricted_config.prune_tol or support in cover_supports + } + + violations.sort(key=lambda item: (-item[0], support_costs[item[1]], len(item[1]), item[1])) + if self.restricted_config.add_policy == "one": + to_add = [violations[0][1]] + elif self.restricted_config.add_policy == "topk": + to_add = [support for _, support in violations[: self.restricted_config.add_top_k]] + elif self.restricted_config.add_policy == "all": + to_add = [support for _, support in violations] + else: + raise ValueError(f"Unsupported add policy: {self.restricted_config.add_policy}") + new_supports = [support for support in to_add if support not in selected_supports] + total_added_supports += len(new_supports) + selected_supports.update(new_supports) + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1 - p1) + (1 - p0) * p1 + + +def parse_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "+inf", "infinity", "+infinity"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("beam must be non-negative or 'inf'") + return float(value) + + +def parse_optional_int(text: str) -> Optional[int]: + lowered = text.strip().lower() + if lowered in {"none", "inf", "infinity"}: + return None + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("value must be non-negative or 'none'") + return value + + +def format_indices(indices: Iterable[int], prefix: str) -> str: + items = list(indices) + if not items: + return "(none)" + return " ".join(f"{prefix}{i}" for i in items) + + +def iter_dem_errors(dem: stim.DetectorErrorModel) -> Iterable[MergedError]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError("This prototype assumes DEM probabilities are in (0, 0.5).") + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + yield MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors(dem: stim.DetectorErrorModel) -> List[MergedError]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors(dem): + key = (error.detectors, error.observables) + previous = errors_by_symptom.get(key) + if previous is None: + errors_by_symptom[key] = error.probability + else: + errors_by_symptom[key] = xor_probability(previous, error.probability) + + merged: List[MergedError] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError("Merged error has probability >= 0.5, giving a non-positive cost.") + merged.append( + MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def build_decoder_data( + dem: stim.DetectorErrorModel, + *, + merge_errors_in_dem: bool = True, +) -> DecoderData: + errors = merged_errors(dem) if merge_errors_in_dem else list(iter_dem_errors(dem)) + detector_to_errors: List[List[int]] = [[] for _ in range(dem.num_detectors)] + for error_index, error in enumerate(errors): + for detector in error.detectors: + detector_to_errors[detector].append(error_index) + return DecoderData( + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + errors=errors, + detector_to_errors=detector_to_errors, + error_costs=np.asarray([e.likelihood_cost for e in errors], dtype=np.float64), + error_detectors=[e.detectors for e in errors], + error_observables=[e.observables for e in errors], + ) + + +def unpack_bit_packed_rows(bits: np.ndarray, count: int) -> np.ndarray: + return np.unpackbits(bits, bitorder="little", axis=1, count=count).astype(bool, copy=False) + + +def initial_detector_counts(data: DecoderData, active_detectors: np.ndarray) -> np.ndarray: + counts = np.zeros(len(data.errors), dtype=np.int32) + for detector in np.flatnonzero(active_detectors): + for error_index in data.detector_to_errors[int(detector)]: + counts[error_index] += 1 + return counts + + +def apply_error( + data: DecoderData, + active_detectors: np.ndarray, + active_detector_counts: np.ndarray, + error_index: int, +) -> Tuple[np.ndarray, np.ndarray]: + next_detectors = active_detectors.copy() + next_counts = active_detector_counts.copy() + for detector in data.error_detectors[error_index]: + if next_detectors[detector]: + next_detectors[detector] = False + delta = -1 + else: + next_detectors[detector] = True + delta = 1 + for other_error_index in data.detector_to_errors[detector]: + next_counts[other_error_index] += delta + return next_detectors, next_counts + + +def plain_detcost_for_detector( + data: DecoderData, + detector: int, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + best = INF + for error_index in data.detector_to_errors[detector]: + if blocked_errors[error_index]: + continue + count = int(active_detector_counts[error_index]) + assert count > 0 + candidate = float(data.error_costs[error_index]) / count + if candidate < best: + best = candidate + return best + + +def plain_detcost_heuristic( + data: DecoderData, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + total = 0.0 + for detector in np.flatnonzero(active_detectors): + det_cost = plain_detcost_for_detector( + data=data, + detector=int(detector), + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + ) + if det_cost == INF: + return INF + total += det_cost + return total + + +def build_active_support_costs( + data: DecoderData, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> Dict[Tuple[int, ...], float]: + support_costs: Dict[Tuple[int, ...], float] = {} + for error_index, error_detectors in enumerate(data.error_detectors): + if blocked_errors[error_index]: + continue + count = int(active_detector_counts[error_index]) + if count <= 0: + continue + support = tuple(detector for detector in error_detectors if active_detectors[detector]) + assert len(support) == count + cost = float(data.error_costs[error_index]) + previous = support_costs.get(support) + if previous is None or cost < previous: + support_costs[support] = cost + return support_costs + + +def split_support_costs_into_components( + *, + active_detector_ids: Tuple[int, ...], + support_costs: Dict[Tuple[int, ...], float], +) -> List[Tuple[Tuple[int, ...], Dict[Tuple[int, ...], float]]]: + detector_to_local = {detector: i for i, detector in enumerate(active_detector_ids)} + uf = UnionFind(len(active_detector_ids)) + for support in support_costs: + if len(support) <= 1: + continue + first = detector_to_local[support[0]] + for detector in support[1:]: + uf.union(first, detector_to_local[detector]) + + detectors_by_root: Dict[int, List[int]] = defaultdict(list) + for detector in active_detector_ids: + detectors_by_root[uf.find(detector_to_local[detector])].append(detector) + supports_by_root: Dict[int, Dict[Tuple[int, ...], float]] = defaultdict(dict) + for support, cost in support_costs.items(): + root = uf.find(detector_to_local[support[0]]) + supports_by_root[root][support] = cost + components: List[Tuple[Tuple[int, ...], Dict[Tuple[int, ...], float]]] = [] + for root, detectors in detectors_by_root.items(): + components.append((tuple(sorted(detectors)), supports_by_root[root])) + components.sort(key=lambda item: (len(item[0]), item[0])) + return components + + +def solve_primal_lp_on_supports( + *, + detector_ids: Tuple[int, ...], + support_costs: Dict[Tuple[int, ...], float], + record_stats: SingletonLPSolverStats, + count_as_main_lp: bool, +) -> Tuple[float, Dict[int, float], List[Tuple[int, ...]], int, int]: + detector_to_var = {detector: i for i, detector in enumerate(detector_ids)} + if any(not any(detector in support for support in support_costs) for detector in detector_ids): + raise RuntimeError("LP component has an uncovered detector; restricted master lost coverage.") + + row_indices: List[int] = [] + col_indices: List[int] = [] + values: List[float] = [] + rhs = np.empty(len(support_costs), dtype=np.float64) + supports = sorted(support_costs, key=lambda s: (len(s), s)) + for row, support in enumerate(supports): + rhs[row] = float(support_costs[support]) + for detector in support: + row_indices.append(row) + col_indices.append(detector_to_var[detector]) + values.append(1.0) + + a_ub = sparse.csr_matrix( + (values, (row_indices, col_indices)), + shape=(len(supports), len(detector_ids)), + dtype=np.float64, + ) + record_stats.lp_calls += 1 if count_as_main_lp else 0 + t0 = time.perf_counter() + result = linprog( + c=-np.ones(len(detector_ids), dtype=np.float64), + A_ub=a_ub, + b_ub=rhs, + bounds=[(0.0, None)] * len(detector_ids), + method="highs", + ) + elapsed = time.perf_counter() - t0 + if count_as_main_lp: + record_stats.lp_total_seconds += elapsed + if not result.success: + raise RuntimeError( + f"singleton LP solve failed: status={result.status} message={result.message}" + ) + + solution = np.asarray(result.x, dtype=np.float64) + y_by_detector = { + detector_ids[var_index]: float(solution[var_index]) + for var_index in range(len(detector_ids)) + if solution[var_index] > 1e-12 + } + tight_supports: List[Tuple[int, ...]] = [] + for row, support in enumerate(supports): + lhs = float(sum(solution[detector_to_var[detector]] for detector in support)) + if abs(float(rhs[row]) - lhs) <= 1e-8: + tight_supports.append(support) + return float(-result.fun), y_by_detector, tight_supports, len(detector_ids), len(supports) + + +def detectors_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + detectors = np.zeros(data.num_detectors, dtype=bool) + for error_index in activated_errors: + for detector in data.error_detectors[error_index]: + detectors[detector] ^= True + return detectors + + +def observables_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + observables = np.zeros(data.num_observables, dtype=bool) + for error_index in activated_errors: + for observable in data.error_observables[error_index]: + observables[observable] ^= True + return observables + + +def decode( + data: DecoderData, + detections: np.ndarray, + *, + det_beam: float = INF, + singleton_solver: Optional[SingletonLPHeuristic] = None, + verbose_search: bool = False, +) -> DecodeResult: + start_time = time.perf_counter() + if singleton_solver is not None: + singleton_solver.reset_stats() + + heuristic_calls = 0 + plain_heuristic_calls = 0 + projection_heuristic_calls = 0 + exact_refinement_calls = 0 + lp_reinserts = 0 + projected_nodes_generated = 0 + projected_nodes_refined = 0 + total_lp_refinement_gain = 0.0 + max_lp_refinement_gain = 0.0 + + initial_active_detectors = np.asarray(detections, dtype=bool).copy() + initial_counts = initial_detector_counts(data, initial_active_detectors) + initial_blocked = np.zeros(len(data.errors), dtype=bool) + heuristic_calls += 1 + plain_heuristic_calls += 1 + initial_heuristic = plain_detcost_heuristic( + data=data, + active_detectors=initial_active_detectors, + blocked_errors=initial_blocked, + active_detector_counts=initial_counts, + ) + if initial_heuristic == INF: + raise RuntimeError("Initial residual syndrome is infeasible under the current pruning rule.") + + initial_state = SearchState( + activated_errors=(), + blocked_errors=initial_blocked, + active_detectors=initial_active_detectors, + active_detector_counts=initial_counts, + path_cost=0.0, + heuristic_cost=initial_heuristic, + heuristic_source="plain", + exact_refined=(singleton_solver is None), + lp_solution=None, + warm_start_solution=None, + changed_detectors_from_parent=(), + ) + + priority_queue: List[Tuple[float, int, int, SearchState]] = [] + push_counter = 0 + initial_num_dets = int(initial_active_detectors.sum()) + heapq.heappush( + priority_queue, + (initial_state.path_cost + initial_state.heuristic_cost, initial_num_dets, push_counter, initial_state), + ) + push_counter += 1 + + num_pq_pushed = 1 + num_nodes_popped = 0 + max_queue_size = 1 + min_num_dets = initial_num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + if singleton_solver is None: + heuristic_name = "plain_detcost" + else: + heuristic_name = f"opt_singleton_{singleton_solver.exact_mode}_lazy_{singleton_solver.projection_mode}" + if singleton_solver.projection_combine_max_plain: + heuristic_name += "_maxplain" + + while priority_queue: + max_queue_size = max(max_queue_size, len(priority_queue)) + f_cost, num_dets, _, state = heapq.heappop(priority_queue) + num_nodes_popped += 1 + + if num_dets > max_num_dets: + continue + + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + if verbose_search: + print( + f"nodes_popped={num_nodes_popped} len(pq)={len(priority_queue)} " + f"lp_calls={0 if singleton_solver is None else singleton_solver.stats.lp_calls} " + f"lp_reinserts={lp_reinserts} proj_generated={projected_nodes_generated} " + f"proj_refined={projected_nodes_refined} " + f"proj_unrefined_so_far={projected_nodes_generated - projected_nodes_refined} " + f"active_dets={num_dets} beam_max={max_num_dets} depth={len(state.activated_errors)} " + f"f={f_cost:.12g} g={state.path_cost:.12g} h={state.heuristic_cost:.12g} " + f"h_source={state.heuristic_source} exact_refined={state.exact_refined}" + ) + + if num_dets == 0: + elapsed_seconds = time.perf_counter() - start_time + stats = DecodeStats( + num_pq_pushed=num_pq_pushed, + num_nodes_popped=num_nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=heuristic_calls, + plain_heuristic_calls=plain_heuristic_calls, + projection_heuristic_calls=projection_heuristic_calls, + exact_refinement_calls=exact_refinement_calls, + lp_calls=0 if singleton_solver is None else singleton_solver.stats.lp_calls, + lp_reinserts=lp_reinserts, + projected_nodes_generated=projected_nodes_generated, + projected_nodes_refined=projected_nodes_refined, + projected_nodes_unrefined_at_finish=projected_nodes_generated - projected_nodes_refined, + total_lp_refinement_gain=total_lp_refinement_gain, + max_lp_refinement_gain=max_lp_refinement_gain, + lp_total_seconds=0.0 if singleton_solver is None else singleton_solver.stats.lp_total_seconds, + projection_local_lp_calls=0 if singleton_solver is None else singleton_solver.stats.projection_local_lp_calls, + projection_local_lp_seconds=0.0 if singleton_solver is None else singleton_solver.stats.projection_local_lp_seconds, + restricted_total_rounds=0 if singleton_solver is None else singleton_solver.stats.restricted_total_rounds, + restricted_total_added_supports=0 if singleton_solver is None else singleton_solver.stats.restricted_total_added_supports, + restricted_total_fallbacks=0 if singleton_solver is None else singleton_solver.stats.restricted_total_fallbacks, + full_check_calls=0 if singleton_solver is None else singleton_solver.stats.full_check_calls, + full_check_max_abs_delta=0.0 if singleton_solver is None else singleton_solver.stats.full_check_max_abs_delta, + elapsed_seconds=elapsed_seconds, + heuristic_name=heuristic_name, + ) + return DecodeResult( + activated_errors=state.activated_errors, + path_cost=state.path_cost, + stats=stats, + ) + + if singleton_solver is not None and not state.exact_refined: + heuristic_calls += 1 + exact_refinement_calls += 1 + previous_h = state.heuristic_cost + previous_source = state.heuristic_source + exact_solution, exact_payload = singleton_solver.solve_exact( + active_detectors=state.active_detectors, + blocked_errors=state.blocked_errors, + active_detector_counts=state.active_detector_counts, + warm_start_solution=state.warm_start_solution, + changed_detectors=state.changed_detectors_from_parent, + ) + exact_h = exact_solution.value + reinserted = False + discarded = False + + if exact_h == INF: + discarded = True + if previous_source == "projected": + projected_nodes_refined += 1 + else: + if exact_h + 1e-7 < previous_h: + raise AssertionError( + f"Exact singleton LP lower bound {exact_h} is below stored {previous_source} lower bound {previous_h}." + ) + delta = exact_h - previous_h + total_lp_refinement_gain += delta + max_lp_refinement_gain = max(max_lp_refinement_gain, delta) + state.heuristic_cost = exact_h + state.heuristic_source = "exact" + state.exact_refined = True + state.lp_solution = exact_solution + if previous_source == "projected": + projected_nodes_refined += 1 + if delta > HEURISTIC_EPS: + reinserted = True + lp_reinserts += 1 + heapq.heappush( + priority_queue, + (state.path_cost + state.heuristic_cost, num_dets, push_counter, state), + ) + push_counter += 1 + + if singleton_solver.logger is not None: + payload = dict(exact_payload) + payload.update( + { + "call_index": exact_refinement_calls, + "phase": "exact_refinement", + "depth": len(state.activated_errors), + "nodes_popped": num_nodes_popped, + "path_cost": state.path_cost, + "active_detector_count": num_dets, + "approx_h": previous_h, + "exact_h": exact_h, + "delta": INF if exact_h == INF else exact_h - previous_h, + "heuristic_source_before": previous_source, + "reinserted": reinserted, + "discarded": discarded, + } + ) + singleton_solver.logger.maybe_log(call_index=exact_refinement_calls, payload=payload) + + if verbose_search: + delta_text = "INF" if exact_h == INF else f"{exact_h - previous_h:.12g}" + exact_text = "INF" if exact_h == INF else f"{exact_h:.12g}" + print( + f" lp_refine approx_h={previous_h:.12g} exact_h={exact_text} delta={delta_text} " + f"vars={exact_solution.num_variables} constraints={exact_solution.num_constraints} " + f"selected={exact_solution.num_selected_constraints} rounds={exact_solution.num_rounds} " + f"tight={len(exact_solution.tight_supports)} reinserted={reinserted} discarded={discarded}" + ) + + if discarded or reinserted: + continue + + min_detector = int(np.flatnonzero(state.active_detectors)[0]) + blocked_prefix = state.blocked_errors.copy() + children_generated = 0 + children_projected = 0 + children_beam_pruned = 0 + children_infeasible = 0 + + for error_index in data.detector_to_errors[min_detector]: + blocked_prefix[error_index] = True + if state.blocked_errors[error_index]: + continue + + child_active_detectors, child_active_counts = apply_error( + data=data, + active_detectors=state.active_detectors, + active_detector_counts=state.active_detector_counts, + error_index=error_index, + ) + child_num_dets = int(child_active_detectors.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_blocked = blocked_prefix.copy() + child_path_cost = state.path_cost + float(data.error_costs[error_index]) + changed_detectors = tuple(sorted(data.error_detectors[error_index])) + + if singleton_solver is None: + heuristic_calls += 1 + plain_heuristic_calls += 1 + child_heuristic = plain_detcost_heuristic( + data=data, + active_detectors=child_active_detectors, + blocked_errors=child_blocked, + active_detector_counts=child_active_counts, + ) + child_source = "plain" + child_exact_refined = True + child_lp_solution = None + child_warm_start_solution = None + else: + if state.lp_solution is None: + raise AssertionError("Projected singleton heuristic requires an exact-refined parent solution.") + heuristic_calls += 1 + projection_heuristic_calls += 1 + projected_nodes_generated += 1 + children_projected += 1 + child_heuristic = singleton_solver.project_to_child( + parent_solution=state.lp_solution, + child_active_detectors=child_active_detectors, + child_blocked_errors=child_blocked, + child_active_detector_counts=child_active_counts, + changed_detectors=changed_detectors, + ) + child_source = "projected" + child_exact_refined = False + child_lp_solution = None + child_warm_start_solution = state.lp_solution + + if child_heuristic == INF: + children_infeasible += 1 + continue + + child_state = SearchState( + activated_errors=state.activated_errors + (error_index,), + blocked_errors=child_blocked, + active_detectors=child_active_detectors, + active_detector_counts=child_active_counts, + path_cost=child_path_cost, + heuristic_cost=child_heuristic, + heuristic_source=child_source, + exact_refined=child_exact_refined, + lp_solution=child_lp_solution, + warm_start_solution=child_warm_start_solution, + changed_detectors_from_parent=changed_detectors, + ) + heapq.heappush( + priority_queue, + (child_path_cost + child_heuristic, child_num_dets, push_counter, child_state), + ) + push_counter += 1 + num_pq_pushed += 1 + children_generated += 1 + + if verbose_search: + print( + f" expanded children_generated={children_generated} children_projected={children_projected} " + f"beam_pruned={children_beam_pruned} infeasible={children_infeasible}" + ) + + raise RuntimeError("Decoding failed to find any completion.") + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim detector error models. " + "Supports plain detcost, lazy full singleton LP, and a restricted-master singleton LP." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a Stim circuit file.") + parser.add_argument("--shot", type=int, default=0, help="Zero-based sampled shot index to decode.") + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot.", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Seed passed to stim.compile_detector_sampler(...).sample(...).", + ) + parser.add_argument( + "--det-beam", + type=parse_beam, + default=INF, + help="Beam cutoff on residual detector count. Use an integer or 'inf'.", + ) + parser.add_argument( + "--opt-singleton-detcost-mode", + choices=["plain", "full", "restricted"], + default="plain", + help="Heuristic mode: plain detcost, lazy full singleton LP, or lazy restricted singleton LP.", + ) + parser.add_argument( + "--projection-mode", + choices=["plain", "parent_y", "new_only", "changed_neighborhood"], + default="changed_neighborhood", + help=( + "How to score child nodes before exact refinement. " + "'parent_y' reuses parent primal detector prices, 'new_only' solves a tiny residual LP on newly active detectors, " + "and 'changed_neighborhood' solves a tiny residual LP on a local region around the changed detectors." + ), + ) + parser.add_argument( + "--projection-combine-max-plain", + action=argparse.BooleanOptionalAction, + default=True, + help="Take max(projected child lower bound, plain detcost).", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the sampled shot's active detector IDs before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the activated error indices in the final decoding.", + ) + parser.add_argument("--verbose-search", action="store_true", help="Print per-node search diagnostics.") + parser.add_argument( + "--lp-log-path", + type=Path, + default=None, + help="Optional JSONL file for logging exact singleton-LP refinements.", + ) + parser.add_argument( + "--lp-log-top-k", + type=int, + default=12, + help="When logging exact LP refinements, include at most this many top supports.", + ) + parser.add_argument( + "--lp-log-every", + type=int, + default=1, + help="When logging exact LP refinements, only write every k-th refinement.", + ) + parser.add_argument( + "--restricted-add-policy", + choices=["one", "topk", "all"], + default="topk", + help="Violation separation policy for restricted singleton LP mode.", + ) + parser.add_argument( + "--restricted-add-top-k", + type=int, + default=3, + help="When --restricted-add-policy=topk, add this many most violated supports.", + ) + parser.add_argument( + "--restricted-max-rounds", + type=int, + default=50, + help="Maximum separation rounds before optional fallback to the full singleton LP.", + ) + parser.add_argument( + "--restricted-fallback-full", + action=argparse.BooleanOptionalAction, + default=True, + help="If restricted mode hits the round limit, fall back to the full singleton LP.", + ) + parser.add_argument( + "--restricted-prune-slack", + action=argparse.BooleanOptionalAction, + default=True, + help="Prune slack supports from the restricted master between rounds.", + ) + parser.add_argument( + "--restricted-prune-tol", + type=float, + default=1e-8, + help="Keep selected supports whose slack is at most this value.", + ) + parser.add_argument( + "--restricted-violation-tol", + type=float, + default=1e-9, + help="Violation tolerance used during separation.", + ) + parser.add_argument( + "--restricted-tight-tol", + type=float, + default=1e-8, + help="Tolerance for tagging a support as tight in the exact solution.", + ) + parser.add_argument( + "--restricted-seed-normalized-global-k", + type=int, + default=0, + help="Add this many globally cheapest supports by cost/size to the initial restricted pool.", + ) + parser.add_argument( + "--restricted-seed-normalized-touching-changed-k", + type=int, + default=2, + help="Add this many cheapest cost/size supports touching changed detectors to the initial restricted pool.", + ) + parser.add_argument( + "--full-check-every", + type=int, + default=0, + help="In restricted mode, solve the full singleton LP every k exact refinements and assert equality (0 disables).", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.lp_log_every <= 0: + parser.error("--lp-log-every must be positive.") + if args.lp_log_top_k <= 0: + parser.error("--lp-log-top-k must be positive.") + if args.restricted_add_top_k <= 0: + parser.error("--restricted-add-top-k must be positive.") + if args.restricted_max_rounds <= 0: + parser.error("--restricted-max-rounds must be positive.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + data = build_decoder_data(dem, merge_errors_in_dem=args.merge_errors) + + singleton_solver = None + if args.opt_singleton_detcost_mode != "plain": + logger = None + if args.lp_log_path is not None: + logger = LPLogger( + args.lp_log_path, + every=args.lp_log_every, + top_k=args.lp_log_top_k, + ) + restricted_config = RestrictedMasterConfig( + add_policy=args.restricted_add_policy, + add_top_k=args.restricted_add_top_k, + violation_tol=args.restricted_violation_tol, + tight_tol=args.restricted_tight_tol, + prune_slack=args.restricted_prune_slack, + prune_tol=args.restricted_prune_tol, + seed_normalized_global_k=args.restricted_seed_normalized_global_k, + seed_normalized_touching_changed_k=args.restricted_seed_normalized_touching_changed_k, + max_rounds=args.restricted_max_rounds, + fallback_full=args.restricted_fallback_full, + full_check_every=args.full_check_every, + ) + singleton_solver = SingletonLPHeuristic( + data, + exact_mode=args.opt_singleton_detcost_mode, + projection_mode=args.projection_mode, + projection_combine_max_plain=args.projection_combine_max_plain, + restricted_config=restricted_config, + logger=logger, + ) + + dets_packed, obs_packed = circuit.compile_detector_sampler(seed=args.seed).sample( + shots=args.sample_num_shots, + separate_observables=True, + bit_packed=True, + ) + detections = unpack_bit_packed_rows(dets_packed, count=dem.num_detectors) + observables = unpack_bit_packed_rows(obs_packed, count=dem.num_observables) + + if args.shot >= detections.shape[0]: + parser.error(f"--shot={args.shot} is out of range for {detections.shape[0]} sampled shots.") + + shot_detections = detections[args.shot] + shot_observables = observables[args.shot] if observables.size else np.zeros(0, dtype=bool) + + print(f"circuit = {args.circuit}") + if singleton_solver is None: + print("heuristic = plain_detcost") + else: + print( + "heuristic = " + + f"opt_singleton_{args.opt_singleton_detcost_mode}_lazy_{args.projection_mode}" + + ("_maxplain" if args.projection_combine_max_plain else "") + ) + print(f"shot = {args.shot}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"num_detectors = {data.num_detectors}") + print(f"num_observables = {data.num_observables}") + print(f"num_errors = {len(data.errors)}") + print(f"beam = {args.det_beam}") + if args.show_shot_detectors: + print(f"shot_detectors = {format_indices(np.flatnonzero(shot_detections), 'D')}") + + result = decode( + data=data, + detections=shot_detections, + det_beam=args.det_beam, + singleton_solver=singleton_solver, + verbose_search=args.verbose_search, + ) + + predicted_observables = observables_from_solution(data, result.activated_errors) + reproduced_detectors = detectors_from_solution(data, result.activated_errors) + if not np.array_equal(reproduced_detectors, shot_detections): + raise AssertionError("Decoded error set does not reproduce the shot's syndrome.") + + print(f"solution_size = {len(result.activated_errors)}") + print(f"solution_cost = {result.path_cost:.12g}") + if args.show_error_indices: + print(f"activated_errors = {format_indices(result.activated_errors, 'E')}") + print(f"predicted_observables = {format_indices(np.flatnonzero(predicted_observables), 'L')}") + print(f"sample_observables = {format_indices(np.flatnonzero(shot_observables), 'L')}") + print(f"observables_match = {bool(np.array_equal(predicted_observables, shot_observables))}") + print(f"num_pq_pushed = {result.stats.num_pq_pushed}") + print(f"num_nodes_popped = {result.stats.num_nodes_popped}") + print(f"max_queue_size = {result.stats.max_queue_size}") + print(f"heuristic_calls = {result.stats.heuristic_calls}") + print(f"plain_heuristic_calls = {result.stats.plain_heuristic_calls}") + print(f"projection_heuristic_calls = {result.stats.projection_heuristic_calls}") + print(f"exact_refinement_calls = {result.stats.exact_refinement_calls}") + print(f"lp_calls = {result.stats.lp_calls}") + print(f"lp_reinserts = {result.stats.lp_reinserts}") + print(f"projected_nodes_generated = {result.stats.projected_nodes_generated}") + print(f"projected_nodes_refined = {result.stats.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish = {result.stats.projected_nodes_unrefined_at_finish}") + print(f"total_lp_refinement_gain = {result.stats.total_lp_refinement_gain:.12g}") + print(f"max_lp_refinement_gain = {result.stats.max_lp_refinement_gain:.12g}") + print(f"lp_total_seconds = {result.stats.lp_total_seconds:.6f}") + print(f"projection_local_lp_calls = {result.stats.projection_local_lp_calls}") + print(f"projection_local_lp_seconds = {result.stats.projection_local_lp_seconds:.6f}") + print(f"restricted_total_rounds = {result.stats.restricted_total_rounds}") + print(f"restricted_total_added_supports = {result.stats.restricted_total_added_supports}") + print(f"restricted_total_fallbacks = {result.stats.restricted_total_fallbacks}") + print(f"full_check_calls = {result.stats.full_check_calls}") + print(f"full_check_max_abs_delta = {result.stats.full_check_max_abs_delta:.12g}") + print(f"elapsed_seconds = {result.stats.elapsed_seconds:.6f}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_subset_detcost.py b/src/py/astar/astar_prototype_subset_detcost.py new file mode 100644 index 0000000..f405240 --- /dev/null +++ b/src/py/astar/astar_prototype_subset_detcost.py @@ -0,0 +1,1071 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with detcost and subset-LP heuristics. + +This script keeps the basic search structure of the original prototype while +adding a small CLI and a family of stronger admissible heuristics. + +Heuristic modes: + --opt-subset-detcost-size 0 plain detcost + --opt-subset-detcost-size 1 optimal singleton LP + --opt-subset-detcost-size 2 optimal LP over singletons and 2-detector subsets + --opt-subset-detcost-size 3 optimal LP over singletons and 2/3-detector subsets + +The subset library is the small-subset closure of DEM supports: + * every singleton detector subset, and + * every nonempty subset of D(e) of size at most N, for each DEM error e. + +For a library subset S, the local decoder only sees the restriction of errors to +S. Because N is intended to be small (<= 3 in practice), all minimal local +pattern resolutions can be precomputed once. +""" + +from __future__ import annotations + +import argparse +import heapq +import itertools +import json +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy import sparse +from scipy.optimize import linprog + +INF = math.inf + + +@dataclass(frozen=True) +class MergedError: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class DecoderData: + num_detectors: int + num_observables: int + errors: List[MergedError] + detector_to_errors: List[List[int]] + error_costs: np.ndarray + error_detectors: List[Tuple[int, ...]] + error_detector_sets: List[frozenset[int]] + error_observables: List[Tuple[int, ...]] + + +@dataclass(frozen=True) +class SubsetLibraryEntry: + subset_id: int + detectors: Tuple[int, ...] + pattern_to_errors: Dict[int, Tuple[int, ...]] + resolution_combos: Dict[int, Tuple[Tuple[int, ...], ...]] + + +@dataclass +class ActiveSubsetRecord: + subset_id: int + detectors: Tuple[int, ...] + size: int + target_mask: int + available_patterns: Dict[int, Tuple[int, ...]] + feasible_combos: Tuple[Tuple[int, ...], ...] + + +@dataclass +class SearchState: + activated_errors: Tuple[int, ...] + blocked_errors: np.ndarray + active_detectors: np.ndarray + active_detector_counts: np.ndarray + path_cost: float + + +@dataclass +class DecodeStats: + num_pq_pushed: int + num_nodes_popped: int + max_queue_size: int + heuristic_calls: int + lp_calls: int + lp_total_seconds: float + elapsed_seconds: float + heuristic_name: str + + +@dataclass +class DecodeResult: + activated_errors: Tuple[int, ...] + path_cost: float + stats: DecodeStats + + +class UnionFind: + def __init__(self, size: int) -> None: + self.parent = list(range(size)) + self.rank = [0] * size + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class LPLogger: + def __init__(self, path: Path, *, every: int = 1, top_k: int = 10) -> None: + self.path = path + self.every = max(1, every) + self.top_k = max(1, top_k) + self.path.parent.mkdir(parents=True, exist_ok=True) + # Truncate eagerly so repeated runs do not append by accident. + self.path.write_text("") + + def maybe_log(self, *, call_index: int, payload: Dict[str, Any]) -> None: + if call_index % self.every != 0: + return + with self.path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, sort_keys=True) + "\n") + + +def parse_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "+inf", "infinity", "+infinity"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("beam must be non-negative or 'inf'") + return float(value) + + +def format_indices(indices: Iterable[int], prefix: str) -> str: + items = list(indices) + if not items: + return "(none)" + return " ".join(f"{prefix}{i}" for i in items) + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1 - p1) + (1 - p0) * p1 + + +def iter_dem_errors(dem: stim.DetectorErrorModel) -> Iterable[MergedError]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + "This prototype assumes detector-error-model probabilities are in (0, 0.5)." + ) + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + yield MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors(dem: stim.DetectorErrorModel) -> List[MergedError]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors(dem): + key = (error.detectors, error.observables) + previous = errors_by_symptom.get(key) + if previous is None: + errors_by_symptom[key] = error.probability + else: + errors_by_symptom[key] = xor_probability(previous, error.probability) + + merged: List[MergedError] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + "Merged error has probability >= 0.5, which would give a non-positive cost." + ) + merged.append( + MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def build_decoder_data( + dem: stim.DetectorErrorModel, + *, + merge_errors_in_dem: bool = True, +) -> DecoderData: + errors = merged_errors(dem) if merge_errors_in_dem else list(iter_dem_errors(dem)) + detector_to_errors: List[List[int]] = [[] for _ in range(dem.num_detectors)] + for ei, error in enumerate(errors): + for d in error.detectors: + detector_to_errors[d].append(ei) + return DecoderData( + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + errors=errors, + detector_to_errors=detector_to_errors, + error_costs=np.asarray([e.likelihood_cost for e in errors], dtype=np.float64), + error_detectors=[e.detectors for e in errors], + error_detector_sets=[frozenset(e.detectors) for e in errors], + error_observables=[e.observables for e in errors], + ) + + +def unpack_bit_packed_rows(bits: np.ndarray, count: int) -> np.ndarray: + return np.unpackbits(bits, bitorder="little", axis=1, count=count).astype(bool, copy=False) + + +def initial_detector_counts(data: DecoderData, active_detectors: np.ndarray) -> np.ndarray: + counts = np.zeros(len(data.errors), dtype=np.int32) + for d in np.flatnonzero(active_detectors): + for ei in data.detector_to_errors[int(d)]: + counts[ei] += 1 + return counts + + +def apply_error( + data: DecoderData, + active_detectors: np.ndarray, + active_detector_counts: np.ndarray, + error_index: int, +) -> Tuple[np.ndarray, np.ndarray]: + next_detectors = active_detectors.copy() + next_counts = active_detector_counts.copy() + for d in data.error_detectors[error_index]: + if next_detectors[d]: + next_detectors[d] = False + delta = -1 + else: + next_detectors[d] = True + delta = 1 + for other_error_index in data.detector_to_errors[d]: + next_counts[other_error_index] += delta + return next_detectors, next_counts + + +def plain_detcost_for_detector( + data: DecoderData, + detector: int, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + best = INF + for ei in data.detector_to_errors[detector]: + if blocked_errors[ei]: + continue + count = int(active_detector_counts[ei]) + assert count > 0 + candidate = float(data.error_costs[ei]) / count + if candidate < best: + best = candidate + return best + + +def plain_detcost_heuristic( + data: DecoderData, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + total = 0.0 + for d in np.flatnonzero(active_detectors): + det_cost = plain_detcost_for_detector( + data=data, + detector=int(d), + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + ) + if det_cost == INF: + return INF + total += det_cost + return total + + +def compute_minimal_resolution_combos( + available_pattern_masks: Iterable[int], + subset_size: int, +) -> Dict[int, Tuple[Tuple[int, ...], ...]]: + """Precompute inclusion-minimal local pattern combinations for each target. + + For a fixed subset S of size k, an error only matters through its nonzero local + pattern D(e)∩S, represented as a bit-mask in {1, ..., 2^k-1}. Because local + budgets are nonnegative, an optimal local resolution never needs to use the same + local pattern twice, and any combo that strictly contains another combo with the + same XOR target is dominated. + """ + + patterns = tuple(sorted(set(available_pattern_masks))) + combos_by_target: Dict[int, List[Tuple[int, ...]]] = { + target: [] for target in range(1, 1 << subset_size) + } + for r in range(1, min(len(patterns), subset_size) + 1): + for combo in itertools.combinations(patterns, r): + target_mask = 0 + for pattern_mask in combo: + target_mask ^= pattern_mask + if target_mask == 0: + continue + combo_set = set(combo) + existing = combos_by_target[target_mask] + keep = True + survivors: List[Tuple[int, ...]] = [] + for old_combo in existing: + old_set = set(old_combo) + if combo_set.issuperset(old_set): + keep = False + survivors.append(old_combo) + elif old_set.issuperset(combo_set): + continue + else: + survivors.append(old_combo) + if keep: + survivors.append(combo) + survivors.sort(key=lambda x: (len(x), x)) + combos_by_target[target_mask] = survivors + return { + target_mask: tuple(combos) + for target_mask, combos in combos_by_target.items() + if combos + } + + +@dataclass +class SubsetLibrary: + max_subset_size: int + entries: List[SubsetLibraryEntry] + subsets_by_detector: List[List[int]] + num_subsets_by_size: Dict[int, int] + + +def build_subset_library(data: DecoderData, max_subset_size: int) -> SubsetLibrary: + library_keys: set[Tuple[int, ...]] = set() + if max_subset_size >= 1: + for detector in range(data.num_detectors): + library_keys.add((detector,)) + + for detectors in data.error_detectors: + limit = min(max_subset_size, len(detectors)) + for subset_size in range(1, limit + 1): + for subset_detectors in itertools.combinations(detectors, subset_size): + library_keys.add(tuple(subset_detectors)) + + subsets_by_detector: List[List[int]] = [[] for _ in range(data.num_detectors)] + entries: List[SubsetLibraryEntry] = [] + num_subsets_by_size: Dict[int, int] = defaultdict(int) + + for subset_id, subset_detectors in enumerate(sorted(library_keys, key=lambda t: (len(t), t))): + pattern_to_errors: Dict[int, List[int]] = defaultdict(list) + for error_index, detector_set in enumerate(data.error_detector_sets): + pattern_mask = 0 + for bit_index, detector in enumerate(subset_detectors): + if detector in detector_set: + pattern_mask |= 1 << bit_index + if pattern_mask != 0: + pattern_to_errors[pattern_mask].append(error_index) + frozen_pattern_to_errors = { + pattern_mask: tuple(error_indices) + for pattern_mask, error_indices in pattern_to_errors.items() + } + entry = SubsetLibraryEntry( + subset_id=subset_id, + detectors=subset_detectors, + pattern_to_errors=frozen_pattern_to_errors, + resolution_combos=compute_minimal_resolution_combos( + available_pattern_masks=frozen_pattern_to_errors.keys(), + subset_size=len(subset_detectors), + ), + ) + entries.append(entry) + num_subsets_by_size[len(subset_detectors)] += 1 + for detector in subset_detectors: + subsets_by_detector[detector].append(subset_id) + + return SubsetLibrary( + max_subset_size=max_subset_size, + entries=entries, + subsets_by_detector=subsets_by_detector, + num_subsets_by_size=dict(sorted(num_subsets_by_size.items())), + ) + + +@dataclass +class SubsetLPHeuristicStats: + call_count: int = 0 + lp_call_count: int = 0 + lp_total_seconds: float = 0.0 + + +class SubsetLPHeuristic: + def __init__( + self, + data: DecoderData, + subset_library: SubsetLibrary, + *, + logger: Optional[LPLogger] = None, + ) -> None: + self.data = data + self.subset_library = subset_library + self.logger = logger + self.stats = SubsetLPHeuristicStats() + + def evaluate( + self, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + *, + context: Optional[Dict[str, Any]] = None, + ) -> float: + self.stats.call_count += 1 + self.stats.lp_call_count += 1 + t0 = time.perf_counter() + + active_subset_ids: set[int] = set() + for detector in np.flatnonzero(active_detectors): + active_subset_ids.update(self.subset_library.subsets_by_detector[int(detector)]) + + subset_records: List[ActiveSubsetRecord] = [] + error_to_subset_positions: Dict[int, List[int]] = defaultdict(list) + + for subset_id in sorted(active_subset_ids): + entry = self.subset_library.entries[subset_id] + target_mask = 0 + for bit_index, detector in enumerate(entry.detectors): + if active_detectors[detector]: + target_mask |= 1 << bit_index + if target_mask == 0: + continue + + available_patterns: Dict[int, Tuple[int, ...]] = {} + relevant_errors: set[int] = set() + for pattern_mask, error_indices in entry.pattern_to_errors.items(): + kept = tuple(error_index for error_index in error_indices if not blocked_errors[error_index]) + if kept: + available_patterns[pattern_mask] = kept + relevant_errors.update(kept) + + feasible_combos = tuple( + combo + for combo in entry.resolution_combos.get(target_mask, ()) + if all(pattern_mask in available_patterns for pattern_mask in combo) + ) + if not feasible_combos: + self.stats.lp_total_seconds += time.perf_counter() - t0 + return INF + + record = ActiveSubsetRecord( + subset_id=subset_id, + detectors=entry.detectors, + size=len(entry.detectors), + target_mask=target_mask, + available_patterns=available_patterns, + feasible_combos=feasible_combos, + ) + subset_position = len(subset_records) + subset_records.append(record) + for error_index in sorted(relevant_errors): + error_to_subset_positions[error_index].append(subset_position) + + if not subset_records: + elapsed = time.perf_counter() - t0 + self.stats.lp_total_seconds += elapsed + if self.logger is not None: + payload: Dict[str, Any] = { + "call_index": self.stats.call_count, + "objective": 0.0, + "solve_seconds": elapsed, + "num_active_subsets": 0, + "num_components": 0, + } + if context is not None: + payload.update(context) + self.logger.maybe_log(call_index=self.stats.call_count, payload=payload) + return 0.0 + + component_uf = UnionFind(len(subset_records)) + for subset_positions in error_to_subset_positions.values(): + for position in subset_positions[1:]: + component_uf.union(subset_positions[0], position) + component_to_subset_positions: Dict[int, List[int]] = defaultdict(list) + for subset_position in range(len(subset_records)): + component_to_subset_positions[component_uf.find(subset_position)].append(subset_position) + + total_objective = 0.0 + total_num_variables = 0 + total_num_constraints = 0 + contribution_by_size: Dict[int, float] = defaultdict(float) + budget_by_size: Dict[int, float] = defaultdict(float) + active_subset_count_by_size: Dict[int, int] = defaultdict(int) + top_subset_records: List[Dict[str, Any]] = [] + + for component_positions in component_to_subset_positions.values(): + y_var: Dict[int, int] = {} + u_var: Dict[Tuple[int, int], int] = {} + error_to_u_vars: Dict[int, List[int]] = defaultdict(list) + + next_var_index = 0 + for subset_position in component_positions: + y_var[subset_position] = next_var_index + next_var_index += 1 + for subset_position in component_positions: + record = subset_records[subset_position] + active_subset_count_by_size[record.size] += 1 + for pattern_mask, error_indices in sorted(record.available_patterns.items()): + variable_index = next_var_index + u_var[(subset_position, pattern_mask)] = variable_index + next_var_index += 1 + for error_index in error_indices: + error_to_u_vars[error_index].append(variable_index) + + row_indices: List[int] = [] + col_indices: List[int] = [] + values: List[float] = [] + rhs: List[float] = [] + + for error_index, variable_indices in sorted(error_to_u_vars.items()): + row = len(rhs) + rhs.append(float(self.data.error_costs[error_index])) + for variable_index in variable_indices: + row_indices.append(row) + col_indices.append(variable_index) + values.append(1.0) + + for subset_position in component_positions: + record = subset_records[subset_position] + y_index = y_var[subset_position] + for combo in record.feasible_combos: + row = len(rhs) + rhs.append(0.0) + row_indices.append(row) + col_indices.append(y_index) + values.append(1.0) + for pattern_mask in combo: + row_indices.append(row) + col_indices.append(u_var[(subset_position, pattern_mask)]) + values.append(-1.0) + + total_num_variables += next_var_index + total_num_constraints += len(rhs) + + a_ub = sparse.csr_matrix( + (values, (row_indices, col_indices)), + shape=(len(rhs), next_var_index), + dtype=np.float64, + ) + objective = np.zeros(next_var_index, dtype=np.float64) + for subset_position in component_positions: + objective[y_var[subset_position]] = -1.0 + + result = linprog( + c=objective, + A_ub=a_ub, + b_ub=np.asarray(rhs, dtype=np.float64), + bounds=[(0.0, None)] * next_var_index, + method="highs", + ) + if not result.success: + raise RuntimeError( + f"subset detcost LP solve failed: status={result.status} message={result.message}" + ) + total_objective += float(-result.fun) + solution = np.asarray(result.x, dtype=np.float64) + + for subset_position in component_positions: + record = subset_records[subset_position] + y_value = float(solution[y_var[subset_position]]) + total_budget = float( + sum(solution[u_var[(subset_position, pattern_mask)]] for pattern_mask in record.available_patterns) + ) + contribution_by_size[record.size] += y_value + budget_by_size[record.size] += total_budget + pattern_values = [ + { + "pattern_detectors": [ + detector + for bit_index, detector in enumerate(record.detectors) + if pattern_mask & (1 << bit_index) + ], + "u": float(solution[u_var[(subset_position, pattern_mask)]]), + "num_allowed_errors": len(record.available_patterns[pattern_mask]), + } + for pattern_mask in sorted(record.available_patterns) + if solution[u_var[(subset_position, pattern_mask)]] > 1e-12 + ] + top_subset_records.append( + { + "subset_detectors": list(record.detectors), + "subset_size": record.size, + "target_active_detectors": [ + detector + for bit_index, detector in enumerate(record.detectors) + if record.target_mask & (1 << bit_index) + ], + "y": y_value, + "total_budget": total_budget, + "num_available_patterns": len(record.available_patterns), + "num_feasible_resolution_combos": len(record.feasible_combos), + "patterns": pattern_values, + } + ) + + elapsed = time.perf_counter() - t0 + self.stats.lp_total_seconds += elapsed + + if self.logger is not None: + top_subset_records.sort(key=lambda item: (-item["y"], -item["total_budget"], item["subset_detectors"])) + payload = { + "call_index": self.stats.call_count, + "objective": total_objective, + "solve_seconds": elapsed, + "num_active_subsets": len(subset_records), + "num_active_subsets_by_size": { + str(size): active_subset_count_by_size[size] for size in sorted(active_subset_count_by_size) + }, + "num_components": len(component_to_subset_positions), + "num_variables": total_num_variables, + "num_constraints": total_num_constraints, + "contribution_by_subset_size": { + str(size): contribution_by_size[size] for size in sorted(contribution_by_size) + }, + "allocated_budget_by_subset_size": { + str(size): budget_by_size[size] for size in sorted(budget_by_size) + }, + "top_subsets": top_subset_records[: self.logger.top_k], + } + if context is not None: + payload.update(context) + self.logger.maybe_log(call_index=self.stats.call_count, payload=payload) + + return total_objective + +def compute_heuristic( + data: DecoderData, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, + *, + opt_subset_solver: Optional[SubsetLPHeuristic], + context: Optional[Dict[str, Any]] = None, +) -> float: + if opt_subset_solver is None: + return plain_detcost_heuristic( + data=data, + active_detectors=active_detectors, + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + ) + del active_detector_counts + return opt_subset_solver.evaluate( + active_detectors=active_detectors, + blocked_errors=blocked_errors, + context=context, + ) + + +def detectors_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + detectors = np.zeros(data.num_detectors, dtype=bool) + for error_index in activated_errors: + for detector in data.error_detectors[error_index]: + detectors[detector] ^= True + return detectors + + +def observables_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + observables = np.zeros(data.num_observables, dtype=bool) + for error_index in activated_errors: + for observable in data.error_observables[error_index]: + observables[observable] ^= True + return observables + + +def decode( + data: DecoderData, + detections: np.ndarray, + *, + det_beam: float = INF, + opt_subset_solver: Optional[SubsetLPHeuristic] = None, + verbose_search: bool = False, +) -> DecodeResult: + start_time = time.perf_counter() + initial_active_detectors = np.asarray(detections, dtype=bool).copy() + initial_counts = initial_detector_counts(data, initial_active_detectors) + initial_blocked = np.zeros(len(data.errors), dtype=bool) + initial_path_cost = 0.0 + initial_heuristic = compute_heuristic( + data=data, + active_detectors=initial_active_detectors, + blocked_errors=initial_blocked, + active_detector_counts=initial_counts, + opt_subset_solver=opt_subset_solver, + context={ + "phase": "initial", + "depth": 0, + "nodes_popped": 0, + "path_cost": 0.0, + "active_detector_count": int(initial_active_detectors.sum()), + }, + ) + if initial_heuristic == INF: + raise RuntimeError("Initial residual syndrome is infeasible under the current pruning rule.") + + initial_state = SearchState( + activated_errors=(), + blocked_errors=initial_blocked, + active_detectors=initial_active_detectors, + active_detector_counts=initial_counts, + path_cost=initial_path_cost, + ) + + priority_queue: List[Tuple[float, int, int, SearchState]] = [] + push_counter = 0 + initial_num_dets = int(initial_active_detectors.sum()) + heapq.heappush( + priority_queue, + (initial_path_cost + initial_heuristic, initial_num_dets, push_counter, initial_state), + ) + push_counter += 1 + + num_pq_pushed = 1 + num_nodes_popped = 0 + max_queue_size = 1 + min_num_dets = initial_num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + heuristic_name = ( + f"opt_subset_detcost_size_{opt_subset_solver.subset_library.max_subset_size}" + if opt_subset_solver is not None + else "plain_detcost" + ) + + while priority_queue: + max_queue_size = max(max_queue_size, len(priority_queue)) + f_cost, num_dets, _, state = heapq.heappop(priority_queue) + num_nodes_popped += 1 + + if num_dets > max_num_dets: + continue + + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + if verbose_search: + print( + f"nodes_popped={num_nodes_popped} len(pq)={len(priority_queue)} " + f"active_dets={num_dets} beam_max={max_num_dets} depth={len(state.activated_errors)} " + f"f={f_cost:.12g} g={state.path_cost:.12g}" + ) + + if num_dets == 0: + elapsed_seconds = time.perf_counter() - start_time + heuristic_calls = 0 if opt_subset_solver is None else opt_subset_solver.stats.call_count + lp_calls = 0 if opt_subset_solver is None else opt_subset_solver.stats.lp_call_count + lp_total_seconds = 0.0 if opt_subset_solver is None else opt_subset_solver.stats.lp_total_seconds + return DecodeResult( + activated_errors=state.activated_errors, + path_cost=state.path_cost, + stats=DecodeStats( + num_pq_pushed=num_pq_pushed, + num_nodes_popped=num_nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=heuristic_calls, + lp_calls=lp_calls, + lp_total_seconds=lp_total_seconds, + elapsed_seconds=elapsed_seconds, + heuristic_name=heuristic_name, + ), + ) + + min_detector = int(np.flatnonzero(state.active_detectors)[0]) + blocked_prefix = state.blocked_errors.copy() + for error_index in data.detector_to_errors[min_detector]: + blocked_prefix[error_index] = True + if state.blocked_errors[error_index]: + continue + + child_active_detectors, child_active_counts = apply_error( + data=data, + active_detectors=state.active_detectors, + active_detector_counts=state.active_detector_counts, + error_index=error_index, + ) + child_num_dets = int(child_active_detectors.sum()) + if child_num_dets > max_num_dets: + continue + + child_blocked = blocked_prefix.copy() + child_path_cost = state.path_cost + float(data.error_costs[error_index]) + child_heuristic = compute_heuristic( + data=data, + active_detectors=child_active_detectors, + blocked_errors=child_blocked, + active_detector_counts=child_active_counts, + opt_subset_solver=opt_subset_solver, + context={ + "phase": "child", + "depth": len(state.activated_errors) + 1, + "nodes_popped": num_nodes_popped, + "path_cost": child_path_cost, + "active_detector_count": child_num_dets, + "chosen_error": error_index, + "min_detector": min_detector, + }, + ) + if child_heuristic == INF: + continue + + child_state = SearchState( + activated_errors=state.activated_errors + (error_index,), + blocked_errors=child_blocked, + active_detectors=child_active_detectors, + active_detector_counts=child_active_counts, + path_cost=child_path_cost, + ) + heapq.heappush( + priority_queue, + ( + child_path_cost + child_heuristic, + child_num_dets, + push_counter, + child_state, + ), + ) + push_counter += 1 + num_pq_pushed += 1 + + raise RuntimeError("Decoding failed to find any completion.") + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim detector error models. " + "Supports plain detcost and subset-based LP lower bounds." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a Stim circuit file.") + parser.add_argument( + "--shot", + type=int, + default=0, + help="Zero-based sampled shot index to decode.", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot.", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Seed passed to stim.compile_detector_sampler(...).sample(...).", + ) + parser.add_argument( + "--det-beam", + type=parse_beam, + default=INF, + help="Beam cutoff on the residual detector count. Use an integer or 'inf'.", + ) + parser.add_argument( + "--opt-subset-detcost-size", + type=int, + default=0, + help=( + "Use the subset-based LP heuristic with library subsets of size at most N. " + "Use 0 for plain detcost, 1 for the optimal singleton LP, etc." + ), + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the sampled shot's active detector IDs before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the activated error indices in the final decoding.", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + parser.add_argument( + "--lp-log-path", + type=Path, + default=None, + help="Optional JSONL file for logging details of each subset-LP solve.", + ) + parser.add_argument( + "--lp-log-top-k", + type=int, + default=10, + help="When logging LP solves, include at most this many top subsets per solve.", + ) + parser.add_argument( + "--lp-log-every", + type=int, + default=1, + help="When logging LP solves, only write every k-th solve.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.opt_subset_detcost_size < 0: + parser.error("--opt-subset-detcost-size must be non-negative.") + if args.lp_log_every <= 0: + parser.error("--lp-log-every must be positive.") + if args.lp_log_top_k <= 0: + parser.error("--lp-log-top-k must be positive.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + data = build_decoder_data(dem, merge_errors_in_dem=args.merge_errors) + + subset_library = None + subset_solver = None + if args.opt_subset_detcost_size > 0: + subset_library = build_subset_library(data, args.opt_subset_detcost_size) + lp_logger = None + if args.lp_log_path is not None: + lp_logger = LPLogger( + args.lp_log_path, + every=args.lp_log_every, + top_k=args.lp_log_top_k, + ) + subset_solver = SubsetLPHeuristic(data, subset_library, logger=lp_logger) + + dets_packed, obs_packed = circuit.compile_detector_sampler(seed=args.seed).sample( + shots=args.sample_num_shots, + separate_observables=True, + bit_packed=True, + ) + detections = unpack_bit_packed_rows(dets_packed, count=dem.num_detectors) + observables = unpack_bit_packed_rows(obs_packed, count=dem.num_observables) + + if args.shot >= detections.shape[0]: + parser.error(f"--shot={args.shot} is out of range for {detections.shape[0]} sampled shots.") + + shot_detections = detections[args.shot] + shot_observables = observables[args.shot] if observables.size else np.zeros(0, dtype=bool) + + print(f"circuit = {args.circuit}") + print( + "heuristic = " + + ( + "plain_detcost" + if subset_solver is None + else f"opt_subset_detcost_size_{subset_library.max_subset_size}" + ) + ) + print(f"shot = {args.shot}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"num_detectors = {data.num_detectors}") + print(f"num_observables = {data.num_observables}") + print(f"num_errors = {len(data.errors)}") + print(f"beam = {args.det_beam}") + if subset_library is not None: + print(f"subset_library_size = {len(subset_library.entries)}") + print( + "subset_library_by_size = " + + ", ".join( + f"{size}:{count}" for size, count in subset_library.num_subsets_by_size.items() + ) + ) + if args.show_shot_detectors: + print(f"shot_detectors = {format_indices(np.flatnonzero(shot_detections), 'D')}") + + result = decode( + data=data, + detections=shot_detections, + det_beam=args.det_beam, + opt_subset_solver=subset_solver, + verbose_search=args.verbose_search, + ) + + predicted_observables = observables_from_solution(data, result.activated_errors) + reproduced_detectors = detectors_from_solution(data, result.activated_errors) + if not np.array_equal(reproduced_detectors, shot_detections): + raise AssertionError("Decoded error set does not reproduce the shot's syndrome.") + + print(f"solution_size = {len(result.activated_errors)}") + print(f"solution_cost = {result.path_cost:.12g}") + if args.show_error_indices: + print(f"activated_errors = {format_indices(result.activated_errors, 'E')}") + print(f"predicted_observables = {format_indices(np.flatnonzero(predicted_observables), 'L')}") + print(f"sample_observables = {format_indices(np.flatnonzero(shot_observables), 'L')}") + print(f"observables_match = {bool(np.array_equal(predicted_observables, shot_observables))}") + print(f"num_pq_pushed = {result.stats.num_pq_pushed}") + print(f"num_nodes_popped = {result.stats.num_nodes_popped}") + print(f"max_queue_size = {result.stats.max_queue_size}") + print(f"heuristic_calls = {result.stats.heuristic_calls}") + print(f"lp_calls = {result.stats.lp_calls}") + print(f"lp_total_seconds = {result.stats.lp_total_seconds:.6f}") + print(f"elapsed_seconds = {result.stats.elapsed_seconds:.6f}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_prototype_subset_detcost_lazy.py b/src/py/astar/astar_prototype_subset_detcost_lazy.py new file mode 100644 index 0000000..8bff114 --- /dev/null +++ b/src/py/astar/astar_prototype_subset_detcost_lazy.py @@ -0,0 +1,1314 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder with lazy subset-LP refinement. + +Heuristic modes: + --opt-subset-detcost-size 0 plain detcost + --opt-subset-detcost-size 1 lazy optimal singleton LP + --opt-subset-detcost-size 2 lazy optimal LP over size-1/2 subsets + --opt-subset-detcost-size 3 lazy optimal LP over size-1/2/3 subsets + +For subset size N > 0, the search uses lazy refinement: + * nodes are first inserted using a cheap lower bound; + * when popped, the exact subset LP is solved; + * if the exact LP raises the node key, the node is reinserted; + * expanded nodes project their exact subset-pattern prices onto children. + +The projection step is the main subtlety relative to the singleton case. +The exact parent LP stores prices u_{S,t} for subset/pattern pairs. For a child, +we keep inherited u_{S,t} values on patterns still available, zero out patterns +that have become unavailable, assign zero to newly active subsets, and recompute +child y_S values as the minimum cost of a feasible local signature decomposition +under those inherited prices. +""" + +from __future__ import annotations + +import argparse +import heapq +import itertools +import json +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy import sparse +from scipy.optimize import linprog + +INF = math.inf +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class MergedError: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class DecoderData: + num_detectors: int + num_observables: int + errors: List[MergedError] + detector_to_errors: List[List[int]] + error_costs: np.ndarray + error_detectors: List[Tuple[int, ...]] + error_detector_sets: List[frozenset[int]] + error_observables: List[Tuple[int, ...]] + + +@dataclass(frozen=True) +class SubsetLibraryEntry: + subset_id: int + detectors: Tuple[int, ...] + pattern_to_errors: Dict[int, Tuple[int, ...]] + resolution_combos: Dict[int, Tuple[Tuple[int, ...], ...]] + + +@dataclass +class ActiveSubsetRecord: + subset_id: int + detectors: Tuple[int, ...] + size: int + target_mask: int + available_patterns: Dict[int, Tuple[int, ...]] + feasible_combos: Tuple[Tuple[int, ...], ...] + + +@dataclass +class SearchState: + activated_errors: Tuple[int, ...] + blocked_errors: np.ndarray + active_detectors: np.ndarray + active_detector_counts: np.ndarray + path_cost: float + heuristic_cost: float + heuristic_source: str + exact_refined: bool + lp_solution: Optional["SubsetLPSolution"] = None + + +@dataclass +class DecodeStats: + num_pq_pushed: int + num_nodes_popped: int + max_queue_size: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + exact_refinement_calls: int + lp_calls: int + lp_reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_lp_refinement_gain: float + max_lp_refinement_gain: float + lp_total_seconds: float + elapsed_seconds: float + heuristic_name: str + + +@dataclass +class DecodeResult: + activated_errors: Tuple[int, ...] + path_cost: float + stats: DecodeStats + + +class UnionFind: + def __init__(self, size: int) -> None: + self.parent = list(range(size)) + self.rank = [0] * size + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class LPLogger: + def __init__(self, path: Path, *, every: int = 1, top_k: int = 10) -> None: + self.path = path + self.every = max(1, every) + self.top_k = max(1, top_k) + self.path.parent.mkdir(parents=True, exist_ok=True) + self.path.write_text("") + + def maybe_log(self, *, call_index: int, payload: Dict[str, Any]) -> None: + if call_index % self.every != 0: + return + with self.path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, sort_keys=True) + "\n") + + +@dataclass +class SubsetLibrary: + max_subset_size: int + entries: List[SubsetLibraryEntry] + subsets_by_detector: List[List[int]] + num_subsets_by_size: Dict[int, int] + + +@dataclass +class SubsetLPSolution: + value: float + subset_u_values: Dict[int, Dict[int, float]] + num_active_subsets: int + num_components: int + num_variables: int + num_constraints: int + + +@dataclass +class SubsetLPSolverStats: + lp_calls: int = 0 + lp_total_seconds: float = 0.0 + + +class SubsetLPHeuristic: + def __init__( + self, + data: DecoderData, + subset_library: SubsetLibrary, + *, + logger: Optional[LPLogger] = None, + ) -> None: + self.data = data + self.subset_library = subset_library + self.logger = logger + self.stats = SubsetLPSolverStats() + + def reset_stats(self) -> None: + self.stats = SubsetLPSolverStats() + + def _collect_active_subset_records( + self, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + ) -> Tuple[Optional[List[ActiveSubsetRecord]], Optional[Dict[int, List[int]]]]: + active_subset_ids: set[int] = set() + for detector in np.flatnonzero(active_detectors): + active_subset_ids.update(self.subset_library.subsets_by_detector[int(detector)]) + + subset_records: List[ActiveSubsetRecord] = [] + error_to_subset_positions: Dict[int, List[int]] = defaultdict(list) + + for subset_id in sorted(active_subset_ids): + entry = self.subset_library.entries[subset_id] + target_mask = 0 + for bit_index, detector in enumerate(entry.detectors): + if active_detectors[detector]: + target_mask |= 1 << bit_index + if target_mask == 0: + continue + + available_patterns: Dict[int, Tuple[int, ...]] = {} + relevant_errors: set[int] = set() + for pattern_mask, error_indices in entry.pattern_to_errors.items(): + kept = tuple(error_index for error_index in error_indices if not blocked_errors[error_index]) + if kept: + available_patterns[pattern_mask] = kept + relevant_errors.update(kept) + + feasible_combos = tuple( + combo + for combo in entry.resolution_combos.get(target_mask, ()) + if all(pattern_mask in available_patterns for pattern_mask in combo) + ) + if not feasible_combos: + return None, None + + subset_position = len(subset_records) + subset_records.append( + ActiveSubsetRecord( + subset_id=subset_id, + detectors=entry.detectors, + size=len(entry.detectors), + target_mask=target_mask, + available_patterns=available_patterns, + feasible_combos=feasible_combos, + ) + ) + for error_index in sorted(relevant_errors): + error_to_subset_positions[error_index].append(subset_position) + + return subset_records, error_to_subset_positions + + def solve_exact( + self, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + ) -> Tuple[SubsetLPSolution, Dict[str, Any]]: + t0 = time.perf_counter() + subset_records, error_to_subset_positions = self._collect_active_subset_records( + active_detectors=active_detectors, + blocked_errors=blocked_errors, + ) + if subset_records is None: + elapsed = time.perf_counter() - t0 + self.stats.lp_total_seconds += elapsed + payload = { + "objective": INF, + "solve_seconds": elapsed, + "num_active_subsets": 0, + "num_components": 0, + "num_variables": 0, + "num_constraints": 0, + "num_active_subsets_by_size": {}, + "contribution_by_subset_size": {}, + "allocated_budget_by_subset_size": {}, + "top_subsets": [], + "structurally_infeasible": True, + } + return ( + SubsetLPSolution( + value=INF, + subset_u_values={}, + num_active_subsets=0, + num_components=0, + num_variables=0, + num_constraints=0, + ), + payload, + ) + + if not subset_records: + elapsed = time.perf_counter() - t0 + self.stats.lp_total_seconds += elapsed + payload = { + "objective": 0.0, + "solve_seconds": elapsed, + "num_active_subsets": 0, + "num_components": 0, + "num_variables": 0, + "num_constraints": 0, + "num_active_subsets_by_size": {}, + "contribution_by_subset_size": {}, + "allocated_budget_by_subset_size": {}, + "top_subsets": [], + "structurally_infeasible": False, + } + return ( + SubsetLPSolution( + value=0.0, + subset_u_values={}, + num_active_subsets=0, + num_components=0, + num_variables=0, + num_constraints=0, + ), + payload, + ) + + component_uf = UnionFind(len(subset_records)) + for subset_positions in error_to_subset_positions.values(): + for position in subset_positions[1:]: + component_uf.union(subset_positions[0], position) + component_to_subset_positions: Dict[int, List[int]] = defaultdict(list) + for subset_position in range(len(subset_records)): + component_to_subset_positions[component_uf.find(subset_position)].append(subset_position) + + total_objective = 0.0 + total_num_variables = 0 + total_num_constraints = 0 + subset_u_values: Dict[int, Dict[int, float]] = {} + contribution_by_size: Dict[int, float] = defaultdict(float) + budget_by_size: Dict[int, float] = defaultdict(float) + active_subset_count_by_size: Dict[int, int] = defaultdict(int) + top_subset_records: List[Dict[str, Any]] = [] + need_log_details = self.logger is not None + + for component_positions in component_to_subset_positions.values(): + y_var: Dict[int, int] = {} + u_var: Dict[Tuple[int, int], int] = {} + error_to_u_vars: Dict[int, List[int]] = defaultdict(list) + + next_var_index = 0 + for subset_position in component_positions: + y_var[subset_position] = next_var_index + next_var_index += 1 + for subset_position in component_positions: + record = subset_records[subset_position] + active_subset_count_by_size[record.size] += 1 + for pattern_mask, error_indices in sorted(record.available_patterns.items()): + variable_index = next_var_index + u_var[(subset_position, pattern_mask)] = variable_index + next_var_index += 1 + for error_index in error_indices: + error_to_u_vars[error_index].append(variable_index) + + row_indices: List[int] = [] + col_indices: List[int] = [] + values: List[float] = [] + rhs: List[float] = [] + + for error_index, variable_indices in sorted(error_to_u_vars.items()): + row = len(rhs) + rhs.append(float(self.data.error_costs[error_index])) + for variable_index in variable_indices: + row_indices.append(row) + col_indices.append(variable_index) + values.append(1.0) + + for subset_position in component_positions: + record = subset_records[subset_position] + y_index = y_var[subset_position] + for combo in record.feasible_combos: + row = len(rhs) + rhs.append(0.0) + row_indices.append(row) + col_indices.append(y_index) + values.append(1.0) + for pattern_mask in combo: + row_indices.append(row) + col_indices.append(u_var[(subset_position, pattern_mask)]) + values.append(-1.0) + + total_num_variables += next_var_index + total_num_constraints += len(rhs) + + a_ub = sparse.csr_matrix( + (values, (row_indices, col_indices)), + shape=(len(rhs), next_var_index), + dtype=np.float64, + ) + objective = np.zeros(next_var_index, dtype=np.float64) + for subset_position in component_positions: + objective[y_var[subset_position]] = -1.0 + + self.stats.lp_calls += 1 + result = linprog( + c=objective, + A_ub=a_ub, + b_ub=np.asarray(rhs, dtype=np.float64), + bounds=[(0.0, None)] * next_var_index, + method="highs", + ) + if not result.success: + raise RuntimeError( + f"subset detcost LP solve failed: status={result.status} message={result.message}" + ) + total_objective += float(-result.fun) + solution = np.asarray(result.x, dtype=np.float64) + + for subset_position in component_positions: + record = subset_records[subset_position] + subset_pattern_values: Dict[int, float] = {} + total_budget = 0.0 + for pattern_mask in sorted(record.available_patterns): + u_value = float(solution[u_var[(subset_position, pattern_mask)]]) + total_budget += u_value + if u_value > 1e-12: + subset_pattern_values[pattern_mask] = u_value + if subset_pattern_values: + subset_u_values[record.subset_id] = subset_pattern_values + + if need_log_details: + y_value = float(solution[y_var[subset_position]]) + contribution_by_size[record.size] += y_value + budget_by_size[record.size] += total_budget + pattern_values = [ + { + "pattern_detectors": [ + detector + for bit_index, detector in enumerate(record.detectors) + if pattern_mask & (1 << bit_index) + ], + "u": float(solution[u_var[(subset_position, pattern_mask)]]), + "num_allowed_errors": len(record.available_patterns[pattern_mask]), + } + for pattern_mask in sorted(record.available_patterns) + if solution[u_var[(subset_position, pattern_mask)]] > 1e-12 + ] + top_subset_records.append( + { + "subset_detectors": list(record.detectors), + "subset_size": record.size, + "target_active_detectors": [ + detector + for bit_index, detector in enumerate(record.detectors) + if record.target_mask & (1 << bit_index) + ], + "y": y_value, + "total_budget": total_budget, + "num_available_patterns": len(record.available_patterns), + "num_feasible_resolution_combos": len(record.feasible_combos), + "patterns": pattern_values, + } + ) + + elapsed = time.perf_counter() - t0 + self.stats.lp_total_seconds += elapsed + + if need_log_details: + top_subset_records.sort(key=lambda item: (-item["y"], -item["total_budget"], item["subset_detectors"])) + payload = { + "objective": total_objective, + "solve_seconds": elapsed, + "num_active_subsets": len(subset_records), + "num_components": len(component_to_subset_positions), + "num_variables": total_num_variables, + "num_constraints": total_num_constraints, + "num_active_subsets_by_size": { + str(size): active_subset_count_by_size[size] for size in sorted(active_subset_count_by_size) + }, + "contribution_by_subset_size": ( + {str(size): contribution_by_size[size] for size in sorted(contribution_by_size)} + if need_log_details + else {} + ), + "allocated_budget_by_subset_size": ( + {str(size): budget_by_size[size] for size in sorted(budget_by_size)} + if need_log_details + else {} + ), + "top_subsets": top_subset_records[: self.logger.top_k] if self.logger is not None else [], + "structurally_infeasible": False, + } + return ( + SubsetLPSolution( + value=total_objective, + subset_u_values=subset_u_values, + num_active_subsets=len(subset_records), + num_components=len(component_to_subset_positions), + num_variables=total_num_variables, + num_constraints=total_num_constraints, + ), + payload, + ) + + def project_from_parent( + self, + parent_solution: SubsetLPSolution, + child_active_detectors: np.ndarray, + child_blocked_errors: np.ndarray, + ) -> float: + total = 0.0 + active_subset_ids: set[int] = set() + for detector in np.flatnonzero(child_active_detectors): + active_subset_ids.update(self.subset_library.subsets_by_detector[int(detector)]) + + for subset_id in sorted(active_subset_ids): + entry = self.subset_library.entries[subset_id] + target_mask = 0 + for bit_index, detector in enumerate(entry.detectors): + if child_active_detectors[detector]: + target_mask |= 1 << bit_index + if target_mask == 0: + continue + + combos = entry.resolution_combos.get(target_mask, ()) + if not combos: + return INF + + parent_u = parent_solution.subset_u_values.get(subset_id, {}) + availability_cache: Dict[int, bool] = {} + best = INF + for combo in combos: + combo_sum = 0.0 + feasible = True + for pattern_mask in combo: + is_available = availability_cache.get(pattern_mask) + if is_available is None: + is_available = any( + not child_blocked_errors[error_index] + for error_index in entry.pattern_to_errors.get(pattern_mask, ()) + ) + availability_cache[pattern_mask] = is_available + if not is_available: + feasible = False + break + combo_sum += parent_u.get(pattern_mask, 0.0) + if feasible and combo_sum < best: + best = combo_sum + if best == INF: + return INF + total += best + + return total + + +def parse_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "+inf", "infinity", "+infinity"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("beam must be non-negative or 'inf'") + return float(value) + + +def format_indices(indices: Iterable[int], prefix: str) -> str: + items = list(indices) + if not items: + return "(none)" + return " ".join(f"{prefix}{i}" for i in items) + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1 - p1) + (1 - p0) * p1 + + +def iter_dem_errors(dem: stim.DetectorErrorModel) -> Iterable[MergedError]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + "This prototype assumes detector-error-model probabilities are in (0, 0.5)." + ) + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + yield MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors(dem: stim.DetectorErrorModel) -> List[MergedError]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors(dem): + key = (error.detectors, error.observables) + previous = errors_by_symptom.get(key) + if previous is None: + errors_by_symptom[key] = error.probability + else: + errors_by_symptom[key] = xor_probability(previous, error.probability) + + merged: List[MergedError] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + "Merged error has probability >= 0.5, which would give a non-positive cost." + ) + merged.append( + MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +def build_decoder_data( + dem: stim.DetectorErrorModel, + *, + merge_errors_in_dem: bool = True, +) -> DecoderData: + errors = merged_errors(dem) if merge_errors_in_dem else list(iter_dem_errors(dem)) + detector_to_errors: List[List[int]] = [[] for _ in range(dem.num_detectors)] + for ei, error in enumerate(errors): + for d in error.detectors: + detector_to_errors[d].append(ei) + return DecoderData( + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + errors=errors, + detector_to_errors=detector_to_errors, + error_costs=np.asarray([e.likelihood_cost for e in errors], dtype=np.float64), + error_detectors=[e.detectors for e in errors], + error_detector_sets=[frozenset(e.detectors) for e in errors], + error_observables=[e.observables for e in errors], + ) + + +def unpack_bit_packed_rows(bits: np.ndarray, count: int) -> np.ndarray: + return np.unpackbits(bits, bitorder="little", axis=1, count=count).astype(bool, copy=False) + + +def initial_detector_counts(data: DecoderData, active_detectors: np.ndarray) -> np.ndarray: + counts = np.zeros(len(data.errors), dtype=np.int32) + for d in np.flatnonzero(active_detectors): + for ei in data.detector_to_errors[int(d)]: + counts[ei] += 1 + return counts + + +def apply_error( + data: DecoderData, + active_detectors: np.ndarray, + active_detector_counts: np.ndarray, + error_index: int, +) -> Tuple[np.ndarray, np.ndarray]: + next_detectors = active_detectors.copy() + next_counts = active_detector_counts.copy() + for d in data.error_detectors[error_index]: + if next_detectors[d]: + next_detectors[d] = False + delta = -1 + else: + next_detectors[d] = True + delta = 1 + for other_error_index in data.detector_to_errors[d]: + next_counts[other_error_index] += delta + return next_detectors, next_counts + + +def plain_detcost_for_detector( + data: DecoderData, + detector: int, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + best = INF + for ei in data.detector_to_errors[detector]: + if blocked_errors[ei]: + continue + count = int(active_detector_counts[ei]) + assert count > 0 + candidate = float(data.error_costs[ei]) / count + if candidate < best: + best = candidate + return best + + +def plain_detcost_heuristic( + data: DecoderData, + active_detectors: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, +) -> float: + total = 0.0 + for d in np.flatnonzero(active_detectors): + det_cost = plain_detcost_for_detector( + data=data, + detector=int(d), + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + ) + if det_cost == INF: + return INF + total += det_cost + return total + + +def compute_minimal_resolution_combos( + available_pattern_masks: Iterable[int], + subset_size: int, +) -> Dict[int, Tuple[Tuple[int, ...], ...]]: + patterns = tuple(sorted(set(available_pattern_masks))) + combos_by_target: Dict[int, List[Tuple[int, ...]]] = { + target: [] for target in range(1, 1 << subset_size) + } + for r in range(1, min(len(patterns), subset_size) + 1): + for combo in itertools.combinations(patterns, r): + target_mask = 0 + for pattern_mask in combo: + target_mask ^= pattern_mask + if target_mask == 0: + continue + combo_set = set(combo) + existing = combos_by_target[target_mask] + keep = True + survivors: List[Tuple[int, ...]] = [] + for old_combo in existing: + old_set = set(old_combo) + if combo_set.issuperset(old_set): + keep = False + survivors.append(old_combo) + elif old_set.issuperset(combo_set): + continue + else: + survivors.append(old_combo) + if keep: + survivors.append(combo) + survivors.sort(key=lambda x: (len(x), x)) + combos_by_target[target_mask] = survivors + return { + target_mask: tuple(combos) + for target_mask, combos in combos_by_target.items() + if combos + } + + +def build_subset_library(data: DecoderData, max_subset_size: int) -> SubsetLibrary: + library_keys: set[Tuple[int, ...]] = set() + if max_subset_size >= 1: + for detector in range(data.num_detectors): + library_keys.add((detector,)) + + for detectors in data.error_detectors: + limit = min(max_subset_size, len(detectors)) + for subset_size in range(1, limit + 1): + for subset_detectors in itertools.combinations(detectors, subset_size): + library_keys.add(tuple(subset_detectors)) + + subsets_by_detector: List[List[int]] = [[] for _ in range(data.num_detectors)] + entries: List[SubsetLibraryEntry] = [] + num_subsets_by_size: Dict[int, int] = defaultdict(int) + + for subset_id, subset_detectors in enumerate(sorted(library_keys, key=lambda t: (len(t), t))): + pattern_to_errors: Dict[int, List[int]] = defaultdict(list) + for error_index, detector_set in enumerate(data.error_detector_sets): + pattern_mask = 0 + for bit_index, detector in enumerate(subset_detectors): + if detector in detector_set: + pattern_mask |= 1 << bit_index + if pattern_mask != 0: + pattern_to_errors[pattern_mask].append(error_index) + frozen_pattern_to_errors = { + pattern_mask: tuple(error_indices) + for pattern_mask, error_indices in pattern_to_errors.items() + } + entry = SubsetLibraryEntry( + subset_id=subset_id, + detectors=subset_detectors, + pattern_to_errors=frozen_pattern_to_errors, + resolution_combos=compute_minimal_resolution_combos( + available_pattern_masks=frozen_pattern_to_errors.keys(), + subset_size=len(subset_detectors), + ), + ) + entries.append(entry) + num_subsets_by_size[len(subset_detectors)] += 1 + for detector in subset_detectors: + subsets_by_detector[detector].append(subset_id) + + return SubsetLibrary( + max_subset_size=max_subset_size, + entries=entries, + subsets_by_detector=subsets_by_detector, + num_subsets_by_size=dict(sorted(num_subsets_by_size.items())), + ) + + +def detectors_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + detectors = np.zeros(data.num_detectors, dtype=bool) + for error_index in activated_errors: + for detector in data.error_detectors[error_index]: + detectors[detector] ^= True + return detectors + + +def observables_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + observables = np.zeros(data.num_observables, dtype=bool) + for error_index in activated_errors: + for observable in data.error_observables[error_index]: + observables[observable] ^= True + return observables + + +def decode( + data: DecoderData, + detections: np.ndarray, + *, + det_beam: float = INF, + opt_subset_solver: Optional[SubsetLPHeuristic] = None, + verbose_search: bool = False, +) -> DecodeResult: + start_time = time.perf_counter() + if opt_subset_solver is not None: + opt_subset_solver.reset_stats() + + heuristic_calls = 0 + plain_heuristic_calls = 0 + projection_heuristic_calls = 0 + exact_refinement_calls = 0 + lp_reinserts = 0 + projected_nodes_generated = 0 + projected_nodes_refined = 0 + total_lp_refinement_gain = 0.0 + max_lp_refinement_gain = 0.0 + + initial_active_detectors = np.asarray(detections, dtype=bool).copy() + initial_counts = initial_detector_counts(data, initial_active_detectors) + initial_blocked = np.zeros(len(data.errors), dtype=bool) + heuristic_calls += 1 + plain_heuristic_calls += 1 + initial_heuristic = plain_detcost_heuristic( + data=data, + active_detectors=initial_active_detectors, + blocked_errors=initial_blocked, + active_detector_counts=initial_counts, + ) + if initial_heuristic == INF: + raise RuntimeError("Initial residual syndrome is infeasible under the current pruning rule.") + + initial_state = SearchState( + activated_errors=(), + blocked_errors=initial_blocked, + active_detectors=initial_active_detectors, + active_detector_counts=initial_counts, + path_cost=0.0, + heuristic_cost=initial_heuristic, + heuristic_source="plain", + exact_refined=(opt_subset_solver is None), + lp_solution=None, + ) + + priority_queue: List[Tuple[float, int, int, SearchState]] = [] + push_counter = 0 + initial_num_dets = int(initial_active_detectors.sum()) + heapq.heappush( + priority_queue, + (initial_state.path_cost + initial_state.heuristic_cost, initial_num_dets, push_counter, initial_state), + ) + push_counter += 1 + + num_pq_pushed = 1 + num_nodes_popped = 0 + max_queue_size = 1 + min_num_dets = initial_num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + heuristic_name = ( + f"opt_subset_detcost_size_{opt_subset_solver.subset_library.max_subset_size}_lazy_projection" + if opt_subset_solver is not None + else "plain_detcost" + ) + + while priority_queue: + max_queue_size = max(max_queue_size, len(priority_queue)) + f_cost, num_dets, _, state = heapq.heappop(priority_queue) + num_nodes_popped += 1 + + if num_dets > max_num_dets: + continue + + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = INF if det_beam == INF else min_num_dets + det_beam + + if verbose_search: + print( + f"nodes_popped={num_nodes_popped} len(pq)={len(priority_queue)} " + f"lp_calls={0 if opt_subset_solver is None else opt_subset_solver.stats.lp_calls} " + f"lp_reinserts={lp_reinserts} proj_generated={projected_nodes_generated} " + f"proj_refined={projected_nodes_refined} " + f"proj_unrefined_so_far={projected_nodes_generated - projected_nodes_refined} " + f"active_dets={num_dets} beam_max={max_num_dets} depth={len(state.activated_errors)} " + f"f={f_cost:.12g} g={state.path_cost:.12g} h={state.heuristic_cost:.12g} " + f"h_source={state.heuristic_source} exact_refined={state.exact_refined}" + ) + + if num_dets == 0: + elapsed_seconds = time.perf_counter() - start_time + lp_calls = 0 if opt_subset_solver is None else opt_subset_solver.stats.lp_calls + lp_total_seconds = 0.0 if opt_subset_solver is None else opt_subset_solver.stats.lp_total_seconds + return DecodeResult( + activated_errors=state.activated_errors, + path_cost=state.path_cost, + stats=DecodeStats( + num_pq_pushed=num_pq_pushed, + num_nodes_popped=num_nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=heuristic_calls, + plain_heuristic_calls=plain_heuristic_calls, + projection_heuristic_calls=projection_heuristic_calls, + exact_refinement_calls=exact_refinement_calls, + lp_calls=lp_calls, + lp_reinserts=lp_reinserts, + projected_nodes_generated=projected_nodes_generated, + projected_nodes_refined=projected_nodes_refined, + projected_nodes_unrefined_at_finish=projected_nodes_generated - projected_nodes_refined, + total_lp_refinement_gain=total_lp_refinement_gain, + max_lp_refinement_gain=max_lp_refinement_gain, + lp_total_seconds=lp_total_seconds, + elapsed_seconds=elapsed_seconds, + heuristic_name=heuristic_name, + ), + ) + + if opt_subset_solver is not None and not state.exact_refined: + heuristic_calls += 1 + exact_refinement_calls += 1 + previous_h = state.heuristic_cost + previous_source = state.heuristic_source + exact_solution, exact_payload = opt_subset_solver.solve_exact( + active_detectors=state.active_detectors, + blocked_errors=state.blocked_errors, + ) + exact_h = exact_solution.value + reinserted = False + discarded = False + + if exact_h == INF: + discarded = True + if previous_source == "projected": + projected_nodes_refined += 1 + else: + if exact_h + 1e-7 < previous_h: + raise AssertionError( + f"Exact subset LP lower bound {exact_h} is below stored {previous_source} lower bound {previous_h}." + ) + delta = exact_h - previous_h + total_lp_refinement_gain += delta + max_lp_refinement_gain = max(max_lp_refinement_gain, delta) + state.heuristic_cost = exact_h + state.heuristic_source = "exact" + state.exact_refined = True + state.lp_solution = exact_solution + if previous_source == "projected": + projected_nodes_refined += 1 + if delta > HEURISTIC_EPS: + reinserted = True + lp_reinserts += 1 + heapq.heappush( + priority_queue, + (state.path_cost + state.heuristic_cost, num_dets, push_counter, state), + ) + push_counter += 1 + + if opt_subset_solver.logger is not None: + payload = dict(exact_payload) + payload.update( + { + "call_index": exact_refinement_calls, + "phase": "exact_refinement", + "depth": len(state.activated_errors), + "nodes_popped": num_nodes_popped, + "path_cost": state.path_cost, + "active_detector_count": num_dets, + "approx_h": previous_h, + "exact_h": exact_h, + "delta": INF if exact_h == INF else exact_h - previous_h, + "heuristic_source_before": previous_source, + "reinserted": reinserted, + "discarded": discarded, + } + ) + opt_subset_solver.logger.maybe_log(call_index=exact_refinement_calls, payload=payload) + + if verbose_search: + delta_text = "INF" if exact_h == INF else f"{exact_h - previous_h:.12g}" + exact_text = "INF" if exact_h == INF else f"{exact_h:.12g}" + print( + f" lp_refine approx_h={previous_h:.12g} exact_h={exact_text} delta={delta_text} " + f"vars={exact_solution.num_variables} constraints={exact_solution.num_constraints} " + f"active_subsets={exact_solution.num_active_subsets} comps={exact_solution.num_components} " + f"reinserted={reinserted} discarded={discarded}" + ) + + if discarded or reinserted: + continue + + min_detector = int(np.flatnonzero(state.active_detectors)[0]) + blocked_prefix = state.blocked_errors.copy() + children_generated = 0 + children_projected = 0 + children_beam_pruned = 0 + children_infeasible = 0 + + for error_index in data.detector_to_errors[min_detector]: + blocked_prefix[error_index] = True + if state.blocked_errors[error_index]: + continue + + child_active_detectors, child_active_counts = apply_error( + data=data, + active_detectors=state.active_detectors, + active_detector_counts=state.active_detector_counts, + error_index=error_index, + ) + child_num_dets = int(child_active_detectors.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_blocked = blocked_prefix.copy() + child_path_cost = state.path_cost + float(data.error_costs[error_index]) + + if opt_subset_solver is None: + heuristic_calls += 1 + plain_heuristic_calls += 1 + child_heuristic = plain_detcost_heuristic( + data=data, + active_detectors=child_active_detectors, + blocked_errors=child_blocked, + active_detector_counts=child_active_counts, + ) + child_source = "plain" + child_exact_refined = True + child_lp_solution = None + else: + if state.lp_solution is None: + raise AssertionError("Subset-LP projection requires an exact-refined parent solution.") + heuristic_calls += 1 + projection_heuristic_calls += 1 + projected_nodes_generated += 1 + children_projected += 1 + child_heuristic = opt_subset_solver.project_from_parent( + parent_solution=state.lp_solution, + child_active_detectors=child_active_detectors, + child_blocked_errors=child_blocked, + ) + child_source = "projected" + child_exact_refined = False + child_lp_solution = None + + if child_heuristic == INF: + children_infeasible += 1 + continue + + child_state = SearchState( + activated_errors=state.activated_errors + (error_index,), + blocked_errors=child_blocked, + active_detectors=child_active_detectors, + active_detector_counts=child_active_counts, + path_cost=child_path_cost, + heuristic_cost=child_heuristic, + heuristic_source=child_source, + exact_refined=child_exact_refined, + lp_solution=child_lp_solution, + ) + heapq.heappush( + priority_queue, + (child_path_cost + child_heuristic, child_num_dets, push_counter, child_state), + ) + push_counter += 1 + num_pq_pushed += 1 + children_generated += 1 + + if verbose_search: + print( + f" expanded children_generated={children_generated} children_projected={children_projected} " + f"beam_pruned={children_beam_pruned} infeasible={children_infeasible}" + ) + + raise RuntimeError("Decoding failed to find any completion.") + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim detector error models. " + "Supports plain detcost and lazy subset-based LP lower bounds." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a Stim circuit file.") + parser.add_argument( + "--shot", + type=int, + default=0, + help="Zero-based sampled shot index to decode.", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample before selecting --shot.", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Seed passed to stim.compile_detector_sampler(...).sample(...).", + ) + parser.add_argument( + "--det-beam", + type=parse_beam, + default=INF, + help="Beam cutoff on the residual detector count. Use an integer or 'inf'.", + ) + parser.add_argument( + "--opt-subset-detcost-size", + type=int, + default=0, + help=( + "Use the lazy subset-based LP heuristic with library subsets of size at most N. " + "Use 0 for plain detcost, 1 for the optimal singleton LP, etc." + ), + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the sampled shot's active detector IDs before decoding.", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the activated error indices in the final decoding.", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + parser.add_argument( + "--lp-log-path", + type=Path, + default=None, + help="Optional JSONL file for logging details of each exact subset-LP refinement.", + ) + parser.add_argument( + "--lp-log-top-k", + type=int, + default=10, + help="When logging exact LP refinements, include at most this many top subsets.", + ) + parser.add_argument( + "--lp-log-every", + type=int, + default=1, + help="When logging exact LP refinements, only write every k-th refinement.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.opt_subset_detcost_size < 0: + parser.error("--opt-subset-detcost-size must be non-negative.") + if args.lp_log_every <= 0: + parser.error("--lp-log-every must be positive.") + if args.lp_log_top_k <= 0: + parser.error("--lp-log-top-k must be positive.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + data = build_decoder_data(dem, merge_errors_in_dem=args.merge_errors) + + subset_library = None + subset_solver = None + if args.opt_subset_detcost_size > 0: + subset_library = build_subset_library(data, args.opt_subset_detcost_size) + lp_logger = None + if args.lp_log_path is not None: + lp_logger = LPLogger( + args.lp_log_path, + every=args.lp_log_every, + top_k=args.lp_log_top_k, + ) + subset_solver = SubsetLPHeuristic(data, subset_library, logger=lp_logger) + + dets_packed, obs_packed = circuit.compile_detector_sampler(seed=args.seed).sample( + shots=args.sample_num_shots, + separate_observables=True, + bit_packed=True, + ) + detections = unpack_bit_packed_rows(dets_packed, count=dem.num_detectors) + observables = unpack_bit_packed_rows(obs_packed, count=dem.num_observables) + + if args.shot >= detections.shape[0]: + parser.error(f"--shot={args.shot} is out of range for {detections.shape[0]} sampled shots.") + + shot_detections = detections[args.shot] + shot_observables = observables[args.shot] if observables.size else np.zeros(0, dtype=bool) + + print(f"circuit = {args.circuit}") + print( + "heuristic = " + + ( + "plain_detcost" + if subset_solver is None + else f"opt_subset_detcost_size_{subset_library.max_subset_size}_lazy_projection" + ) + ) + print(f"shot = {args.shot}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"num_detectors = {data.num_detectors}") + print(f"num_observables = {data.num_observables}") + print(f"num_errors = {len(data.errors)}") + print(f"beam = {args.det_beam}") + if subset_library is not None: + print(f"subset_library_size = {len(subset_library.entries)}") + print( + "subset_library_by_size = " + + ", ".join( + f"{size}:{count}" for size, count in subset_library.num_subsets_by_size.items() + ) + ) + if args.show_shot_detectors: + print(f"shot_detectors = {format_indices(np.flatnonzero(shot_detections), 'D')}") + + result = decode( + data=data, + detections=shot_detections, + det_beam=args.det_beam, + opt_subset_solver=subset_solver, + verbose_search=args.verbose_search, + ) + + predicted_observables = observables_from_solution(data, result.activated_errors) + reproduced_detectors = detectors_from_solution(data, result.activated_errors) + if not np.array_equal(reproduced_detectors, shot_detections): + raise AssertionError("Decoded error set does not reproduce the shot's syndrome.") + + print(f"solution_size = {len(result.activated_errors)}") + print(f"solution_cost = {result.path_cost:.12g}") + if args.show_error_indices: + print(f"activated_errors = {format_indices(result.activated_errors, 'E')}") + print(f"predicted_observables = {format_indices(np.flatnonzero(predicted_observables), 'L')}") + print(f"sample_observables = {format_indices(np.flatnonzero(shot_observables), 'L')}") + print(f"observables_match = {bool(np.array_equal(predicted_observables, shot_observables))}") + print(f"num_pq_pushed = {result.stats.num_pq_pushed}") + print(f"num_nodes_popped = {result.stats.num_nodes_popped}") + print(f"max_queue_size = {result.stats.max_queue_size}") + print(f"heuristic_calls = {result.stats.heuristic_calls}") + print(f"plain_heuristic_calls = {result.stats.plain_heuristic_calls}") + print(f"projection_heuristic_calls = {result.stats.projection_heuristic_calls}") + print(f"exact_refinement_calls = {result.stats.exact_refinement_calls}") + print(f"lp_calls = {result.stats.lp_calls}") + print(f"lp_reinserts = {result.stats.lp_reinserts}") + print(f"projected_nodes_generated = {result.stats.projected_nodes_generated}") + print(f"projected_nodes_refined = {result.stats.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish = {result.stats.projected_nodes_unrefined_at_finish}") + print(f"total_lp_refinement_gain = {result.stats.total_lp_refinement_gain:.12g}") + print(f"max_lp_refinement_gain = {result.stats.max_lp_refinement_gain:.12g}") + print(f"lp_total_seconds = {result.stats.lp_total_seconds:.6f}") + print(f"elapsed_seconds = {result.stats.elapsed_seconds:.6f}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_singleton_lp_probe.py b/src/py/astar/astar_singleton_lp_probe.py new file mode 100644 index 0000000..8b321e0 --- /dev/null +++ b/src/py/astar/astar_singleton_lp_probe.py @@ -0,0 +1,1276 @@ +#!/usr/bin/env python3 +"""Instrumented A* prototype for studying the optimal singleton LP heuristic. + +This script is intentionally data-heavy and not heavily optimized. It decodes a +set of Stim circuits, samples several shots from each, and writes detailed logs +about every heuristic evaluation during search. + +Outputs (written under --output-dir): + manifest.json + shot_summaries.jsonl + node_summaries.jsonl.gz + component_summaries.jsonl.gz + sampled_instances.jsonl.gz + +The node/component logs are designed to answer questions such as: + * How often is the singleton LP graphlike (all distinct supports have size <= 2)? + * How many connected components does the residual support hypergraph have? + * How many raw allowed errors collapse to the same distinct active support? + * How sparse are primal/dual LP solutions? + * Are graphlike components common enough to justify a specialized solver? + +The search tree uses the same precedence-style pruning idea as the prototype and +Tesseract paper: at each node, only errors incident to the minimum active +residual detector are expanded, with earlier siblings blocked to keep a unique +path ordering. The A* heuristic can be plain detcost or the optimal singleton +LP; both values are logged for every created node. +""" + +from __future__ import annotations + +import argparse +import gzip +import heapq +import json +import math +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy import sparse +from scipy.optimize import linprog + +INF = math.inf +JSON_SEPARATORS = (",", ":") +LP_TOL = 1e-9 +RATIONAL_TOL = 1e-7 + + +@dataclass(frozen=True) +class MergedError: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class DecoderData: + num_detectors: int + num_observables: int + errors: List[MergedError] + detector_to_errors: List[List[int]] + error_costs: np.ndarray + error_detectors: List[Tuple[int, ...]] + error_observables: List[Tuple[int, ...]] + + +@dataclass +class SearchSettings: + det_beam: float + search_heuristic: str + respect_blocked_errors_in_heuristic: bool + max_nodes_popped: Optional[int] + max_nodes_pushed: Optional[int] + sample_raw_nodes_per_shot: int + verbose_search: bool + + +@dataclass +class SearchState: + node_id: int + parent_node_id: Optional[int] + incoming_error_index: Optional[int] + depth: int + activated_errors: Tuple[int, ...] + activated_error_mask: np.ndarray + blocked_errors: np.ndarray + active_detectors: np.ndarray + active_detector_counts: np.ndarray + path_cost: float + search_h: float + plain_h: float + opt_h: float + + +class JsonlWriter: + def __init__(self, path: Path, *, gz: bool = False): + self.path = path + path.parent.mkdir(parents=True, exist_ok=True) + if gz: + self.file = gzip.open(path, "wt", encoding="utf-8") + else: + self.file = open(path, "wt", encoding="utf-8") + + def write(self, record: Dict[str, Any]) -> None: + self.file.write(json.dumps(record, separators=JSON_SEPARATORS, sort_keys=True)) + self.file.write("\n") + + def flush(self) -> None: + self.file.flush() + + def close(self) -> None: + self.file.close() + + +class UnionFind: + def __init__(self, size: int) -> None: + self.parent = list(range(size)) + self.rank = [0] * size + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class ShotAggregator: + def __init__(self) -> None: + self.nodes_created = 0 + self.nodes_pushed = 0 + self.nodes_infeasible = 0 + self.nodes_graphlike = 0 + self.nodes_with_lp = 0 + self.total_plain_h = 0.0 + self.total_opt_h = 0.0 + self.total_h_gain = 0.0 + self.total_lp_time_sec = 0.0 + self.total_lp_vars = 0 + self.total_lp_constraints = 0 + self.total_raw_allowed_errors = 0 + self.total_distinct_supports = 0 + self.total_components = 0 + self.total_graphlike_components = 0 + self.max_active_detectors = 0 + self.max_distinct_supports = 0 + self.max_component_variables = 0 + self.max_component_constraints = 0 + + def absorb_node(self, node_record: Dict[str, Any]) -> None: + self.nodes_created += 1 + self.nodes_pushed += int(bool(node_record["pushed"])) + self.nodes_infeasible += int(bool(node_record["opt_infeasible"])) + self.nodes_graphlike += int(bool(node_record["graphlike_all_components"])) + self.nodes_with_lp += int(node_record["lp_calls"] > 0) + self.total_plain_h += float(node_record["plain_h"]) + if not node_record["opt_infeasible"]: + self.total_opt_h += float(node_record["opt_h"]) + self.total_h_gain += float(node_record["opt_h"] - node_record["plain_h"]) + self.total_lp_time_sec += float(node_record["lp_time_sec"]) + self.total_lp_vars += int(node_record["total_lp_vars"]) + self.total_lp_constraints += int(node_record["total_lp_constraints"]) + self.total_raw_allowed_errors += int(node_record["raw_allowed_errors"]) + self.total_distinct_supports += int(node_record["distinct_supports"]) + self.total_components += int(node_record["num_components"]) + self.total_graphlike_components += int(node_record["num_graphlike_components"]) + self.max_active_detectors = max(self.max_active_detectors, int(node_record["num_active_detectors"])) + self.max_distinct_supports = max(self.max_distinct_supports, int(node_record["distinct_supports"])) + self.max_component_variables = max(self.max_component_variables, int(node_record["max_component_variables"])) + self.max_component_constraints = max(self.max_component_constraints, int(node_record["max_component_constraints"])) + + def finish(self, *, nodes_popped: int, status: str, elapsed_seconds: float) -> Dict[str, Any]: + n = max(self.nodes_created, 1) + c = max(self.total_components, 1) + return { + "status": status, + "nodes_created": self.nodes_created, + "nodes_pushed": self.nodes_pushed, + "nodes_popped": nodes_popped, + "nodes_infeasible": self.nodes_infeasible, + "graphlike_node_fraction": self.nodes_graphlike / n, + "mean_plain_h": self.total_plain_h / n, + "mean_opt_h_over_feasible": (self.total_opt_h / max(self.nodes_created - self.nodes_infeasible, 1)), + "mean_opt_minus_plain_over_feasible": (self.total_h_gain / max(self.nodes_created - self.nodes_infeasible, 1)), + "total_lp_time_sec": self.total_lp_time_sec, + "mean_lp_time_per_created_node_sec": self.total_lp_time_sec / n, + "mean_lp_vars_per_created_node": self.total_lp_vars / n, + "mean_lp_constraints_per_created_node": self.total_lp_constraints / n, + "mean_raw_allowed_errors": self.total_raw_allowed_errors / n, + "mean_distinct_supports": self.total_distinct_supports / n, + "mean_components": self.total_components / n, + "graphlike_component_fraction": self.total_graphlike_components / c, + "max_active_detectors": self.max_active_detectors, + "max_distinct_supports": self.max_distinct_supports, + "max_component_variables": self.max_component_variables, + "max_component_constraints": self.max_component_constraints, + "elapsed_seconds": elapsed_seconds, + } + + +class NodeSampler: + def __init__(self, sample_raw_nodes_per_shot: int): + self.sample_raw_nodes_per_shot = sample_raw_nodes_per_shot + self.seen = 0 + + def should_sample(self, node_id: int) -> bool: + del node_id + if self.seen < self.sample_raw_nodes_per_shot: + self.seen += 1 + return True + return False + + +class ProbeLogger: + def __init__(self, output_dir: Path): + self.output_dir = output_dir + self.shot_writer = JsonlWriter(output_dir / "shot_summaries.jsonl", gz=False) + self.node_writer = JsonlWriter(output_dir / "node_summaries.jsonl.gz", gz=True) + self.component_writer = JsonlWriter(output_dir / "component_summaries.jsonl.gz", gz=True) + self.sample_writer = JsonlWriter(output_dir / "sampled_instances.jsonl.gz", gz=True) + + def close(self) -> None: + self.shot_writer.close() + self.node_writer.close() + self.component_writer.close() + self.sample_writer.close() + + def flush(self) -> None: + self.shot_writer.flush() + self.node_writer.flush() + self.component_writer.flush() + self.sample_writer.flush() + + +def parse_optional_int(text: str) -> Optional[int]: + lowered = text.strip().lower() + if lowered in {"none", "inf", "infinity", "+inf", "+infinity"}: + return None + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("must be non-negative or one of: none, inf") + return value + + +def parse_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "+inf", "+infinity"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("beam must be non-negative or 'inf'") + return float(value) + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1 - p1) + (1 - p0) * p1 + + +def iter_dem_errors(dem: stim.DetectorErrorModel) -> Iterable[MergedError]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + "This prototype assumes detector-error-model probabilities are in (0, 0.5)." + ) + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + assert target.is_relative_detector_id() + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + yield MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors(dem: stim.DetectorErrorModel) -> List[MergedError]: + probabilities: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors(dem): + key = (error.detectors, error.observables) + prev = probabilities.get(key) + probabilities[key] = error.probability if prev is None else xor_probability(prev, error.probability) + + out: List[MergedError] = [] + for (detectors, observables), probability in probabilities.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError("Merged error has probability >= 0.5.") + out.append( + MergedError( + probability=probability, + likelihood_cost=float(-math.log(probability / (1 - probability))), + detectors=detectors, + observables=observables, + ) + ) + return out + + +def build_decoder_data(dem: stim.DetectorErrorModel, *, merge_errors_in_dem: bool = True) -> DecoderData: + errors = merged_errors(dem) if merge_errors_in_dem else list(iter_dem_errors(dem)) + detector_to_errors: List[List[int]] = [[] for _ in range(dem.num_detectors)] + for ei, error in enumerate(errors): + for d in error.detectors: + detector_to_errors[d].append(ei) + return DecoderData( + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + errors=errors, + detector_to_errors=detector_to_errors, + error_costs=np.asarray([e.likelihood_cost for e in errors], dtype=np.float64), + error_detectors=[e.detectors for e in errors], + error_observables=[e.observables for e in errors], + ) + + +def unpack_bit_packed_rows(bits: np.ndarray, count: int) -> np.ndarray: + return np.unpackbits(bits, bitorder="little", axis=1, count=count).astype(bool, copy=False) + + +def initial_detector_counts(data: DecoderData, active_detectors: np.ndarray) -> np.ndarray: + counts = np.zeros(len(data.errors), dtype=np.int32) + for d in np.flatnonzero(active_detectors): + for ei in data.detector_to_errors[int(d)]: + counts[ei] += 1 + return counts + + +def apply_error( + data: DecoderData, + active_detectors: np.ndarray, + active_detector_counts: np.ndarray, + error_index: int, +) -> Tuple[np.ndarray, np.ndarray]: + next_detectors = active_detectors.copy() + next_counts = active_detector_counts.copy() + for d in data.error_detectors[error_index]: + if next_detectors[d]: + next_detectors[d] = False + delta = -1 + else: + next_detectors[d] = True + delta = 1 + for other_error_index in data.detector_to_errors[d]: + next_counts[other_error_index] += delta + return next_detectors, next_counts + + +def plain_detcost_for_detector( + data: DecoderData, + detector: int, + *, + activated_error_mask: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, + respect_blocked_errors_in_heuristic: bool, +) -> float: + best = INF + for ei in data.detector_to_errors[detector]: + if respect_blocked_errors_in_heuristic: + if blocked_errors[ei]: + continue + else: + if activated_error_mask[ei]: + continue + count = int(active_detector_counts[ei]) + assert count > 0 + candidate = float(data.error_costs[ei]) / count + if candidate < best: + best = candidate + return best + + +def plain_detcost_heuristic( + data: DecoderData, + active_detectors: np.ndarray, + *, + activated_error_mask: np.ndarray, + blocked_errors: np.ndarray, + active_detector_counts: np.ndarray, + respect_blocked_errors_in_heuristic: bool, +) -> float: + total = 0.0 + for d in np.flatnonzero(active_detectors): + det_cost = plain_detcost_for_detector( + data=data, + detector=int(d), + activated_error_mask=activated_error_mask, + blocked_errors=blocked_errors, + active_detector_counts=active_detector_counts, + respect_blocked_errors_in_heuristic=respect_blocked_errors_in_heuristic, + ) + if det_cost == INF: + return INF + total += det_cost + return total + + +def grid_fraction(values: np.ndarray, denominator: int, tol: float = RATIONAL_TOL) -> float: + if values.size == 0: + return 0.0 + scaled = denominator * values + return float(np.mean(np.abs(scaled - np.round(scaled)) <= tol)) + + +@dataclass +class LPProbeResult: + opt_h: float + node_record: Dict[str, Any] + component_records: List[Dict[str, Any]] + sample_record: Optional[Dict[str, Any]] + + +def probe_opt_singleton_lp( + *, + run_id: str, + circuit_name: str, + shot_index: int, + state: SearchState, + data: DecoderData, + settings: SearchSettings, + plain_h: float, + sample_raw_instance: bool, +) -> LPProbeResult: + active_detector_ids = np.flatnonzero(state.active_detectors) + num_active_detectors = int(active_detector_ids.size) + global_to_local = np.full(data.num_detectors, -1, dtype=np.int32) + global_to_local[active_detector_ids] = np.arange(num_active_detectors, dtype=np.int32) + + support_to_cost: Dict[Tuple[int, ...], float] = {} + support_to_multiplicity: Dict[Tuple[int, ...], int] = {} + covered = np.zeros(num_active_detectors, dtype=bool) + + raw_allowed_errors = 0 + raw_support_size_hist = {"1": 0, "2": 0, "3": 0, "4+": 0} + + for ei, error_detectors in enumerate(data.error_detectors): + if settings.respect_blocked_errors_in_heuristic: + if state.blocked_errors[ei]: + continue + else: + if state.activated_error_mask[ei]: + continue + + count = int(state.active_detector_counts[ei]) + if count == 0: + continue + support = tuple(int(global_to_local[d]) for d in error_detectors if state.active_detectors[d]) + assert support + raw_allowed_errors += 1 + size = len(support) + if size == 1: + raw_support_size_hist["1"] += 1 + elif size == 2: + raw_support_size_hist["2"] += 1 + elif size == 3: + raw_support_size_hist["3"] += 1 + else: + raw_support_size_hist["4+"] += 1 + covered[list(support)] = True + support_to_multiplicity[support] = support_to_multiplicity.get(support, 0) + 1 + cost = float(data.error_costs[ei]) + prev = support_to_cost.get(support) + if prev is None or cost < prev: + support_to_cost[support] = cost + + distinct_support_size_hist = {"1": 0, "2": 0, "3": 0, "4+": 0} + for support in support_to_cost: + size = len(support) + if size == 1: + distinct_support_size_hist["1"] += 1 + elif size == 2: + distinct_support_size_hist["2"] += 1 + elif size == 3: + distinct_support_size_hist["3"] += 1 + else: + distinct_support_size_hist["4+"] += 1 + + uncovered_count = int(np.count_nonzero(~covered)) + base_node_record: Dict[str, Any] = { + "run_id": run_id, + "circuit": circuit_name, + "shot": shot_index, + "node_id": state.node_id, + "parent_node_id": state.parent_node_id, + "incoming_error_index": state.incoming_error_index, + "depth": state.depth, + "num_active_detectors": num_active_detectors, + "path_cost": state.path_cost, + "plain_h": plain_h, + "raw_allowed_errors": raw_allowed_errors, + "raw_support_hist": raw_support_size_hist, + "distinct_supports": len(support_to_cost), + "distinct_support_hist": distinct_support_size_hist, + "support_multiplicity_mean": (float(np.mean(list(support_to_multiplicity.values()))) if support_to_multiplicity else 0.0), + "support_multiplicity_max": (max(support_to_multiplicity.values()) if support_to_multiplicity else 0), + "uncovered_active_detectors": uncovered_count, + } + + if uncovered_count > 0: + base_node_record.update( + { + "opt_h": INF, + "opt_infeasible": True, + "lp_calls": 0, + "lp_time_sec": 0.0, + "total_lp_vars": 0, + "total_lp_constraints": 0, + "num_components": 0, + "num_graphlike_components": 0, + "graphlike_all_components": False, + "max_support_size": 0, + "max_component_variables": 0, + "max_component_constraints": 0, + "positive_y_count": 0, + "tight_constraint_count": 0, + "positive_dual_count": 0, + } + ) + sample_record = None + if sample_raw_instance: + sample_record = { + "run_id": run_id, + "circuit": circuit_name, + "shot": shot_index, + "node_id": state.node_id, + "parent_node_id": state.parent_node_id, + "depth": state.depth, + "opt_infeasible": True, + "active_detector_ids": active_detector_ids.tolist(), + "supports": [ + { + "local_support": list(support), + "global_support": [int(active_detector_ids[i]) for i in support], + "cost": support_to_cost[support], + "multiplicity": support_to_multiplicity[support], + } + for support in sorted(support_to_cost) + ], + } + return LPProbeResult( + opt_h=INF, + node_record=base_node_record, + component_records=[], + sample_record=sample_record, + ) + + union_find = UnionFind(num_active_detectors) + for support in support_to_cost: + first = support[0] + for detector in support[1:]: + union_find.union(first, detector) + + detectors_by_root: Dict[int, List[int]] = {} + for detector in range(num_active_detectors): + root = union_find.find(detector) + detectors_by_root.setdefault(root, []).append(detector) + + supports_by_root: Dict[int, List[Tuple[Tuple[int, ...], float, int]]] = {} + for support, cost in support_to_cost.items(): + root = union_find.find(support[0]) + supports_by_root.setdefault(root, []).append((support, cost, support_to_multiplicity[support])) + + component_records: List[Dict[str, Any]] = [] + sample_components: List[Dict[str, Any]] = [] + total_opt_h = 0.0 + total_lp_time = 0.0 + total_lp_vars = 0 + total_lp_constraints = 0 + total_positive_y = 0 + total_tight_constraints = 0 + total_positive_dual = 0 + num_graphlike_components = 0 + max_component_variables = 0 + max_component_constraints = 0 + max_support_size = max((len(support) for support in support_to_cost), default=0) + + for component_index, (root, component_detectors) in enumerate(sorted(detectors_by_root.items())): + local_reindex = {detector: i for i, detector in enumerate(component_detectors)} + component_supports = supports_by_root[root] + num_vars = len(component_detectors) + num_constraints = len(component_supports) + max_component_variables = max(max_component_variables, num_vars) + max_component_constraints = max(max_component_constraints, num_constraints) + total_lp_vars += num_vars + total_lp_constraints += num_constraints + + row_indices: List[int] = [] + col_indices: List[int] = [] + values: List[float] = [] + rhs = np.empty(num_constraints, dtype=np.float64) + component_global_supports: List[List[int]] = [] + support_sizes = np.empty(num_constraints, dtype=np.int32) + multiplicities = np.empty(num_constraints, dtype=np.int32) + + graphlike = True + support_size_hist = {"1": 0, "2": 0, "3": 0, "4+": 0} + for row, (support, cost, multiplicity) in enumerate(component_supports): + rhs[row] = cost + multiplicities[row] = multiplicity + reindexed_support = [local_reindex[d] for d in support] + support_sizes[row] = len(reindexed_support) + if support_sizes[row] == 1: + support_size_hist["1"] += 1 + elif support_sizes[row] == 2: + support_size_hist["2"] += 1 + elif support_sizes[row] == 3: + support_size_hist["3"] += 1 + graphlike = False + else: + support_size_hist["4+"] += 1 + graphlike = False + component_global_supports.append([int(active_detector_ids[d]) for d in support]) + for col in reindexed_support: + row_indices.append(row) + col_indices.append(col) + values.append(1.0) + + a_ub = sparse.csr_matrix( + (values, (row_indices, col_indices)), + shape=(num_constraints, num_vars), + dtype=np.float64, + ) + + t0 = time.perf_counter() + result = linprog( + c=-np.ones(num_vars, dtype=np.float64), + A_ub=a_ub, + b_ub=rhs, + bounds=[(0.0, None)] * num_vars, + method="highs", + ) + lp_time_sec = time.perf_counter() - t0 + total_lp_time += lp_time_sec + if not result.success: + raise RuntimeError( + f"LP solve failed for circuit={circuit_name} shot={shot_index} node={state.node_id} " + f"component={component_index}: {result.message}" + ) + + y = np.asarray(result.x, dtype=np.float64) + total_opt_h += float(-result.fun) + positive_y_mask = y > LP_TOL + positive_y_count = int(np.count_nonzero(positive_y_mask)) + total_positive_y += positive_y_count + + if hasattr(result, "ineqlin") and hasattr(result.ineqlin, "residual"): + residual = np.asarray(result.ineqlin.residual, dtype=np.float64) + else: + residual = rhs - a_ub.dot(y) + tight_mask = residual <= LP_TOL + tight_count = int(np.count_nonzero(tight_mask)) + total_tight_constraints += tight_count + + if hasattr(result, "ineqlin") and hasattr(result.ineqlin, "marginals"): + dual = -np.asarray(result.ineqlin.marginals, dtype=np.float64) + else: + dual = np.full(num_constraints, np.nan) + if np.isnan(dual).any(): + positive_dual_mask = np.zeros(num_constraints, dtype=bool) + positive_dual = np.zeros(0, dtype=np.float64) + else: + positive_dual_mask = dual > LP_TOL + positive_dual = dual[positive_dual_mask] + positive_dual_count = int(np.count_nonzero(positive_dual_mask)) + total_positive_dual += positive_dual_count + + if graphlike: + num_graphlike_components += 1 + + positive_dual_size_hist = {"1": 0, "2": 0, "3": 0, "4+": 0} + for size in support_sizes[positive_dual_mask]: + if size == 1: + positive_dual_size_hist["1"] += 1 + elif size == 2: + positive_dual_size_hist["2"] += 1 + elif size == 3: + positive_dual_size_hist["3"] += 1 + else: + positive_dual_size_hist["4+"] += 1 + + component_record = { + "run_id": run_id, + "circuit": circuit_name, + "shot": shot_index, + "node_id": state.node_id, + "component_index": component_index, + "num_variables": num_vars, + "num_constraints": num_constraints, + "objective": float(-result.fun), + "lp_time_sec": lp_time_sec, + "graphlike": graphlike, + "max_support_size": int(np.max(support_sizes) if support_sizes.size else 0), + "support_hist": support_size_hist, + "positive_y_count": positive_y_count, + "tight_constraint_count": tight_count, + "positive_dual_count": positive_dual_count, + "dual_integral_fraction": grid_fraction(positive_dual, 1), + "dual_half_integral_fraction": grid_fraction(positive_dual, 2), + "dual_third_integral_fraction": grid_fraction(positive_dual, 3), + "dual_quarter_integral_fraction": grid_fraction(positive_dual, 4), + "positive_dual_support_hist": positive_dual_size_hist, + "support_multiplicity_mean": float(np.mean(multiplicities)) if multiplicities.size else 0.0, + "support_multiplicity_max": int(np.max(multiplicities) if multiplicities.size else 0), + } + component_records.append(component_record) + + if sample_raw_instance: + sample_components.append( + { + "component_index": component_index, + "global_detector_ids": [int(active_detector_ids[d]) for d in component_detectors], + "supports": [ + { + "global_support": component_global_supports[row], + "cost": float(rhs[row]), + "multiplicity": int(multiplicities[row]), + "dual": float(dual[row]) if not np.isnan(dual[row]) else None, + "slack": float(residual[row]), + } + for row in range(num_constraints) + ], + "y": [float(v) for v in y], + } + ) + + base_node_record.update( + { + "opt_h": total_opt_h, + "opt_infeasible": False, + "lp_calls": len(component_records), + "lp_time_sec": total_lp_time, + "total_lp_vars": total_lp_vars, + "total_lp_constraints": total_lp_constraints, + "num_components": len(component_records), + "num_graphlike_components": num_graphlike_components, + "graphlike_all_components": num_graphlike_components == len(component_records), + "max_support_size": max_support_size, + "max_component_variables": max_component_variables, + "max_component_constraints": max_component_constraints, + "positive_y_count": total_positive_y, + "tight_constraint_count": total_tight_constraints, + "positive_dual_count": total_positive_dual, + } + ) + + sample_record = None + if sample_raw_instance: + sample_record = { + "run_id": run_id, + "circuit": circuit_name, + "shot": shot_index, + "node_id": state.node_id, + "parent_node_id": state.parent_node_id, + "incoming_error_index": state.incoming_error_index, + "depth": state.depth, + "path_cost": state.path_cost, + "plain_h": plain_h, + "opt_h": total_opt_h, + "active_detector_ids": active_detector_ids.tolist(), + "components": sample_components, + } + + return LPProbeResult( + opt_h=total_opt_h, + node_record=base_node_record, + component_records=component_records, + sample_record=sample_record, + ) + + +def compute_node_metrics( + *, + run_id: str, + circuit_name: str, + shot_index: int, + state: SearchState, + data: DecoderData, + settings: SearchSettings, + sample_raw_instance: bool, +) -> LPProbeResult: + plain_h = plain_detcost_heuristic( + data=data, + active_detectors=state.active_detectors, + activated_error_mask=state.activated_error_mask, + blocked_errors=state.blocked_errors, + active_detector_counts=state.active_detector_counts, + respect_blocked_errors_in_heuristic=settings.respect_blocked_errors_in_heuristic, + ) + lp_probe = probe_opt_singleton_lp( + run_id=run_id, + circuit_name=circuit_name, + shot_index=shot_index, + state=state, + data=data, + settings=settings, + plain_h=plain_h, + sample_raw_instance=sample_raw_instance, + ) + return lp_probe + + +def observables_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + observables = np.zeros(data.num_observables, dtype=bool) + for error_index in activated_errors: + for observable in data.error_observables[error_index]: + observables[observable] ^= True + return observables + + +def detectors_from_solution(data: DecoderData, activated_errors: Sequence[int]) -> np.ndarray: + detectors = np.zeros(data.num_detectors, dtype=bool) + for error_index in activated_errors: + for detector in data.error_detectors[error_index]: + detectors[detector] ^= True + return detectors + + +def heuristic_for_search(settings: SearchSettings, plain_h: float, opt_h: float) -> float: + if settings.search_heuristic == "plain": + return plain_h + if settings.search_heuristic == "opt": + return opt_h + raise ValueError(f"Unknown search heuristic: {settings.search_heuristic}") + + +def decode_and_probe_shot( + *, + run_id: str, + circuit_name: str, + shot_index: int, + shot_detections: np.ndarray, + shot_observables: np.ndarray, + data: DecoderData, + settings: SearchSettings, + logger: ProbeLogger, +) -> Dict[str, Any]: + shot_start = time.perf_counter() + sampler = NodeSampler(settings.sample_raw_nodes_per_shot) + aggregator = ShotAggregator() + + initial_active_detectors = np.asarray(shot_detections, dtype=bool).copy() + initial_counts = initial_detector_counts(data, initial_active_detectors) + initial_activated_mask = np.zeros(len(data.errors), dtype=bool) + initial_blocked = np.zeros(len(data.errors), dtype=bool) + + root_state = SearchState( + node_id=0, + parent_node_id=None, + incoming_error_index=None, + depth=0, + activated_errors=(), + activated_error_mask=initial_activated_mask, + blocked_errors=initial_blocked, + active_detectors=initial_active_detectors, + active_detector_counts=initial_counts, + path_cost=0.0, + search_h=0.0, + plain_h=0.0, + opt_h=0.0, + ) + + root_probe = compute_node_metrics( + run_id=run_id, + circuit_name=circuit_name, + shot_index=shot_index, + state=root_state, + data=data, + settings=settings, + sample_raw_instance=sampler.should_sample(root_state.node_id), + ) + root_state.plain_h = float(root_probe.node_record["plain_h"]) + root_state.opt_h = float(root_probe.node_record["opt_h"]) + root_state.search_h = heuristic_for_search(settings, root_state.plain_h, root_state.opt_h) + if root_state.search_h == INF: + raise RuntimeError( + f"Root node is infeasible for circuit={circuit_name} shot={shot_index}." + ) + + root_record = { + **root_probe.node_record, + "search_h": root_state.search_h, + "f_cost": root_state.path_cost + root_state.search_h, + "pushed": True, + } + logger.node_writer.write(root_record) + for component_record in root_probe.component_records: + logger.component_writer.write(component_record) + if root_probe.sample_record is not None: + logger.sample_writer.write(root_probe.sample_record) + aggregator.absorb_node(root_record) + + queue: List[Tuple[float, int, int, SearchState]] = [] + heapq_push_counter = 0 + npush = 1 + popped = 0 + max_queue_size = 1 + min_num_dets = int(initial_active_detectors.sum()) + max_num_dets = INF if settings.det_beam == INF else min_num_dets + settings.det_beam + heapq.heappush(queue, (root_state.path_cost + root_state.search_h, min_num_dets, heapq_push_counter, root_state)) + heapq_push_counter += 1 + next_node_id = 1 + + solution_state: Optional[SearchState] = None + status = "unknown" + + while queue: + max_queue_size = max(max_queue_size, len(queue)) + f_cost, num_dets, _, state = heapq.heappop(queue) + popped += 1 + + if settings.max_nodes_popped is not None and popped > settings.max_nodes_popped: + status = "max_nodes_popped" + break + + if num_dets > max_num_dets: + continue + + if settings.verbose_search: + print( + f"[{circuit_name} shot={shot_index}] nodes_popped={popped} pq={len(queue)} " + f"active_dets={num_dets} max_active_dets={max_num_dets} depth={state.depth} " + f"g={state.path_cost:.12g} h={state.search_h:.12g} f={f_cost:.12g}", + flush=True, + ) + + if num_dets == 0: + solution_state = state + status = "success" + break + + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = INF if settings.det_beam == INF else min_num_dets + settings.det_beam + + min_detector = int(np.flatnonzero(state.active_detectors)[0]) + blocked_prefix = state.blocked_errors.copy() + + for error_index in data.detector_to_errors[min_detector]: + blocked_prefix[error_index] = True + if state.blocked_errors[error_index]: + continue + + child_active_detectors, child_counts = apply_error( + data=data, + active_detectors=state.active_detectors, + active_detector_counts=state.active_detector_counts, + error_index=error_index, + ) + child_num_dets = int(child_active_detectors.sum()) + if child_num_dets > max_num_dets: + continue + + child_activated_mask = state.activated_error_mask.copy() + child_activated_mask[error_index] = True + child_blocked = blocked_prefix.copy() + child_path_cost = state.path_cost + float(data.error_costs[error_index]) + + child_state = SearchState( + node_id=next_node_id, + parent_node_id=state.node_id, + incoming_error_index=error_index, + depth=state.depth + 1, + activated_errors=state.activated_errors + (error_index,), + activated_error_mask=child_activated_mask, + blocked_errors=child_blocked, + active_detectors=child_active_detectors, + active_detector_counts=child_counts, + path_cost=child_path_cost, + search_h=0.0, + plain_h=0.0, + opt_h=0.0, + ) + next_node_id += 1 + + child_probe = compute_node_metrics( + run_id=run_id, + circuit_name=circuit_name, + shot_index=shot_index, + state=child_state, + data=data, + settings=settings, + sample_raw_instance=sampler.should_sample(child_state.node_id), + ) + child_state.plain_h = float(child_probe.node_record["plain_h"]) + child_state.opt_h = float(child_probe.node_record["opt_h"]) + child_state.search_h = heuristic_for_search(settings, child_state.plain_h, child_state.opt_h) + + pushed = child_state.search_h != INF + child_record = { + **child_probe.node_record, + "search_h": child_state.search_h, + "f_cost": child_state.path_cost + child_state.search_h, + "pushed": pushed, + } + logger.node_writer.write(child_record) + for component_record in child_probe.component_records: + logger.component_writer.write(component_record) + if child_probe.sample_record is not None: + logger.sample_writer.write(child_probe.sample_record) + aggregator.absorb_node(child_record) + + if not pushed: + continue + + heapq.heappush( + queue, + ( + child_state.path_cost + child_state.search_h, + child_num_dets, + heapq_push_counter, + child_state, + ), + ) + heapq_push_counter += 1 + npush += 1 + if settings.max_nodes_pushed is not None and npush > settings.max_nodes_pushed: + status = "max_nodes_pushed" + queue.clear() + break + + if status == "max_nodes_pushed": + break + + if status == "unknown": + status = "empty_queue" + + elapsed_seconds = time.perf_counter() - shot_start + predicted_observables: Optional[np.ndarray] = None + solution_cost: Optional[float] = None + observables_match: Optional[bool] = None + solution_size: Optional[int] = None + + if solution_state is not None: + reproduced_detectors = detectors_from_solution(data, solution_state.activated_errors) + if not np.array_equal(reproduced_detectors, shot_detections): + raise AssertionError( + f"Decoded error set does not reproduce the shot syndrome for circuit={circuit_name} shot={shot_index}." + ) + predicted_observables = observables_from_solution(data, solution_state.activated_errors) + observables_match = bool(np.array_equal(predicted_observables, shot_observables)) + solution_cost = float(solution_state.path_cost) + solution_size = len(solution_state.activated_errors) + + summary = { + "run_id": run_id, + "circuit": circuit_name, + "shot": shot_index, + **aggregator.finish(nodes_popped=popped, status=status, elapsed_seconds=elapsed_seconds), + "max_queue_size": max_queue_size, + "det_beam": settings.det_beam, + "search_heuristic": settings.search_heuristic, + "solution_cost": solution_cost, + "solution_size": solution_size, + "observables_match": observables_match, + "predicted_observables": (np.flatnonzero(predicted_observables).tolist() if predicted_observables is not None else None), + "sample_observables": np.flatnonzero(shot_observables).tolist(), + } + logger.shot_writer.write(summary) + logger.flush() + return summary + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Run the prototype decoder on several circuits and log detailed LP-structure data " + "for the optimal singleton heuristic." + ) + ) + parser.add_argument( + "circuits", + nargs="+", + type=Path, + help="Stim circuit files to analyze.", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Directory where logs will be written.", + ) + parser.add_argument( + "--shots-per-circuit", + type=int, + default=10, + help="Number of sampled shots to decode per circuit (default: 10).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Seed passed to stim.compile_detector_sampler(...).sample(...).", + ) + parser.add_argument( + "--det-beam", + type=parse_beam, + default=INF, + help="Beam cutoff on residual detector count. Use an integer or 'inf'.", + ) + parser.add_argument( + "--max-nodes-popped", + type=parse_optional_int, + default=5000, + help="Stop after this many popped nodes per shot (default: 5000; use 'none' for no limit).", + ) + parser.add_argument( + "--max-nodes-pushed", + type=parse_optional_int, + default=50000, + help="Stop after this many pushed nodes per shot (default: 50000; use 'none' for no limit).", + ) + parser.add_argument( + "--search-heuristic", + choices=["plain", "opt"], + default="opt", + help="Heuristic used for queue ordering. Both plain and optimal values are always logged.", + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action="store_true", + help=( + "When set, both heuristics exclude precedence-blocked errors as well as already-activated errors. " + "By default, heuristics only exclude already-activated errors, matching the original prototype." + ), + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--sample-raw-nodes-per-shot", + type=int, + default=25, + help="How many raw LP instances to dump per shot (default: 25).", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print one line per popped node.", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Suppress per-shot progress printing.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.shots_per_circuit <= 0: + parser.error("--shots-per-circuit must be positive.") + if args.sample_raw_nodes_per_shot < 0: + parser.error("--sample-raw-nodes-per-shot must be non-negative.") + + output_dir: Path = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + run_id = f"singleton_lp_probe_{int(time.time())}" + + manifest = { + "run_id": run_id, + "argv": list(argv) if argv is not None else sys.argv[1:], + "circuits": [str(p) for p in args.circuits], + "shots_per_circuit": args.shots_per_circuit, + "seed": args.seed, + "det_beam": args.det_beam, + "max_nodes_popped": args.max_nodes_popped, + "max_nodes_pushed": args.max_nodes_pushed, + "search_heuristic": args.search_heuristic, + "respect_blocked_errors_in_heuristic": args.respect_blocked_errors_in_heuristic, + "merge_errors": args.merge_errors, + "sample_raw_nodes_per_shot": args.sample_raw_nodes_per_shot, + "lp_tol": LP_TOL, + "rational_tol": RATIONAL_TOL, + } + (output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + logger = ProbeLogger(output_dir) + settings = SearchSettings( + det_beam=args.det_beam, + search_heuristic=args.search_heuristic, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + max_nodes_popped=args.max_nodes_popped, + max_nodes_pushed=args.max_nodes_pushed, + sample_raw_nodes_per_shot=args.sample_raw_nodes_per_shot, + verbose_search=args.verbose_search, + ) + + try: + for circuit_path in args.circuits: + circuit = stim.Circuit.from_file(str(circuit_path)) + dem = circuit.detector_error_model(decompose_errors=False) + data = build_decoder_data(dem, merge_errors_in_dem=args.merge_errors) + dets_packed, obs_packed = circuit.compile_detector_sampler(seed=args.seed).sample( + shots=args.shots_per_circuit, + separate_observables=True, + bit_packed=True, + ) + detections = unpack_bit_packed_rows(dets_packed, count=dem.num_detectors) + observables = unpack_bit_packed_rows(obs_packed, count=dem.num_observables) + circuit_name = circuit_path.name + + for shot_index in range(args.shots_per_circuit): + if not args.quiet: + print( + f"[{run_id}] circuit={circuit_name} shot={shot_index} " + f"detectors={int(np.count_nonzero(detections[shot_index]))} ...", + flush=True, + ) + summary = decode_and_probe_shot( + run_id=run_id, + circuit_name=circuit_name, + shot_index=shot_index, + shot_detections=detections[shot_index], + shot_observables=observables[shot_index] if observables.size else np.zeros(0, dtype=bool), + data=data, + settings=settings, + logger=logger, + ) + if not args.quiet: + print( + f"[{run_id}] done circuit={circuit_name} shot={shot_index} status={summary['status']} " + f"nodes_popped={summary['nodes_popped']} nodes_created={summary['nodes_created']} " + f"total_lp_time_sec={summary['total_lp_time_sec']:.6f}", + flush=True, + ) + finally: + logger.close() + + if not args.quiet: + print(f"Wrote logs under {output_dir}", flush=True) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/tesseract.h b/src/tesseract.h index 528c43b..517d326 100644 --- a/src/tesseract.h +++ b/src/tesseract.h @@ -97,6 +97,8 @@ struct TesseractDecoder { std::vector>& obs_predicted); bool low_confidence_flag = false; + size_t num_pq_pushed = 0; + size_t num_pq_popped = 0; std::vector predicted_errors_buffer; std::vector dem_error_to_error; std::vector error_to_dem_error; @@ -108,7 +110,6 @@ struct TesseractDecoder { return eneighbors; } - private: std::vector> d2e; std::vector> eneighbors; std::vector> edets; diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index 3bdf477..eb01ed2 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -468,6 +468,10 @@ void add_tesseract_module(py::module& root) { "The configuration used to create this decoder.") .def_readwrite("low_confidence_flag", &TesseractDecoder::low_confidence_flag, "A flag indicating if the decoder's prediction has low confidence.") + .def_readwrite("num_pq_pushed", &TesseractDecoder::num_pq_pushed, + "The number of items pushed to the priority queue during the most recent decode.") + .def_readwrite("num_pq_popped", &TesseractDecoder::num_pq_popped, + "The number of items popped from the priority queue during the most recent decode.") .def_readwrite( "predicted_errors_buffer", &TesseractDecoder::predicted_errors_buffer, "A buffer containing the predicted errors from the most recent decode operation.") diff --git a/src/tesseract_ftl.cc b/src/tesseract_ftl.cc new file mode 100644 index 0000000..a9322c9 --- /dev/null +++ b/src/tesseract_ftl.cc @@ -0,0 +1,1170 @@ + +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tesseract_ftl.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +constexpr double INF_D = std::numeric_limits::infinity(); +constexpr double HEURISTIC_EPS = 1e-9; +constexpr double SIMPLEX_EPS = 1e-9; +constexpr double SEED_TIGHT_EPS = 1e-9; +constexpr double VIOLATION_EPS = 1e-9; +constexpr size_t VIOLATION_BATCH_SIZE = 4; + +struct UnionFind { + std::vector parent; + std::vector rank; + + explicit UnionFind(size_t n) : parent(n), rank(n, 0) { + std::iota(parent.begin(), parent.end(), 0); + } + + int find(int x) { + while (parent[x] != x) { + parent[x] = parent[parent[x]]; + x = parent[x]; + } + return x; + } + + void unite(int a, int b) { + a = find(a); + b = find(b); + if (a == b) return; + if (rank[a] < rank[b]) { + parent[a] = b; + } else if (rank[a] > rank[b]) { + parent[b] = a; + } else { + parent[b] = a; + rank[a]++; + } + } +}; + +template +std::ostream& operator<<(std::ostream& os, const std::vector& vec) { + os << "["; + bool is_first = true; + for (const auto& x : vec) { + if (!is_first) os << ", "; + is_first = false; + os << x; + } + os << "]"; + return os; +} + +struct IntVectorHash { + size_t operator()(const std::vector& values) const { + return boost::hash_range(values.begin(), values.end()); + } +}; + +struct DenseSimplexResult { + bool success = false; + bool unbounded = false; + double objective = 0.0; + size_t pivots = 0; + std::vector solution; +}; + +double dot_on_support(const std::vector& values, const std::vector& support) { + double total = 0.0; + for (int idx : support) total += values[(size_t)idx]; + return total; +} + +// Solves: +// maximize sum_i x_i +// subject to A x <= b +// x >= 0 +// where A is a 0/1 matrix given by row supports. +DenseSimplexResult solve_dense_primal_packing_lp( + size_t num_vars, const std::vector>& row_supports, + const std::vector& rhs, const std::vector* entering_priorities = nullptr) { + DenseSimplexResult result; + result.solution.assign(num_vars, 0.0); + + const size_t num_rows = row_supports.size(); + if (num_vars == 0) { + result.success = true; + return result; + } + if (num_rows == 0) { + result.unbounded = true; + return result; + } + + const size_t width = num_vars + num_rows + 1; + std::vector> tableau(num_rows + 1, std::vector(width, 0.0)); + std::vector basis(num_rows); + + for (size_t row = 0; row < num_rows; ++row) { + for (int col : row_supports[row]) { + tableau[row][(size_t)col] = 1.0; + } + tableau[row][num_vars + row] = 1.0; + tableau[row][width - 1] = rhs[row]; + basis[row] = num_vars + row; + if (rhs[row] < -SIMPLEX_EPS) { + throw std::runtime_error("Dense simplex received a negative RHS."); + } + } + for (size_t col = 0; col < num_vars; ++col) { + tableau[num_rows][col] = -1.0; + } + + auto pivot = [&](size_t pivot_row, size_t pivot_col) { + const double pivot_value = tableau[pivot_row][pivot_col]; + assert(std::abs(pivot_value) > SIMPLEX_EPS); + const double inv_pivot = 1.0 / pivot_value; + for (size_t col = 0; col < width; ++col) { + tableau[pivot_row][col] *= inv_pivot; + } + tableau[pivot_row][pivot_col] = 1.0; + + for (size_t row = 0; row <= num_rows; ++row) { + if (row == pivot_row) continue; + const double factor = tableau[row][pivot_col]; + if (std::abs(factor) <= SIMPLEX_EPS) { + tableau[row][pivot_col] = 0.0; + continue; + } + for (size_t col = 0; col < width; ++col) { + tableau[row][col] -= factor * tableau[pivot_row][col]; + } + tableau[row][pivot_col] = 0.0; + } + basis[pivot_row] = pivot_col; + result.pivots++; + }; + + while (true) { + size_t entering_col = width; + double entering_priority = -INF_D; + for (size_t col = 0; col + 1 < width; ++col) { + if (tableau[num_rows][col] >= -SIMPLEX_EPS) continue; + const bool current_is_original = entering_col < num_vars; + const bool candidate_is_original = col < num_vars; + const double candidate_priority = + candidate_is_original && entering_priorities != nullptr + ? (*entering_priorities)[col] + : -INF_D; + if (entering_col == width) { + entering_col = col; + entering_priority = candidate_priority; + continue; + } + if (candidate_is_original != current_is_original) { + if (candidate_is_original) { + entering_col = col; + entering_priority = candidate_priority; + } + continue; + } + if (candidate_priority > entering_priority + SIMPLEX_EPS || + (std::abs(candidate_priority - entering_priority) <= SIMPLEX_EPS && + col < entering_col)) { + entering_col = col; + entering_priority = candidate_priority; + } + } + if (entering_col == width) { + break; + } + + size_t leaving_row = num_rows; + double best_ratio = INF_D; + for (size_t row = 0; row < num_rows; ++row) { + const double coeff = tableau[row][entering_col]; + if (coeff <= SIMPLEX_EPS) continue; + const double ratio = tableau[row][width - 1] / coeff; + if (ratio + SIMPLEX_EPS < best_ratio) { + best_ratio = ratio; + leaving_row = row; + } else if (std::abs(ratio - best_ratio) <= SIMPLEX_EPS && leaving_row != num_rows && + basis[row] < basis[leaving_row]) { + leaving_row = row; + } + } + + if (leaving_row == num_rows) { + result.unbounded = true; + return result; + } + pivot(leaving_row, entering_col); + } + + for (size_t row = 0; row < num_rows; ++row) { + if (basis[row] < num_vars) { + double value = tableau[row][width - 1]; + if (std::abs(value) <= SIMPLEX_EPS) value = 0.0; + result.solution[basis[row]] = value; + } + } + result.objective = tableau[num_rows][width - 1]; + if (std::abs(result.objective) <= SIMPLEX_EPS) result.objective = 0.0; + result.success = true; + return result; +} + +template +double lookup_detector_budget(const Solution& solution, int detector) { + auto it = std::lower_bound(solution.active_detectors.begin(), solution.active_detectors.end(), + detector); + if (it == solution.active_detectors.end() || *it != detector) return 0.0; + const size_t pos = (size_t)(it - solution.active_detectors.begin()); + return solution.detector_budgets[pos]; +} + +struct SingletonComponentSolveResult { + bool success = false; + bool unbounded = false; + double objective = 0.0; + size_t reduced_constraints = 0; + size_t simplex_solves = 0; + std::vector detector_budgets; +}; + +SingletonComponentSolveResult solve_singleton_component_lp( + size_t num_local_detectors, const std::vector>& row_supports, + const std::vector& rhs, const std::vector& cheapest_constraint_for_local_detector, + const std::vector& seed_budgets) { + SingletonComponentSolveResult result; + result.detector_budgets.assign(num_local_detectors, 0.0); + + if (num_local_detectors == 0) { + result.success = true; + return result; + } + if (row_supports.empty()) { + result.unbounded = true; + return result; + } + + const double seed_total = std::accumulate(seed_budgets.begin(), seed_budgets.end(), 0.0); + + std::vector selected(row_supports.size(), 0); + std::vector selected_indices; + selected_indices.reserve(std::min(row_supports.size(), num_local_detectors * 2 + 4)); + + auto add_constraint = [&](int idx) { + if (idx < 0) return; + if (!selected[(size_t)idx]) { + selected[(size_t)idx] = 1; + selected_indices.push_back(idx); + } + }; + + for (size_t row = 0; row < row_supports.size(); ++row) { + const double slack = rhs[row] - dot_on_support(seed_budgets, row_supports[row]); + if (slack <= SEED_TIGHT_EPS * (1.0 + rhs[row])) { + add_constraint((int)row); + } + } + + std::vector covered(num_local_detectors, 0); + for (int idx : selected_indices) { + for (int local : row_supports[(size_t)idx]) covered[(size_t)local] = 1; + } + for (size_t local = 0; local < num_local_detectors; ++local) { + if (!covered[local]) { + const int idx = cheapest_constraint_for_local_detector[local]; + if (idx < 0) { + throw std::runtime_error("Missing seed constraint for active detector."); + } + add_constraint(idx); + for (int touched : row_supports[(size_t)idx]) covered[(size_t)touched] = 1; + } + } + + if (selected_indices.empty()) { + throw std::runtime_error("Singleton LP seed set unexpectedly empty."); + } + + size_t rounds = 0; + while (true) { + if (++rounds > row_supports.size() + 1) { + throw std::runtime_error("Constraint generation exceeded the number of unique constraints."); + } + + std::vector> reduced_rows; + std::vector reduced_rhs; + reduced_rows.reserve(selected_indices.size()); + reduced_rhs.reserve(selected_indices.size()); + for (int idx : selected_indices) { + reduced_rows.push_back(row_supports[(size_t)idx]); + reduced_rhs.push_back(rhs[(size_t)idx]); + } + + DenseSimplexResult simplex = + solve_dense_primal_packing_lp(num_local_detectors, reduced_rows, reduced_rhs, + &seed_budgets); + result.simplex_solves++; + if (simplex.unbounded) { + result.unbounded = true; + return result; + } + if (!simplex.success) { + return result; + } + if (simplex.objective + 1e-7 < seed_total) { + throw std::runtime_error("Reduced singleton LP optimum fell below the projected seed bound."); + } + + double max_violation = 0.0; + std::vector> top_violated; + top_violated.reserve(VIOLATION_BATCH_SIZE); + + for (size_t row = 0; row < row_supports.size(); ++row) { + if (selected[row]) continue; + const double lhs = dot_on_support(simplex.solution, row_supports[row]); + const double violation = lhs - rhs[row]; + if (violation > max_violation) { + max_violation = violation; + } + if (violation <= VIOLATION_EPS * (1.0 + rhs[row])) continue; + + top_violated.emplace_back(violation, (int)row); + std::sort(top_violated.begin(), top_violated.end(), + [](const auto& a, const auto& b) { return a.first > b.first; }); + if (top_violated.size() > VIOLATION_BATCH_SIZE) top_violated.pop_back(); + } + + if (max_violation <= VIOLATION_EPS) { + result.success = true; + result.objective = simplex.objective; + result.reduced_constraints = selected_indices.size(); + result.detector_budgets = std::move(simplex.solution); + return result; + } + + bool added_any = false; + for (const auto& [_, idx] : top_violated) { + if (!selected[(size_t)idx]) { + add_constraint(idx); + added_any = true; + } + } + if (!added_any) { + throw std::runtime_error("Constraint generation identified violations but added no rows."); + } + } +} + +std::string heuristic_source_to_string(FTLHeuristicSource source) { + switch (source) { + case FTLHeuristicSource::kPlain: + return "plain"; + case FTLHeuristicSource::kProjected: + return "projected"; + case FTLHeuristicSource::kExact: + return "exact"; + } + return "unknown"; +} + +} // namespace + +std::string TesseractFTLConfig::str() { + std::stringstream ss; + ss << "TesseractFTLConfig("; + ss << "dem=DetectorErrorModel_Object, "; + ss << "det_beam=" << det_beam << ", "; + ss << "no_revisit_dets=" << no_revisit_dets << ", "; + ss << "verbose=" << verbose << ", "; + ss << "merge_errors=" << merge_errors << ", "; + ss << "pqlimit=" << pqlimit << ", "; + ss << "det_orders=" << det_orders << ", "; + ss << "det_penalty=" << det_penalty << ", "; + ss << "create_visualization=" << create_visualization << ", "; + ss << "subset_detcost_size=" << subset_detcost_size; + ss << ")"; + return ss.str(); +} + +void TesseractFTLStats::clear() { + *this = TesseractFTLStats{}; +} + +void TesseractFTLStats::accumulate(const TesseractFTLStats& other) { + num_pq_pushed += other.num_pq_pushed; + num_nodes_popped += other.num_nodes_popped; + max_queue_size = std::max(max_queue_size, other.max_queue_size); + heuristic_calls += other.heuristic_calls; + plain_heuristic_calls += other.plain_heuristic_calls; + projection_heuristic_calls += other.projection_heuristic_calls; + exact_refinement_calls += other.exact_refinement_calls; + lp_calls += other.lp_calls; + lp_reinserts += other.lp_reinserts; + projected_nodes_generated += other.projected_nodes_generated; + projected_nodes_refined += other.projected_nodes_refined; + total_lp_refinement_gain += other.total_lp_refinement_gain; + max_lp_refinement_gain = std::max(max_lp_refinement_gain, other.max_lp_refinement_gain); + lp_total_seconds += other.lp_total_seconds; +} + +bool TesseractFTLDecoder::FTLNode::operator>(const FTLNode& other) const { + return f_cost > other.f_cost || (f_cost == other.f_cost && num_dets < other.num_dets); +} + +size_t TesseractFTLDecoder::DynamicBitsetHash::operator()( + const boost::dynamic_bitset<>& bs) const { + return boost::hash_value(bs); +} + +TesseractFTLDecoder::TesseractFTLDecoder(TesseractFTLConfig config_) : config(config_) { + if (config.subset_detcost_size > 1) { + throw std::invalid_argument( + "tesseract_ftl singleton mode supports only subset_detcost_size of 0 or 1"); + } + + if (config.subset_detcost_size == 0) { + TesseractConfig delegate_config; + delegate_config.dem = config.dem; + delegate_config.det_beam = config.det_beam; + delegate_config.beam_climbing = config.beam_climbing; + delegate_config.no_revisit_dets = config.no_revisit_dets; + delegate_config.verbose = config.verbose; + delegate_config.merge_errors = config.merge_errors; + delegate_config.pqlimit = config.pqlimit; + delegate_config.det_orders = config.det_orders; + delegate_config.det_penalty = config.det_penalty; + delegate_config.create_visualization = config.create_visualization; + plain_delegate = std::make_unique(delegate_config); + errors = plain_delegate->errors; + num_detectors = plain_delegate->num_detectors; + num_observables = plain_delegate->num_observables; + dem_error_to_error = plain_delegate->dem_error_to_error; + error_to_dem_error = plain_delegate->error_to_dem_error; + return; + } + + std::vector dem_error_map(config.dem.flattened().count_errors()); + std::iota(dem_error_map.begin(), dem_error_map.end(), 0); + + if (config.merge_errors) { + std::vector merge_map; + config.dem = common::merge_indistinguishable_errors(config.dem, merge_map); + common::chain_error_maps(dem_error_map, merge_map); + } + + std::vector nonzero_map; + config.dem = common::remove_zero_probability_errors(config.dem, nonzero_map); + common::chain_error_maps(dem_error_map, nonzero_map); + + dem_error_to_error = std::move(dem_error_map); + error_to_dem_error = common::invert_error_map(dem_error_to_error, config.dem.count_errors()); + + if (config.det_orders.empty()) { + config.det_orders.emplace_back(config.dem.count_detectors()); + std::iota(config.det_orders[0].begin(), config.det_orders[0].end(), 0); + } else { + for (const auto& order : config.det_orders) { + if (order.size() != config.dem.count_detectors()) { + throw std::invalid_argument( + "Each detector order list must have a size equal to the number of detectors."); + } + } + } + if (config.det_orders.empty()) { + throw std::runtime_error("Detector order list must not be empty."); + } + + errors = get_errors_from_dem(config.dem.flattened()); + num_detectors = config.dem.count_detectors(); + num_errors = config.dem.count_errors(); + num_observables = config.dem.count_observables(); + + initialize_structures(num_detectors); + + if (config.create_visualization) { + auto detectors = get_detector_coords(config.dem); + visualizer.add_detector_coords(detectors); + visualizer.add_errors(errors); + } +} + +TesseractFTLDecoder::~TesseractFTLDecoder() = default; + +void TesseractFTLDecoder::initialize_structures(size_t num_detectors_) { + d2e.resize(num_detectors_); + edets.resize(num_errors); + error_costs.resize(num_errors); + + for (size_t ei = 0; ei < num_errors; ++ei) { + edets[ei] = errors[ei].symptom.detectors; + for (int d : edets[ei]) { + d2e[(size_t)d].push_back((int)ei); + } + error_costs[ei] = {errors[ei].likelihood_cost, + errors[ei].likelihood_cost / errors[ei].symptom.detectors.size()}; + } + + for (size_t d = 0; d < num_detectors_; ++d) { + std::sort(d2e[d].begin(), d2e[d].end(), [this](int a, int b) { + return error_costs[(size_t)a].min_cost < error_costs[(size_t)b].min_cost; + }); + } +} + +void TesseractFTLDecoder::flip_detectors_and_block_errors( + size_t detector_order, int64_t error_chain_idx, boost::dynamic_bitset<>& detectors, + std::vector& detector_cost_tuples) const { + (void)detector_order; + int64_t walker_idx = error_chain_idx; + while (walker_idx != -1) { + const auto& node = error_chain_arena[(size_t)walker_idx]; + const size_t ei = node.error_index; + const size_t min_detector = node.min_detector; + + for (int oei : d2e[min_detector]) { + detector_cost_tuples[(size_t)oei].error_blocked = 1; + if ((size_t)oei == ei) break; + } + for (int d : edets[ei]) detectors[(size_t)d] = !detectors[(size_t)d]; + walker_idx = node.parent_idx; + } +} + +TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_components( + const boost::dynamic_bitset<>& detectors, + const std::vector& detector_cost_tuples) const { + SingletonBuildResult result; + + std::vector active_detectors; + active_detectors.reserve(detectors.count()); + std::vector detector_to_active_pos(num_detectors, -1); + for (size_t detector = detectors.find_first(); detector != boost::dynamic_bitset<>::npos; + detector = detectors.find_next(detector)) { + detector_to_active_pos[detector] = (int)active_detectors.size(); + active_detectors.push_back((int)detector); + } + if (active_detectors.empty()) return result; + + UnionFind uf(active_detectors.size()); + std::vector has_available(active_detectors.size(), 0); + + for (size_t ei = 0; ei < num_errors; ++ei) { + if (detector_cost_tuples[ei].error_blocked) continue; + int first_active = -1; + for (int detector : edets[ei]) { + const int active_pos = detector_to_active_pos[(size_t)detector]; + if (active_pos < 0) continue; + has_available[(size_t)active_pos] = 1; + if (first_active < 0) { + first_active = active_pos; + } else { + uf.unite(first_active, active_pos); + } + } + } + + for (size_t active_pos = 0; active_pos < active_detectors.size(); ++active_pos) { + if (!has_available[active_pos]) { + result.feasible = false; + return result; + } + } + + std::unordered_map> positions_by_root; + positions_by_root.reserve(active_detectors.size()); + for (int active_pos = 0; active_pos < (int)active_detectors.size(); ++active_pos) { + positions_by_root[uf.find(active_pos)].push_back(active_pos); + } + + std::vector> component_positions; + component_positions.reserve(positions_by_root.size()); + for (auto& [_, positions] : positions_by_root) { + std::sort(positions.begin(), positions.end(), [&](int a, int b) { + return active_detectors[(size_t)a] < active_detectors[(size_t)b]; + }); + component_positions.push_back(std::move(positions)); + } + std::sort(component_positions.begin(), component_positions.end(), + [&](const auto& a, const auto& b) { + return active_detectors[(size_t)a[0]] < active_detectors[(size_t)b[0]]; + }); + + std::vector active_pos_to_component(active_detectors.size(), -1); + std::vector active_pos_to_local(active_detectors.size(), -1); + + result.components.reserve(component_positions.size()); + for (size_t component_index = 0; component_index < component_positions.size(); + ++component_index) { + const auto& positions = component_positions[component_index]; + SingletonLPComponent component; + component.detectors.reserve(positions.size()); + for (size_t local = 0; local < positions.size(); ++local) { + const int active_pos = positions[local]; + active_pos_to_component[(size_t)active_pos] = (int)component_index; + active_pos_to_local[(size_t)active_pos] = (int)local; + component.detectors.push_back(active_detectors[(size_t)active_pos]); + } + result.components.push_back(std::move(component)); + } + + std::vector, double, IntVectorHash>> min_rhs_by_pattern( + result.components.size()); + + for (size_t ei = 0; ei < num_errors; ++ei) { + if (detector_cost_tuples[ei].error_blocked) continue; + + int component_index = -1; + std::vector local_hits; + local_hits.reserve(edets[ei].size()); + + for (int detector : edets[ei]) { + const int active_pos = detector_to_active_pos[(size_t)detector]; + if (active_pos < 0) continue; + if (component_index < 0) { + component_index = active_pos_to_component[(size_t)active_pos]; + } else { + assert(component_index == active_pos_to_component[(size_t)active_pos]); + } + local_hits.push_back(active_pos_to_local[(size_t)active_pos]); + } + + if (component_index < 0) continue; + + std::sort(local_hits.begin(), local_hits.end()); + auto& rhs_map = min_rhs_by_pattern[(size_t)component_index]; + const double rhs = errors[ei].likelihood_cost; + auto it = rhs_map.find(local_hits); + if (it == rhs_map.end() || rhs < it->second) { + rhs_map[local_hits] = rhs; + } + } + + for (size_t component_index = 0; component_index < result.components.size(); ++component_index) { + auto& component = result.components[component_index]; + const auto& rhs_map = min_rhs_by_pattern[component_index]; + + component.constraints.reserve(rhs_map.size()); + for (const auto& [local_hits, rhs] : rhs_map) { + component.constraints.push_back({local_hits, rhs}); + } + std::sort(component.constraints.begin(), component.constraints.end(), + [](const auto& a, const auto& b) { + if (a.local_detectors.size() != b.local_detectors.size()) { + return a.local_detectors.size() < b.local_detectors.size(); + } + if (a.local_detectors != b.local_detectors) { + return a.local_detectors < b.local_detectors; + } + return a.rhs < b.rhs; + }); + + component.cheapest_constraint_for_local_detector.assign(component.detectors.size(), -1); + std::vector cheapest_rhs(component.detectors.size(), INF_D); + for (size_t constraint_index = 0; constraint_index < component.constraints.size(); + ++constraint_index) { + const auto& constraint = component.constraints[constraint_index]; + for (int local_detector : constraint.local_detectors) { + if (constraint.rhs < cheapest_rhs[(size_t)local_detector]) { + cheapest_rhs[(size_t)local_detector] = constraint.rhs; + component.cheapest_constraint_for_local_detector[(size_t)local_detector] = + (int)constraint_index; + } + } + } + for (size_t local = 0; local < component.detectors.size(); ++local) { + if (component.cheapest_constraint_for_local_detector[local] < 0) { + result.feasible = false; + result.components.clear(); + return result; + } + } + } + + return result; +} + +TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset_lp( + const boost::dynamic_bitset<>& detectors, + const std::vector& detector_cost_tuples, int64_t warm_solution_idx) { + stats.heuristic_calls++; + stats.exact_refinement_calls++; + const auto start_time = std::chrono::high_resolution_clock::now(); + + ExactSubsetSolution solution; + const auto build = build_singleton_components(detectors, detector_cost_tuples); + if (!build.feasible) { + solution.value = INF_D; + const auto stop_time = std::chrono::high_resolution_clock::now(); + stats.lp_total_seconds += + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + return solution; + } + if (build.components.empty()) { + solution.value = 0.0; + const auto stop_time = std::chrono::high_resolution_clock::now(); + stats.lp_total_seconds += + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + return solution; + } + + const ExactSubsetSolution* warm_solution = + warm_solution_idx >= 0 ? &exact_solution_arena[(size_t)warm_solution_idx] : nullptr; + + solution.value = 0.0; + solution.num_components = build.components.size(); + + for (const auto& component : build.components) { + std::vector seed_budgets(component.detectors.size(), 0.0); + if (warm_solution != nullptr) { + for (size_t local = 0; local < component.detectors.size(); ++local) { + seed_budgets[local] = lookup_detector_budget(*warm_solution, component.detectors[local]); + } + } + + std::vector> row_supports; + std::vector rhs; + row_supports.reserve(component.constraints.size()); + rhs.reserve(component.constraints.size()); + for (const auto& constraint : component.constraints) { + row_supports.push_back(constraint.local_detectors); + rhs.push_back(constraint.rhs); + } + + const auto component_result = + solve_singleton_component_lp(component.detectors.size(), row_supports, rhs, + component.cheapest_constraint_for_local_detector, seed_budgets); + stats.lp_calls += component_result.simplex_solves; + + if (component_result.unbounded) { + throw std::runtime_error("Singleton custom LP became unbounded."); + } + if (!component_result.success) { + throw std::runtime_error("Singleton custom LP failed."); + } + + solution.value += component_result.objective; + solution.num_active_subsets += component.detectors.size(); + solution.num_variables += component.detectors.size(); + solution.num_constraints += component_result.reduced_constraints; + + for (size_t local = 0; local < component.detectors.size(); ++local) { + solution.active_detectors.push_back(component.detectors[local]); + solution.detector_budgets.push_back(component_result.detector_budgets[local]); + } + } + + const auto stop_time = std::chrono::high_resolution_clock::now(); + stats.lp_total_seconds += + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + return solution; +} + +double TesseractFTLDecoder::project_from_exact_solution( + const ExactSubsetSolution& solution, const boost::dynamic_bitset<>& detectors, + const std::vector& detector_cost_tuples) { + stats.heuristic_calls++; + stats.projection_heuristic_calls++; + + double total = 0.0; + for (size_t detector = detectors.find_first(); detector != boost::dynamic_bitset<>::npos; + detector = detectors.find_next(detector)) { + bool has_available = false; + for (int ei : d2e[detector]) { + if (!detector_cost_tuples[(size_t)ei].error_blocked) { + has_available = true; + break; + } + } + if (!has_available) return INF_D; + total += lookup_detector_budget(solution, (int)detector); + } + return total; +} + +void TesseractFTLDecoder::reset_decode_state() { + low_confidence_flag = false; + predicted_errors_buffer.clear(); + error_chain_arena.clear(); + exact_solution_arena.clear(); + stats.clear(); +} + +void TesseractFTLDecoder::decode_to_errors(const std::vector& detections) { + if (plain_delegate) { + plain_delegate->decode_to_errors(detections); + predicted_errors_buffer = plain_delegate->predicted_errors_buffer; + low_confidence_flag = plain_delegate->low_confidence_flag; + stats.clear(); + return; + } + + std::vector best_errors; + double best_cost = std::numeric_limits::max(); + bool any_success = false; + TesseractFTLStats aggregate_stats; + stats.clear(); + + if (config.beam_climbing) { + int beam = 0; + int detector_order = 0; + for (int trial = 0; trial < std::max(config.det_beam + 1, int(config.det_orders.size())); + ++trial) { + decode_to_errors(detections, (size_t)detector_order, (size_t)beam); + aggregate_stats.accumulate(stats); + const double local_cost = cost_from_errors(predicted_errors_buffer); + if (!low_confidence_flag && local_cost < best_cost) { + best_errors = predicted_errors_buffer; + best_cost = local_cost; + any_success = true; + } + if (config.verbose) { + std::cout << "for detector_order " << detector_order << " beam " << beam + << " got low confidence " << low_confidence_flag << " and cost " << local_cost + << ". Best cost so far: " << best_cost << std::endl; + } + beam = (beam + 1) % (config.det_beam + 1); + detector_order = (detector_order + 1) % config.det_orders.size(); + } + } else { + for (size_t detector_order = 0; detector_order < config.det_orders.size(); ++detector_order) { + decode_to_errors(detections, detector_order, config.det_beam); + aggregate_stats.accumulate(stats); + const double local_cost = cost_from_errors(predicted_errors_buffer); + if (!low_confidence_flag && local_cost < best_cost) { + best_errors = predicted_errors_buffer; + best_cost = local_cost; + any_success = true; + } + if (config.verbose) { + std::cout << "for detector_order " << detector_order << " beam " << config.det_beam + << " got low confidence " << low_confidence_flag << " and cost " << local_cost + << ". Best cost so far: " << best_cost << std::endl; + } + } + } + predicted_errors_buffer = best_errors; + low_confidence_flag = !any_success; + stats = aggregate_stats; +} + +void TesseractFTLDecoder::decode_to_errors(const std::vector& detections, + size_t detector_order, size_t detector_beam) { + if (plain_delegate) { + plain_delegate->decode_to_errors(detections, detector_order, detector_beam); + predicted_errors_buffer = plain_delegate->predicted_errors_buffer; + low_confidence_flag = plain_delegate->low_confidence_flag; + return; + } + + reset_decode_state(); + if (config.pqlimit != std::numeric_limits::max()) { + const size_t reserve_size = std::min(config.pqlimit, 5000000); + error_chain_arena.reserve(reserve_size); + exact_solution_arena.reserve(reserve_size / 4 + 1); + } + + std::priority_queue, std::greater> pq; + std::unordered_map, DynamicBitsetHash>> + visited_detectors; + + boost::dynamic_bitset<> initial_detectors(num_detectors, false); + std::vector initial_detector_cost_tuples(num_errors); + for (size_t detector : detections) { + if (detector >= num_detectors) { + throw std::runtime_error("Symptom references detector >= num_detectors"); + } + initial_detectors[detector] = true; + } + + size_t min_num_dets = detections.size(); + size_t max_num_dets = min_num_dets + detector_beam; + + FTLNode root; + root.g_cost = 0.0; + root.num_dets = min_num_dets; + root.depth = 0; + root.error_chain_idx = -1; + root.warm_solution_idx = -1; + root.exact_solution_idx = -1; + + ExactSubsetSolution root_exact = + solve_exact_subset_lp(initial_detectors, initial_detector_cost_tuples, -1); + if (root_exact.value == INF_D) { + low_confidence_flag = true; + return; + } + exact_solution_arena.push_back(std::move(root_exact)); + root.exact_solution_idx = (int64_t)exact_solution_arena.size() - 1; + root.f_cost = exact_solution_arena.back().value; + root.h_cost = exact_solution_arena.back().value; + root.exact_refined = true; + root.heuristic_source = FTLHeuristicSource::kExact; + pq.push(root); + stats.num_pq_pushed = 1; + stats.max_queue_size = 1; + + while (!pq.empty()) { + stats.max_queue_size = std::max(stats.max_queue_size, pq.size()); + FTLNode node = pq.top(); + pq.pop(); + stats.num_nodes_popped++; + + if (node.num_dets > max_num_dets) continue; + + boost::dynamic_bitset<> detectors = initial_detectors; + std::vector detector_cost_tuples(num_errors); + flip_detectors_and_block_errors(detector_order, node.error_chain_idx, detectors, + detector_cost_tuples); + + if (config.verbose) { + const size_t projected_unrefined = + stats.projected_nodes_generated - stats.projected_nodes_refined; + std::cout.precision(13); + std::cout << "nodes_popped=" << stats.num_nodes_popped << " len(pq)=" << pq.size() + << " nodes_pushed=" << stats.num_pq_pushed << " lp_calls=" << stats.lp_calls + << " lp_reinserts=" << stats.lp_reinserts + << " proj_generated=" << stats.projected_nodes_generated + << " proj_refined=" << stats.projected_nodes_refined + << " proj_unrefined_so_far=" << projected_unrefined + << " num_dets=" << node.num_dets << " max_num_dets=" << max_num_dets + << " f=" << node.f_cost << " g=" << node.g_cost << " h=" << node.h_cost + << " h_source=" << heuristic_source_to_string(node.heuristic_source) + << " exact_refined=" << node.exact_refined << std::endl; + } + + if (node.num_dets == 0) { + predicted_errors_buffer.resize(node.depth); + int64_t walker_idx = node.error_chain_idx; + for (size_t i = 0; i < node.depth; ++i) { + predicted_errors_buffer[node.depth - 1 - i] = + error_to_dem_error[error_chain_arena[(size_t)walker_idx].error_index]; + walker_idx = error_chain_arena[(size_t)walker_idx].parent_idx; + } + if (config.verbose) { + std::cout << "Decoding complete. Cost: " << node.g_cost + << " num_pq_pushed = " << stats.num_pq_pushed << std::endl; + } + return; + } + + if (node.num_dets < min_num_dets) { + min_num_dets = node.num_dets; + if (config.no_revisit_dets) { + for (size_t count = min_num_dets + detector_beam + 1; count <= max_num_dets; ++count) { + visited_detectors[count].clear(); + } + } + max_num_dets = std::min(max_num_dets, min_num_dets + detector_beam); + } + + if (!node.exact_refined) { + const double prev_h = node.h_cost; + const FTLHeuristicSource prev_source = node.heuristic_source; + ExactSubsetSolution exact_solution = + solve_exact_subset_lp(detectors, detector_cost_tuples, node.warm_solution_idx); + if (prev_source == FTLHeuristicSource::kProjected) stats.projected_nodes_refined++; + if (exact_solution.value == INF_D) { + if (config.verbose) { + std::cout << " lp_refine exact_h=INF discarded=true" << std::endl; + } + continue; + } + if (exact_solution.value + 1e-7 < prev_h) { + throw std::runtime_error("Exact singleton lower bound fell below stored lower bound."); + } + const double delta = exact_solution.value - prev_h; + stats.total_lp_refinement_gain += delta; + stats.max_lp_refinement_gain = std::max(stats.max_lp_refinement_gain, delta); + exact_solution_arena.push_back(std::move(exact_solution)); + node.exact_solution_idx = (int64_t)exact_solution_arena.size() - 1; + node.h_cost = exact_solution_arena.back().value; + node.f_cost = node.g_cost + node.h_cost; + node.exact_refined = true; + node.heuristic_source = FTLHeuristicSource::kExact; + if (config.verbose) { + std::cout << " lp_refine approx_h=" << prev_h << " exact_h=" << node.h_cost + << " delta=" << delta << " vars=" << exact_solution_arena.back().num_variables + << " constraints=" << exact_solution_arena.back().num_constraints + << " reinserted=" << (delta > HEURISTIC_EPS) << std::endl; + } + if (delta > HEURISTIC_EPS) { + stats.lp_reinserts++; + pq.push(node); + stats.num_pq_pushed++; + if (stats.num_pq_pushed > config.pqlimit) { + low_confidence_flag = true; + return; + } + continue; + } + } + + if (config.no_revisit_dets && !visited_detectors[node.num_dets].insert(detectors).second) { + continue; + } + + size_t min_detector = std::numeric_limits::max(); + for (size_t offset = 0; offset < num_detectors; ++offset) { + const size_t detector = config.det_orders[detector_order][offset]; + if (detectors[detector]) { + min_detector = detector; + break; + } + } + + std::vector prefix_blocked(num_errors, 0); + for (size_t ei = 0; ei < num_errors; ++ei) { + prefix_blocked[ei] = detector_cost_tuples[ei].error_blocked; + } + + size_t children_generated = 0; + size_t children_projected = 0; + size_t children_beam_pruned = 0; + size_t children_infeasible = 0; + + for (int ei : d2e[min_detector]) { + prefix_blocked[(size_t)ei] = 1; + if (detector_cost_tuples[(size_t)ei].error_blocked) continue; + + boost::dynamic_bitset<> child_detectors = detectors; + for (int detector : edets[(size_t)ei]) child_detectors[(size_t)detector] = !child_detectors[(size_t)detector]; + const size_t child_num_dets = child_detectors.count(); + if (child_num_dets > max_num_dets) { + children_beam_pruned++; + continue; + } + + std::vector child_detector_cost_tuples(num_errors); + for (size_t j = 0; j < num_errors; ++j) { + child_detector_cost_tuples[j].error_blocked = prefix_blocked[j]; + } + + const double child_h = + project_from_exact_solution(exact_solution_arena[(size_t)node.exact_solution_idx], + child_detectors, child_detector_cost_tuples); + stats.projected_nodes_generated++; + children_projected++; + if (child_h == INF_D) { + children_infeasible++; + continue; + } + + error_chain_arena.emplace_back(); + auto& chain_node = error_chain_arena.back(); + chain_node.error_index = (size_t)ei; + chain_node.min_detector = min_detector; + chain_node.parent_idx = node.error_chain_idx; + + FTLNode child; + child.g_cost = node.g_cost + errors[(size_t)ei].likelihood_cost; + child.h_cost = child_h; + child.f_cost = child.g_cost + child.h_cost; + child.num_dets = child_num_dets; + child.depth = node.depth + 1; + child.error_chain_idx = (int64_t)error_chain_arena.size() - 1; + child.warm_solution_idx = node.exact_solution_idx; + child.exact_solution_idx = -1; + child.exact_refined = false; + child.heuristic_source = FTLHeuristicSource::kProjected; + pq.push(child); + stats.num_pq_pushed++; + children_generated++; + if (stats.num_pq_pushed > config.pqlimit) { + low_confidence_flag = true; + return; + } + } + + if (config.verbose) { + const size_t projected_unrefined = + stats.projected_nodes_generated - stats.projected_nodes_refined; + std::cout << " expanded children_generated=" << children_generated + << " children_projected=" << children_projected + << " beam_pruned=" << children_beam_pruned + << " infeasible=" << children_infeasible << " lp_calls=" << stats.lp_calls + << " proj_unrefined_so_far=" << projected_unrefined << std::endl; + } + } + + if (config.verbose) { + std::cout << "Decoding failed to converge within beam limit." << std::endl; + } + low_confidence_flag = true; +} + +double TesseractFTLDecoder::cost_from_errors(const std::vector& predicted_errors) const { + if (plain_delegate) return plain_delegate->cost_from_errors(predicted_errors); + double total_cost = 0.0; + for (size_t dem_error_index : predicted_errors) { + const size_t error_index = dem_error_to_error.at(dem_error_index); + if (error_index == std::numeric_limits::max()) { + throw std::invalid_argument("error index does not map to a retained decoder error"); + } + total_cost += errors[error_index].likelihood_cost; + } + return total_cost; +} + +std::vector TesseractFTLDecoder::get_flipped_observables( + const std::vector& predicted_errors) const { + if (plain_delegate) return plain_delegate->get_flipped_observables(predicted_errors); + std::unordered_set flipped_observables_set; + for (size_t dem_error_index : predicted_errors) { + const size_t error_index = dem_error_to_error.at(dem_error_index); + if (error_index == std::numeric_limits::max()) { + throw std::invalid_argument("error index does not map to a retained decoder error"); + } + for (int obs_index : errors[error_index].symptom.observables) { + if (flipped_observables_set.count(obs_index)) { + flipped_observables_set.erase(obs_index); + } else { + flipped_observables_set.insert(obs_index); + } + } + } + std::vector flipped_observables(flipped_observables_set.begin(), + flipped_observables_set.end()); + std::sort(flipped_observables.begin(), flipped_observables.end()); + return flipped_observables; +} + +std::vector TesseractFTLDecoder::decode(const std::vector& detections) { + decode_to_errors(detections); + return get_flipped_observables(predicted_errors_buffer); +} + +void TesseractFTLDecoder::decode_shots(std::vector& shots, + std::vector>& obs_predicted) { + obs_predicted.resize(shots.size()); + for (size_t i = 0; i < shots.size(); ++i) { + obs_predicted[i] = decode(shots[i].hits); + } +} diff --git a/src/tesseract_ftl.h b/src/tesseract_ftl.h new file mode 100644 index 0000000..9b05e43 --- /dev/null +++ b/src/tesseract_ftl.h @@ -0,0 +1,203 @@ + +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TESSERACT_FTL_DECODER_H +#define TESSERACT_FTL_DECODER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "stim.h" +#include "tesseract.h" +#include "utils.h" +#include "visualization.h" + +constexpr size_t DEFAULT_FTL_SUBSET_DETCOST_SIZE = 0; + +struct TesseractFTLConfig { + stim::DetectorErrorModel dem; + int det_beam = DEFAULT_DET_BEAM; + bool beam_climbing = false; + bool no_revisit_dets = true; + + bool verbose = false; + bool merge_errors = true; + size_t pqlimit = DEFAULT_PQLIMIT; + std::vector> det_orders; + double det_penalty = 0; + bool create_visualization = false; + + // 0 = delegate to the original Tesseract detcost heuristic. + // 1 = use the singleton fractional lower bound implemented in this file. + size_t subset_detcost_size = DEFAULT_FTL_SUBSET_DETCOST_SIZE; + + std::string str(); +}; + +enum class FTLHeuristicSource : uint8_t { kPlain = 0, kProjected = 1, kExact = 2 }; + +struct TesseractFTLStats { + size_t num_pq_pushed = 0; + size_t num_nodes_popped = 0; + size_t max_queue_size = 0; + + size_t heuristic_calls = 0; + size_t plain_heuristic_calls = 0; + size_t projection_heuristic_calls = 0; + size_t exact_refinement_calls = 0; + size_t lp_calls = 0; + size_t lp_reinserts = 0; + size_t projected_nodes_generated = 0; + size_t projected_nodes_refined = 0; + double total_lp_refinement_gain = 0.0; + double max_lp_refinement_gain = 0.0; + double lp_total_seconds = 0.0; + + void clear(); + void accumulate(const TesseractFTLStats& other); +}; + +struct TesseractFTLDecoder { + TesseractFTLConfig config; + Visualizer visualizer; + + explicit TesseractFTLDecoder(TesseractFTLConfig config); + ~TesseractFTLDecoder(); + + // Clears the predicted_errors_buffer and fills it with the decoded errors for + // these detection events. + void decode_to_errors(const std::vector& detections); + + // Clears the predicted_errors_buffer and fills it with the decoded errors for + // these detection events, using a specified detector ordering index. + void decode_to_errors(const std::vector& detections, size_t detector_order, + size_t detector_beam); + + // Returns the bitwise XOR of the observables flipped by the errors in the given array, indexed by + // the original flattened DEM error indices. + std::vector get_flipped_observables(const std::vector& predicted_errors) const; + + // Returns the sum of likelihood costs of the errors in the given array, indexed by the original + // flattened DEM error indices. + double cost_from_errors(const std::vector& predicted_errors) const; + + std::vector decode(const std::vector& detections); + void decode_shots(std::vector& shots, + std::vector>& obs_predicted); + + bool low_confidence_flag = false; + std::vector predicted_errors_buffer; + std::vector dem_error_to_error; + std::vector error_to_dem_error; + std::vector errors; + size_t num_observables = 0; + size_t num_detectors = 0; + TesseractFTLStats stats; + + private: + struct ErrorCost { + double likelihood_cost = 0; + double min_cost = 0; + }; + + struct FTLNode { + double f_cost = 0.0; + double g_cost = 0.0; + double h_cost = 0.0; + size_t num_dets = 0; + size_t depth = 0; + int64_t error_chain_idx = -1; + int64_t warm_solution_idx = -1; + int64_t exact_solution_idx = -1; + bool exact_refined = false; + FTLHeuristicSource heuristic_source = FTLHeuristicSource::kPlain; + + bool operator>(const FTLNode& other) const; + }; + + struct DetectorCostTuple { + uint32_t error_blocked = 0; + }; + + struct SingletonPatternConstraint { + std::vector local_detectors; + double rhs = 0.0; + }; + + struct SingletonLPComponent { + std::vector detectors; + std::vector constraints; + std::vector cheapest_constraint_for_local_detector; + }; + + struct ExactSubsetSolution { + double value = 0.0; + size_t num_active_subsets = 0; + size_t num_components = 0; + size_t num_variables = 0; + size_t num_constraints = 0; + std::vector active_detectors; + std::vector detector_budgets; + }; + + struct SingletonBuildResult { + bool feasible = true; + std::vector components; + }; + + struct DynamicBitsetHash { + size_t operator()(const boost::dynamic_bitset<>& bs) const; + }; + + std::vector> d2e; + std::vector> edets; + size_t num_errors = 0; + std::vector error_costs; + std::vector error_chain_arena; + std::vector exact_solution_arena; + + // If subset_detcost_size == 0, delegate to the original Tesseract decoder. + std::unique_ptr plain_delegate; + + void initialize_structures(size_t num_detectors); + + void flip_detectors_and_block_errors(size_t detector_order, int64_t error_chain_idx, + boost::dynamic_bitset<>& detectors, + std::vector& detector_cost_tuples) const; + + SingletonBuildResult build_singleton_components( + const boost::dynamic_bitset<>& detectors, + const std::vector& detector_cost_tuples) const; + + ExactSubsetSolution solve_exact_subset_lp(const boost::dynamic_bitset<>& detectors, + const std::vector& detector_cost_tuples, + int64_t warm_solution_idx); + + double project_from_exact_solution(const ExactSubsetSolution& solution, + const boost::dynamic_bitset<>& detectors, + const std::vector& detector_cost_tuples); + + void reset_decode_state(); +}; + +#endif // TESSERACT_FTL_DECODER_H diff --git a/src/tesseract_ftl_main.cc b/src/tesseract_ftl_main.cc new file mode 100644 index 0000000..b0876e9 --- /dev/null +++ b/src/tesseract_ftl_main.cc @@ -0,0 +1,495 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "stim.h" +#include "tesseract_ftl.h" +#include "utils.h" + +struct Args { + std::string circuit_path; + std::string dem_path; + bool no_merge_errors = false; + + uint64_t det_order_seed; + size_t num_det_orders = 10; + bool det_order_bfs = false; + bool det_order_index = false; + bool det_order_coordinate = false; + + size_t sample_num_shots = 0; + size_t max_errors = SIZE_MAX; + uint64_t sample_seed; + + size_t shot_range_begin = 0; + size_t shot_range_end = 0; + + std::string in_fname = ""; + std::string in_format = ""; + std::string obs_in_fname = ""; + std::string obs_in_format = ""; + bool append_observables = false; + std::string out_fname = ""; + std::string out_format = ""; + + std::string dem_out_fname = ""; + std::string stats_out_fname = ""; + + size_t num_threads = 1; + + size_t det_beam; + double det_penalty = 0; + bool beam_climbing = false; + bool no_revisit_dets = false; + size_t pqlimit; + + size_t subset_detcost_size = 0; + + bool verbose = false; + bool print_stats = false; + + bool has_observables() { + return append_observables || !obs_in_fname.empty() || (sample_num_shots > 0); + } + + void validate() { + if (circuit_path.empty() && dem_path.empty()) { + throw std::invalid_argument("Must provide at least one of --circuit or --dem"); + } + int det_order_flags = int(det_order_bfs) + int(det_order_index) + int(det_order_coordinate); + if (det_order_flags > 1) { + throw std::invalid_argument( + "Only one of --det-order-bfs, --det-order-index, or --det-order-coordinate may be set."); + } + int num_data_sources = int(sample_num_shots > 0) + int(!in_fname.empty()); + if (num_data_sources != 1) { + throw std::invalid_argument("Requires exactly 1 source of shots."); + } + if (!in_fname.empty() && in_format.empty()) { + throw std::invalid_argument("If --in is provided, must also specify --in-format."); + } + if (!out_fname.empty() && out_format.empty()) { + throw std::invalid_argument("If --out is provided, must also specify --out-format."); + } + if (!in_format.empty() && !stim::format_name_to_enum_map().contains(in_format)) { + throw std::invalid_argument("Invalid format: " + in_format); + } + if (!obs_in_format.empty() && !stim::format_name_to_enum_map().contains(obs_in_format)) { + throw std::invalid_argument("Invalid format: " + obs_in_format); + } + if (!out_format.empty() && !stim::format_name_to_enum_map().contains(out_format)) { + throw std::invalid_argument("Invalid format: " + out_format); + } + if (!obs_in_fname.empty() && in_fname.empty()) { + throw std::invalid_argument( + "Cannot load observable flips without a corresponding detection event data file."); + } + if (num_threads == 0) { + throw std::invalid_argument("--threads must be at least 1."); + } + if (shot_range_begin || shot_range_end) { + if (shot_range_end < shot_range_begin) { + throw std::invalid_argument("Provided shot range must have end >= begin."); + } + } + if (sample_num_shots > 0 && circuit_path.empty()) { + throw std::invalid_argument("Cannot sample shots without a circuit."); + } + if (beam_climbing && det_beam == INF_DET_BEAM) { + throw std::invalid_argument("Beam climbing requires a finite beam"); + } + if (subset_detcost_size > 1) { + throw std::invalid_argument("This prototype currently supports --subset-detcost-size <= 1"); + } + } + + void extract(TesseractFTLConfig& config, std::vector& shots, + std::unique_ptr& writer) { + stim::Circuit circuit; + if (!circuit_path.empty()) { + FILE* file = fopen(circuit_path.c_str(), "r"); + if (!file) throw std::invalid_argument("Could not open the file: " + circuit_path); + circuit = stim::Circuit::from_file(file); + fclose(file); + } + + if (!dem_path.empty()) { + FILE* file = fopen(dem_path.c_str(), "r"); + if (!file) throw std::invalid_argument("Could not open the file: " + dem_path); + config.dem = stim::DetectorErrorModel::from_file(file); + fclose(file); + } else { + assert(!circuit_path.empty()); + config.dem = stim::ErrorAnalyzer::circuit_to_detector_error_model( + circuit, /*decompose_errors=*/false, /*fold_loops=*/true, + /*allow_gauge_detectors=*/true, + /*approximate_disjoint_errors_threshold=*/1, + /*ignore_decomposition_failures=*/false, + /*block_decomposition_from_introducing_remnant_edges=*/false); + } + + config.merge_errors = !no_merge_errors; + config.subset_detcost_size = subset_detcost_size; + + { + DetOrder order = DetOrder::DetBFS; + if (det_order_index) { + order = DetOrder::DetIndex; + } else if (det_order_coordinate) { + order = DetOrder::DetCoordinate; + } + config.det_orders = build_det_orders(config.dem, num_det_orders, order, det_order_seed); + } + + if (sample_num_shots > 0) { + assert(!circuit_path.empty()); + std::mt19937_64 rng(sample_seed); + size_t num_detectors = circuit.count_detectors(); + const auto [dets, obs] = + stim::sample_batch_detection_events<64>(circuit, sample_num_shots, rng); + stim::simd_bit_table<64> obs_T = obs.transposed(); + shots.resize(sample_num_shots); + for (size_t k = 0; k < sample_num_shots; k++) { + shots[k].obs_mask = obs_T[k]; + for (size_t d = 0; d < num_detectors; d++) { + if (dets[d][k]) shots[k].hits.push_back(d); + } + } + } + + if (!in_fname.empty()) { + FILE* shots_file = fopen(in_fname.c_str(), "r"); + if (!shots_file) throw std::invalid_argument("Could not open the file: " + in_fname); + stim::FileFormatData shots_in_format = stim::format_name_to_enum_map().at(in_format); + auto reader = stim::MeasureRecordReader::make( + shots_file, shots_in_format.id, 0, config.dem.count_detectors(), + append_observables * config.dem.count_observables()); + stim::SparseShot sparse_shot; + sparse_shot.clear(); + while (reader->start_and_read_entire_record(sparse_shot)) { + shots.push_back(sparse_shot); + sparse_shot.clear(); + } + fclose(shots_file); + } + + if (!obs_in_fname.empty()) { + FILE* obs_file = fopen(obs_in_fname.c_str(), "r"); + if (!obs_file) throw std::invalid_argument("Could not open the file: " + obs_in_fname); + stim::FileFormatData obs_format = stim::format_name_to_enum_map().at(obs_in_format); + auto obs_reader = stim::MeasureRecordReader::make( + obs_file, obs_format.id, 0, 0, config.dem.count_observables()); + stim::SparseShot sparse_shot; + sparse_shot.clear(); + size_t num_obs_shots = 0; + while (obs_reader->start_and_read_entire_record(sparse_shot)) { + if (num_obs_shots >= shots.size()) { + throw std::invalid_argument("Shot data ended before obs data."); + } + shots[num_obs_shots].obs_mask = sparse_shot.obs_mask; + sparse_shot.clear(); + ++num_obs_shots; + } + if (num_obs_shots != shots.size()) { + throw std::invalid_argument("Obs data ended before shot data ended."); + } + fclose(obs_file); + } + + if (shot_range_begin || shot_range_end) { + if (shot_range_end > shots.size()) { + throw std::invalid_argument("Shot range end is past end of shots array."); + } + std::vector shots_in_range(shots.begin() + shot_range_begin, + shots.begin() + shot_range_end); + std::swap(shots_in_range, shots); + } + + if (!out_fname.empty()) { + stim::FileFormatData predictions_out_format = stim::format_name_to_enum_map().at(out_format); + FILE* predictions_file = stdout; + if (out_fname != "-") predictions_file = fopen(out_fname.c_str(), "w"); + writer = stim::MeasureRecordWriter::make(predictions_file, predictions_out_format.id); + writer->begin_result_type('L'); + } + + config.det_beam = det_beam; + config.det_penalty = det_penalty; + config.beam_climbing = beam_climbing; + config.no_revisit_dets = no_revisit_dets; + config.pqlimit = pqlimit; + config.verbose = verbose; + } +}; + +int main(int argc, char* argv[]) { + std::cout.precision(16); + argparse::ArgumentParser program("tesseract_ftl"); + Args args; + + program.add_argument("--circuit").help("Stim circuit file path").store_into(args.circuit_path); + program.add_argument("--dem").help("Stim dem file path").store_into(args.dem_path); + program.add_argument("--no-merge-errors") + .help("If provided, will not merge identical error mechanisms.") + .store_into(args.no_merge_errors); + program.add_argument("--subset-detcost-size") + .help("0 = plain detcost delegate, 1 = singleton fractional lower bound") + .default_value(size_t(0)) + .store_into(args.subset_detcost_size); + + program.add_argument("--num-det-orders") + .help("Number of ways to orient the manifold when reordering the detectors") + .metavar("N") + .default_value(size_t(1)) + .store_into(args.num_det_orders); + program.add_argument("--det-order-bfs") + .help("Use BFS-based detector ordering") + .flag() + .store_into(args.det_order_bfs); + program.add_argument("--det-order-index") + .help("Randomly choose increasing or decreasing detector index order") + .flag() + .store_into(args.det_order_index); + program.add_argument("--det-order-coordinate") + .help("Random geometric detector orientation ordering") + .flag() + .store_into(args.det_order_coordinate); + program.add_argument("--det-order-seed") + .help("Seed used when initializing the random detector traversal orderings.") + .default_value(static_cast(518278944)) + .store_into(args.det_order_seed); + + program.add_argument("--sample-num-shots") + .help("Sample the requested number of shots from the Stim circuit.") + .store_into(args.sample_num_shots); + program.add_argument("--max-errors") + .help("Stop after at least this many errors have been observed.") + .store_into(args.max_errors); + program.add_argument("--sample-seed") + .help("Seed used when initializing the random number generator for sampling shots") + .default_value(static_cast(std::random_device()())) + .store_into(args.sample_seed); + + program.add_argument("--shot-range-begin").default_value(size_t(0)).store_into(args.shot_range_begin); + program.add_argument("--shot-range-end").default_value(size_t(0)).store_into(args.shot_range_end); + + program.add_argument("--in").default_value(std::string("")).store_into(args.in_fname); + std::string in_formats; + bool first = true; + for (const auto& [key, value] : stim::format_name_to_enum_map()) { + if (!first) in_formats += "/"; + first = false; + in_formats += key; + } + program.add_argument("--in-format", "--in_format") + .default_value(std::string("")) + .store_into(args.in_format); + program.add_argument("--in-includes-appended-observables", "--in_includes_appended_observables") + .default_value(false) + .store_into(args.append_observables) + .flag(); + program.add_argument("--obs_in", "--obs-in") + .default_value(std::string("")) + .store_into(args.obs_in_fname); + program.add_argument("--obs-in-format", "--obs_in_format") + .default_value(std::string("")) + .store_into(args.obs_in_format); + program.add_argument("--out").default_value(std::string("")) + .store_into(args.out_fname); + program.add_argument("--out-format").default_value(std::string("")) + .store_into(args.out_format); + program.add_argument("--dem-out").default_value(std::string("")) + .store_into(args.dem_out_fname); + program.add_argument("--stats-out").default_value(std::string("")) + .store_into(args.stats_out_fname); + + program.add_argument("--threads") + .default_value(size_t( + std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) + .store_into(args.num_threads); + program.add_argument("--beam") + .default_value(INF_DET_BEAM) + .store_into(args.det_beam); + program.add_argument("--det-penalty").default_value(0.0).store_into(args.det_penalty); + program.add_argument("--beam-climbing").flag().store_into(args.beam_climbing); + program.add_argument("--no-revisit-dets").flag().store_into(args.no_revisit_dets); + program.add_argument("--pqlimit") + .default_value(std::numeric_limits::max()) + .store_into(args.pqlimit); + program.add_argument("--verbose").flag().store_into(args.verbose); + program.add_argument("--print-stats").flag().store_into(args.print_stats); + + try { + program.parse_args(argc, argv); + } catch (const std::exception& err) { + std::cerr << err.what() << std::endl; + std::cerr << program; + return EXIT_FAILURE; + } + args.validate(); + + TesseractFTLConfig config; + std::vector shots; + std::unique_ptr writer; + args.extract(config, shots, writer); + + std::vector obs_predicted(shots.size()); + std::vector cost_predicted(shots.size()); + std::vector decoding_time_seconds(shots.size()); + std::vector> low_confidence(shots.size()); + const stim::DetectorErrorModel original_dem = config.dem.flattened(); + std::vector> decoders(args.num_threads); + std::vector> error_use_per_thread( + args.num_threads, std::vector(original_dem.count_errors())); + std::vector decoder_stats_per_thread(args.num_threads); + + bool has_obs = args.has_observables(); + size_t num_errors = 0; + size_t num_low_confidence = 0; + double total_time_seconds = 0; + size_t num_observables = config.dem.count_observables(); + + size_t shot = parallel_for_shots_in_order( + shots.size(), args.num_threads, + [&](size_t thread_index, size_t shot_index) { + if (!decoders[thread_index]) { + decoders[thread_index] = std::make_unique(config); + } + auto& decoder = *decoders[thread_index]; + auto& error_use = error_use_per_thread[thread_index]; + auto start_time = std::chrono::high_resolution_clock::now(); + decoder.decode_to_errors(shots[shot_index].hits); + auto stop_time = std::chrono::high_resolution_clock::now(); + decoding_time_seconds[shot_index] = + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + obs_predicted[shot_index] = + vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer)); + low_confidence[shot_index] = decoder.low_confidence_flag; + cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer); + decoder_stats_per_thread[thread_index].accumulate(decoder.stats); + if (!has_obs || shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { + for (size_t ei : decoder.predicted_errors_buffer) ++error_use[ei]; + } + }, + [&](size_t shot_index) { + if (writer) { + writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables); + writer->write_end(); + } + if (low_confidence[shot_index]) { + ++num_low_confidence; + } else if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) { + ++num_errors; + } + total_time_seconds += decoding_time_seconds[shot_index]; + if (args.print_stats) { + std::cout << "num_shots = " << (shot_index + 1) + << " num_low_confidence = " << num_low_confidence + << " num_errors = " << num_errors + << " total_time_seconds = " << total_time_seconds << std::endl; + std::cout << "cost = " << cost_predicted[shot_index] << std::endl; + std::cout.flush(); + } + return num_errors < args.max_errors; + }); + + std::vector error_use_totals(original_dem.count_errors()); + for (const auto& error_use : error_use_per_thread) { + for (size_t ei = 0; ei < error_use_totals.size(); ++ei) error_use_totals[ei] += error_use[ei]; + } + TesseractFTLStats decoder_stats_total; + for (const auto& s : decoder_stats_per_thread) decoder_stats_total.accumulate(s); + + if (!args.dem_out_fname.empty()) { + size_t num_usage_dem_shots = shot; + if (has_obs) num_usage_dem_shots -= num_errors; + stim::DetectorErrorModel est_dem = + common::dem_from_counts(original_dem, error_use_totals, num_usage_dem_shots); + std::ofstream out(args.dem_out_fname, std::ofstream::out); + if (!out.is_open()) throw std::invalid_argument("Failed to open " + args.dem_out_fname); + out << est_dem << '\n'; + } + + bool print_final_stats = true; + if (!args.stats_out_fname.empty()) { + nlohmann::json stats_json = { + {"circuit_path", args.circuit_path}, + {"dem_path", args.dem_path}, + {"max_errors", args.max_errors}, + {"sample_seed", args.sample_seed}, + {"det_beam", args.det_beam}, + {"det_penalty", args.det_penalty}, + {"beam_climbing", args.beam_climbing}, + {"no_revisit_dets", args.no_revisit_dets}, + {"pqlimit", args.pqlimit}, + {"num_det_orders", args.num_det_orders}, + {"det_order_seed", args.det_order_seed}, + {"subset_detcost_size", args.subset_detcost_size}, + {"total_time_seconds", total_time_seconds}, + {"num_errors", num_errors}, + {"num_low_confidence", num_low_confidence}, + {"num_shots", shot}, + {"num_threads", args.num_threads}, + {"sample_num_shots", args.sample_num_shots}, + {"ftl_num_pq_pushed", decoder_stats_total.num_pq_pushed}, + {"ftl_num_nodes_popped", decoder_stats_total.num_nodes_popped}, + {"ftl_max_queue_size", decoder_stats_total.max_queue_size}, + {"ftl_heuristic_calls", decoder_stats_total.heuristic_calls}, + {"ftl_plain_heuristic_calls", decoder_stats_total.plain_heuristic_calls}, + {"ftl_projection_heuristic_calls", decoder_stats_total.projection_heuristic_calls}, + {"ftl_exact_refinement_calls", decoder_stats_total.exact_refinement_calls}, + {"ftl_lp_calls", decoder_stats_total.lp_calls}, + {"ftl_lp_reinserts", decoder_stats_total.lp_reinserts}, + {"ftl_projected_nodes_generated", decoder_stats_total.projected_nodes_generated}, + {"ftl_projected_nodes_refined", decoder_stats_total.projected_nodes_refined}, + {"ftl_total_lp_refinement_gain", decoder_stats_total.total_lp_refinement_gain}, + {"ftl_max_lp_refinement_gain", decoder_stats_total.max_lp_refinement_gain}, + {"ftl_lp_total_seconds", decoder_stats_total.lp_total_seconds}, + }; + + if (args.stats_out_fname == "-") { + std::cout << stats_json << std::endl; + print_final_stats = false; + } else { + std::ofstream out(args.stats_out_fname, std::ofstream::out); + out << stats_json << std::endl; + } + } + + if (print_final_stats) { + std::cout << "num_shots = " << shot; + std::cout << " num_low_confidence = " << num_low_confidence; + if (has_obs) std::cout << " num_errors = " << num_errors; + std::cout << " total_time_seconds = " << total_time_seconds; + if (args.subset_detcost_size > 0) { + std::cout << " lp_calls = " << decoder_stats_total.lp_calls; + std::cout << " lp_reinserts = " << decoder_stats_total.lp_reinserts; + std::cout << " projected_nodes_generated = " << decoder_stats_total.projected_nodes_generated; + std::cout << " projected_nodes_refined = " << decoder_stats_total.projected_nodes_refined; + } + std::cout << std::endl; + } + return 0; +} diff --git a/src/tesseract_main.cc b/src/tesseract_main.cc index ab5ed9c..ff7b3d0 100644 --- a/src/tesseract_main.cc +++ b/src/tesseract_main.cc @@ -483,6 +483,8 @@ int main(int argc, char* argv[]) { std::vector obs_predicted(shots.size()); std::vector cost_predicted(shots.size()); std::vector decoding_time_seconds(shots.size()); + std::vector num_pq_pushed_per_shot(shots.size()); + std::vector num_pq_popped_per_shot(shots.size()); std::vector> low_confidence(shots.size()); const stim::DetectorErrorModel original_dem = config.dem.flattened(); std::vector> decoders(args.num_threads); @@ -511,6 +513,8 @@ int main(int argc, char* argv[]) { vector_to_u64_mask(decoder.get_flipped_observables(decoder.predicted_errors_buffer)); low_confidence[shot_index] = decoder.low_confidence_flag; cost_predicted[shot_index] = decoder.cost_from_errors(decoder.predicted_errors_buffer); + num_pq_pushed_per_shot[shot_index] = decoder.num_pq_pushed; + num_pq_popped_per_shot[shot_index] = decoder.num_pq_popped; if (!has_obs or shots[shot_index].obs_mask_as_u64() == obs_predicted[shot_index]) { for (size_t ei : decoder.predicted_errors_buffer) { ++error_use[ei]; @@ -532,6 +536,8 @@ int main(int argc, char* argv[]) { std::cout << "num_shots = " << (shot_index + 1) << " num_low_confidence = " << num_low_confidence << " num_errors = " << num_errors + << " num_pq_pushed = " << num_pq_pushed_per_shot[shot_index] + << " num_pq_popped = " << num_pq_popped_per_shot[shot_index] << " total_time_seconds = " << total_time_seconds << std::endl; std::cout << "cost = " << cost_predicted[shot_index] << std::endl; std::cout.flush(); From 688338cfb28946862df88a1f10d5f15b04e4f411 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Mon, 30 Mar 2026 09:45:56 -0700 Subject: [PATCH 03/25] make pq pop push counters reflective --- src/py/tesseract_test.py | 14 ++++++++++++++ src/tesseract.cc | 5 ++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/py/tesseract_test.py b/src/py/tesseract_test.py index bb62bea..337bf91 100644 --- a/src/py/tesseract_test.py +++ b/src/py/tesseract_test.py @@ -195,6 +195,20 @@ def test_create_tesseract_decoder(): assert decoder.cost_from_errors([1]) == pytest.approx(0.5108256237659907) +def test_tesseract_priority_queue_counters_track_search(): + config = tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL) + decoder = tesseract_decoder.tesseract.TesseractDecoder(config) + + decoder.decode_to_errors(np.array([True, False], dtype=bool)) + assert decoder.num_pq_pushed >= 1 + assert decoder.num_pq_popped >= 1 + assert decoder.num_pq_pushed >= decoder.num_pq_popped + + decoder.decode_to_errors(np.array([False, False], dtype=bool)) + assert decoder.num_pq_pushed == 1 + assert decoder.num_pq_popped == 1 + + def test_tesseract_compile_decoder(): shared_test_compile_decoder( tesseract_decoder.tesseract.TesseractConfig, diff --git a/src/tesseract.cc b/src/tesseract.cc index cc92d28..6b715f5 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -283,6 +283,8 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, size_t detector_order, size_t detector_beam) { predicted_errors_buffer.clear(); low_confidence_flag = false; + num_pq_pushed = 0; + num_pq_popped = 0; error_chain_arena.clear(); // Can technically be larger than pqlimit, but we need an initial guess on how many nodes we // will process from the queue. @@ -323,11 +325,12 @@ void TesseractDecoder::decode_to_errors(const std::vector& detections, std::vector next_detector_cost_tuples; pq.push({initial_cost, min_num_dets, 0, -1}); - size_t num_pq_pushed = 1; + num_pq_pushed = 1; while (!pq.empty()) { const Node node = pq.top(); pq.pop(); + ++num_pq_popped; if (node.num_dets > max_num_dets) continue; From 5054aec8846b0208b7b5e49faada2a6dc8f9bc56 Mon Sep 17 00:00:00 2001 From: noajshu Date: Mon, 30 Mar 2026 20:53:50 +0000 Subject: [PATCH 04/25] run clang format and optimize the ftl prototype implementation --- ...totype_singleton_greedy_heuristics_lazy.py | 39 ++- src/tesseract.cc | 4 +- src/tesseract.h | 2 +- src/tesseract.pybind.h | 10 +- src/tesseract.test.cc | 2 +- src/tesseract_ftl.cc | 256 ++++++++++-------- src/tesseract_ftl.h | 15 +- src/tesseract_ftl_main.cc | 20 +- 8 files changed, 194 insertions(+), 154 deletions(-) diff --git a/src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py b/src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py index 1051961..346a974 100644 --- a/src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py +++ b/src/py/astar/astar_prototype_singleton_greedy_heuristics_lazy.py @@ -902,6 +902,12 @@ def build_arg_parser() -> argparse.ArgumentParser: ) ) parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--dets", + type=str, + default=None, + help="String of shot dets (e.g., 'shot D0 D1 L2') to parse instead of sampling.", + ) parser.add_argument( "--sample-num-shots", type=int, @@ -1004,15 +1010,30 @@ def main(argv: Optional[Sequence[str]] = None) -> int: dem = circuit.detector_error_model(decompose_errors=False) errors = merged_errors_from_dem(dem) if args.merge_errors else list(iter_dem_errors_from_dem(dem)) - dets, obs = sample_detections_and_observables( - circuit, - num_shots=args.sample_num_shots, - seed=args.seed, - num_detectors=dem.num_detectors, - num_observables=dem.num_observables, - ) - shot_dets = dets[args.shot] - shot_obs = obs[args.shot] + if args.dets is not None: + shot_dets = np.zeros(dem.num_detectors, dtype=bool) + shot_obs = np.zeros(dem.num_observables, dtype=bool) + for token in args.dets.split(): + if token == "shot": + continue + if token.startswith("D") and token[1:].isdigit(): + d_idx = int(token[1:]) + if d_idx < dem.num_detectors: + shot_dets[d_idx] = True + elif token.startswith("L") and token[1:].isdigit(): + l_idx = int(token[1:]) + if l_idx < dem.num_observables: + shot_obs[l_idx] = True + else: + dets, obs = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + ) + shot_dets = dets[args.shot] + shot_obs = obs[args.shot] decoder = GreedySingletonHeuristicDecoder( errors, diff --git a/src/tesseract.cc b/src/tesseract.cc index 6b715f5..c6ce011 100644 --- a/src/tesseract.cc +++ b/src/tesseract.cc @@ -95,7 +95,9 @@ double TesseractDecoder::get_detcost( for (int ei : d2e[d]) { ec = error_costs[ei]; - if (ec.likelihood_cost * min_det_cost_det_count >= min_cost * errors[ei].symptom.detectors.size()) break; + if (ec.likelihood_cost * min_det_cost_det_count >= + min_cost * errors[ei].symptom.detectors.size()) + break; dct = detector_cost_tuples[ei]; if (!dct.error_blocked) { diff --git a/src/tesseract.h b/src/tesseract.h index 517d326..fc4173d 100644 --- a/src/tesseract.h +++ b/src/tesseract.h @@ -60,7 +60,7 @@ class Node { }; struct DetectorCostTuple { - uint32_t error_blocked; + uint8_t error_blocked; uint32_t detectors_count; }; diff --git a/src/tesseract.pybind.h b/src/tesseract.pybind.h index eb01ed2..5781881 100644 --- a/src/tesseract.pybind.h +++ b/src/tesseract.pybind.h @@ -468,10 +468,12 @@ void add_tesseract_module(py::module& root) { "The configuration used to create this decoder.") .def_readwrite("low_confidence_flag", &TesseractDecoder::low_confidence_flag, "A flag indicating if the decoder's prediction has low confidence.") - .def_readwrite("num_pq_pushed", &TesseractDecoder::num_pq_pushed, - "The number of items pushed to the priority queue during the most recent decode.") - .def_readwrite("num_pq_popped", &TesseractDecoder::num_pq_popped, - "The number of items popped from the priority queue during the most recent decode.") + .def_readwrite( + "num_pq_pushed", &TesseractDecoder::num_pq_pushed, + "The number of items pushed to the priority queue during the most recent decode.") + .def_readwrite( + "num_pq_popped", &TesseractDecoder::num_pq_popped, + "The number of items popped from the priority queue during the most recent decode.") .def_readwrite( "predicted_errors_buffer", &TesseractDecoder::predicted_errors_buffer, "A buffer containing the predicted errors from the most recent decode operation.") diff --git a/src/tesseract.test.cc b/src/tesseract.test.cc index 3bb34fb..ae62460 100644 --- a/src/tesseract.test.cc +++ b/src/tesseract.test.cc @@ -409,7 +409,7 @@ TEST(TesseractDetcostTest, ComparesRatiosNotRawCosts) { std::vector tuples(dec.errors.size()); // residual x = {D0, D1} - std::cout <<"dec.d2e.size() = "< SIMPLEX_EPS); const double inv_pivot = 1.0 / pivot_value; for (size_t col = 0; col < width; ++col) { - tableau[pivot_row][col] *= inv_pivot; + tableau[pivot_row * width + col] *= inv_pivot; } - tableau[pivot_row][pivot_col] = 1.0; + tableau[pivot_row * width + pivot_col] = 1.0; - for (size_t row = 0; row <= num_rows; ++row) { + for (size_t row = 0; row < height; ++row) { if (row == pivot_row) continue; - const double factor = tableau[row][pivot_col]; + const double factor = tableau[row * width + pivot_col]; if (std::abs(factor) <= SIMPLEX_EPS) { - tableau[row][pivot_col] = 0.0; + tableau[row * width + pivot_col] = 0.0; continue; } for (size_t col = 0; col < width; ++col) { - tableau[row][col] -= factor * tableau[pivot_row][col]; + tableau[row * width + col] -= factor * tableau[pivot_row * width + col]; } - tableau[row][pivot_col] = 0.0; + tableau[row * width + pivot_col] = 0.0; } basis[pivot_row] = pivot_col; result.pivots++; @@ -170,13 +173,12 @@ DenseSimplexResult solve_dense_primal_packing_lp( size_t entering_col = width; double entering_priority = -INF_D; for (size_t col = 0; col + 1 < width; ++col) { - if (tableau[num_rows][col] >= -SIMPLEX_EPS) continue; + if (tableau[num_rows * width + col] >= -SIMPLEX_EPS) continue; const bool current_is_original = entering_col < num_vars; const bool candidate_is_original = col < num_vars; - const double candidate_priority = - candidate_is_original && entering_priorities != nullptr - ? (*entering_priorities)[col] - : -INF_D; + const double candidate_priority = candidate_is_original && entering_priorities != nullptr + ? (*entering_priorities)[col] + : -INF_D; if (entering_col == width) { entering_col = col; entering_priority = candidate_priority; @@ -190,8 +192,7 @@ DenseSimplexResult solve_dense_primal_packing_lp( continue; } if (candidate_priority > entering_priority + SIMPLEX_EPS || - (std::abs(candidate_priority - entering_priority) <= SIMPLEX_EPS && - col < entering_col)) { + (std::abs(candidate_priority - entering_priority) <= SIMPLEX_EPS && col < entering_col)) { entering_col = col; entering_priority = candidate_priority; } @@ -203,9 +204,9 @@ DenseSimplexResult solve_dense_primal_packing_lp( size_t leaving_row = num_rows; double best_ratio = INF_D; for (size_t row = 0; row < num_rows; ++row) { - const double coeff = tableau[row][entering_col]; + const double coeff = tableau[row * width + entering_col]; if (coeff <= SIMPLEX_EPS) continue; - const double ratio = tableau[row][width - 1] / coeff; + const double ratio = tableau[row * width + width - 1] / coeff; if (ratio + SIMPLEX_EPS < best_ratio) { best_ratio = ratio; leaving_row = row; @@ -224,12 +225,12 @@ DenseSimplexResult solve_dense_primal_packing_lp( for (size_t row = 0; row < num_rows; ++row) { if (basis[row] < num_vars) { - double value = tableau[row][width - 1]; + double value = tableau[row * width + width - 1]; if (std::abs(value) <= SIMPLEX_EPS) value = 0.0; result.solution[basis[row]] = value; } } - result.objective = tableau[num_rows][width - 1]; + result.objective = tableau[num_rows * width + width - 1]; if (std::abs(result.objective) <= SIMPLEX_EPS) result.objective = 0.0; result.success = true; return result; @@ -315,18 +316,8 @@ SingletonComponentSolveResult solve_singleton_component_lp( throw std::runtime_error("Constraint generation exceeded the number of unique constraints."); } - std::vector> reduced_rows; - std::vector reduced_rhs; - reduced_rows.reserve(selected_indices.size()); - reduced_rhs.reserve(selected_indices.size()); - for (int idx : selected_indices) { - reduced_rows.push_back(row_supports[(size_t)idx]); - reduced_rhs.push_back(rhs[(size_t)idx]); - } - - DenseSimplexResult simplex = - solve_dense_primal_packing_lp(num_local_detectors, reduced_rows, reduced_rhs, - &seed_budgets); + DenseSimplexResult simplex = solve_dense_primal_packing_lp( + num_local_detectors, row_supports, selected_indices, rhs, &seed_budgets); result.simplex_solves++; if (simplex.unbounded) { result.unbounded = true; @@ -435,8 +426,7 @@ bool TesseractFTLDecoder::FTLNode::operator>(const FTLNode& other) const { return f_cost > other.f_cost || (f_cost == other.f_cost && num_dets < other.num_dets); } -size_t TesseractFTLDecoder::DynamicBitsetHash::operator()( - const boost::dynamic_bitset<>& bs) const { +size_t TesseractFTLDecoder::DynamicBitsetHash::operator()(const boost::dynamic_bitset<>& bs) const { return boost::hash_value(bs); } @@ -537,7 +527,7 @@ void TesseractFTLDecoder::initialize_structures(size_t num_detectors_) { void TesseractFTLDecoder::flip_detectors_and_block_errors( size_t detector_order, int64_t error_chain_idx, boost::dynamic_bitset<>& detectors, - std::vector& detector_cost_tuples) const { + std::vector& blocked_flags) const { (void)detector_order; int64_t walker_idx = error_chain_idx; while (walker_idx != -1) { @@ -546,7 +536,7 @@ void TesseractFTLDecoder::flip_detectors_and_block_errors( const size_t min_detector = node.min_detector; for (int oei : d2e[min_detector]) { - detector_cost_tuples[(size_t)oei].error_blocked = 1; + blocked_flags[(size_t)oei] = 1; if ((size_t)oei == ei) break; } for (int d : edets[ei]) detectors[(size_t)d] = !detectors[(size_t)d]; @@ -555,8 +545,7 @@ void TesseractFTLDecoder::flip_detectors_and_block_errors( } TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_components( - const boost::dynamic_bitset<>& detectors, - const std::vector& detector_cost_tuples) const { + const boost::dynamic_bitset<>& detectors, const std::vector& blocked_flags) const { SingletonBuildResult result; std::vector active_detectors; @@ -573,7 +562,7 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c std::vector has_available(active_detectors.size(), 0); for (size_t ei = 0; ei < num_errors; ++ei) { - if (detector_cost_tuples[ei].error_blocked) continue; + if (blocked_flags[ei]) continue; int first_active = -1; for (int detector : edets[ei]) { const int active_pos = detector_to_active_pos[(size_t)detector]; @@ -594,15 +583,15 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c } } - std::unordered_map> positions_by_root; - positions_by_root.reserve(active_detectors.size()); + std::vector> positions_by_root(active_detectors.size()); for (int active_pos = 0; active_pos < (int)active_detectors.size(); ++active_pos) { - positions_by_root[uf.find(active_pos)].push_back(active_pos); + positions_by_root[(size_t)uf.find(active_pos)].push_back(active_pos); } std::vector> component_positions; - component_positions.reserve(positions_by_root.size()); - for (auto& [_, positions] : positions_by_root) { + component_positions.reserve(active_detectors.size()); + for (auto& positions : positions_by_root) { + if (positions.empty()) continue; std::sort(positions.begin(), positions.end(), [&](int a, int b) { return active_detectors[(size_t)a] < active_detectors[(size_t)b]; }); @@ -635,7 +624,7 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c result.components.size()); for (size_t ei = 0; ei < num_errors; ++ei) { - if (detector_cost_tuples[ei].error_blocked) continue; + if (blocked_flags[ei]) continue; int component_index = -1; std::vector local_hits; @@ -708,28 +697,26 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c } TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset_lp( - const boost::dynamic_bitset<>& detectors, - const std::vector& detector_cost_tuples, int64_t warm_solution_idx) { + const boost::dynamic_bitset<>& detectors, const std::vector& blocked_flags, + int64_t warm_solution_idx) { stats.heuristic_calls++; stats.exact_refinement_calls++; const auto start_time = std::chrono::high_resolution_clock::now(); ExactSubsetSolution solution; - const auto build = build_singleton_components(detectors, detector_cost_tuples); + const auto build = build_singleton_components(detectors, blocked_flags); if (!build.feasible) { solution.value = INF_D; const auto stop_time = std::chrono::high_resolution_clock::now(); stats.lp_total_seconds += - std::chrono::duration_cast(stop_time - start_time).count() / - 1e6; + std::chrono::duration_cast(stop_time - start_time).count() / 1e6; return solution; } if (build.components.empty()) { solution.value = 0.0; const auto stop_time = std::chrono::high_resolution_clock::now(); stats.lp_total_seconds += - std::chrono::duration_cast(stop_time - start_time).count() / - 1e6; + std::chrono::duration_cast(stop_time - start_time).count() / 1e6; return solution; } @@ -738,12 +725,23 @@ TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset solution.value = 0.0; solution.num_components = build.components.size(); + std::vector> detector_budget_pairs; + detector_budget_pairs.reserve(detectors.count()); for (const auto& component : build.components) { std::vector seed_budgets(component.detectors.size(), 0.0); if (warm_solution != nullptr) { + size_t warm_pos = 0; for (size_t local = 0; local < component.detectors.size(); ++local) { - seed_budgets[local] = lookup_detector_budget(*warm_solution, component.detectors[local]); + int det = component.detectors[local]; + while (warm_pos < warm_solution->active_detectors.size() && + warm_solution->active_detectors[warm_pos] < det) { + ++warm_pos; + } + if (warm_pos < warm_solution->active_detectors.size() && + warm_solution->active_detectors[warm_pos] == det) { + seed_budgets[local] = warm_solution->detector_budgets[warm_pos]; + } } } @@ -756,9 +754,9 @@ TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset rhs.push_back(constraint.rhs); } - const auto component_result = - solve_singleton_component_lp(component.detectors.size(), row_supports, rhs, - component.cheapest_constraint_for_local_detector, seed_budgets); + const auto component_result = solve_singleton_component_lp( + component.detectors.size(), row_supports, rhs, + component.cheapest_constraint_for_local_detector, seed_budgets); stats.lp_calls += component_result.simplex_solves; if (component_result.unbounded) { @@ -774,36 +772,55 @@ TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset solution.num_constraints += component_result.reduced_constraints; for (size_t local = 0; local < component.detectors.size(); ++local) { - solution.active_detectors.push_back(component.detectors[local]); - solution.detector_budgets.push_back(component_result.detector_budgets[local]); + detector_budget_pairs.emplace_back(component.detectors[local], + component_result.detector_budgets[local]); } } + if (detector_budget_pairs.size() > 1) { + std::sort(detector_budget_pairs.begin(), detector_budget_pairs.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + } + solution.active_detectors.reserve(detector_budget_pairs.size()); + solution.detector_budgets.reserve(detector_budget_pairs.size()); + for (const auto& [detector, budget] : detector_budget_pairs) { + solution.active_detectors.push_back(detector); + solution.detector_budgets.push_back(budget); + } + const auto stop_time = std::chrono::high_resolution_clock::now(); stats.lp_total_seconds += - std::chrono::duration_cast(stop_time - start_time).count() / - 1e6; + std::chrono::duration_cast(stop_time - start_time).count() / 1e6; return solution; } -double TesseractFTLDecoder::project_from_exact_solution( - const ExactSubsetSolution& solution, const boost::dynamic_bitset<>& detectors, - const std::vector& detector_cost_tuples) { +double TesseractFTLDecoder::project_from_exact_solution(const ExactSubsetSolution& solution, + const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags) { stats.heuristic_calls++; stats.projection_heuristic_calls++; double total = 0.0; + size_t budget_pos = 0; for (size_t detector = detectors.find_first(); detector != boost::dynamic_bitset<>::npos; detector = detectors.find_next(detector)) { bool has_available = false; for (int ei : d2e[detector]) { - if (!detector_cost_tuples[(size_t)ei].error_blocked) { + if (!blocked_flags[(size_t)ei]) { has_available = true; break; } } if (!has_available) return INF_D; - total += lookup_detector_budget(solution, (int)detector); + + while (budget_pos < solution.active_detectors.size() && + solution.active_detectors[budget_pos] < (int)detector) { + ++budget_pos; + } + if (budget_pos < solution.active_detectors.size() && + solution.active_detectors[budget_pos] == (int)detector) { + total += solution.detector_budgets[budget_pos]; + } } return total; } @@ -817,6 +834,13 @@ void TesseractFTLDecoder::reset_decode_state() { } void TesseractFTLDecoder::decode_to_errors(const std::vector& detections) { + if (config.verbose) { + std::cout << "shot"; + for (const uint64_t& d : detections) { + std::cout << " D" << d; + } + std::cout << std::endl; + } if (plain_delegate) { plain_delegate->decode_to_errors(detections); predicted_errors_buffer = plain_delegate->predicted_errors_buffer; @@ -891,12 +915,11 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio } std::priority_queue, std::greater> pq; - std::unordered_map, DynamicBitsetHash>> - visited_detectors; + std::vector, DynamicBitsetHash>> visited_detectors( + num_detectors + 1); boost::dynamic_bitset<> initial_detectors(num_detectors, false); - std::vector initial_detector_cost_tuples(num_errors); + std::vector initial_blocked_flags(num_errors, 0); for (size_t detector : detections) { if (detector >= num_detectors) { throw std::runtime_error("Symptom references detector >= num_detectors"); @@ -905,7 +928,8 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio } size_t min_num_dets = detections.size(); - size_t max_num_dets = min_num_dets + detector_beam; + size_t max_num_dets = + detector_beam > num_detectors - min_num_dets ? num_detectors : min_num_dets + detector_beam; FTLNode root; root.g_cost = 0.0; @@ -916,7 +940,7 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio root.exact_solution_idx = -1; ExactSubsetSolution root_exact = - solve_exact_subset_lp(initial_detectors, initial_detector_cost_tuples, -1); + solve_exact_subset_lp(initial_detectors, initial_blocked_flags, -1); if (root_exact.value == INF_D) { low_confidence_flag = true; return; @@ -940,9 +964,8 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio if (node.num_dets > max_num_dets) continue; boost::dynamic_bitset<> detectors = initial_detectors; - std::vector detector_cost_tuples(num_errors); - flip_detectors_and_block_errors(detector_order, node.error_chain_idx, detectors, - detector_cost_tuples); + std::vector blocked_flags(num_errors, 0); + flip_detectors_and_block_errors(detector_order, node.error_chain_idx, detectors, blocked_flags); if (config.verbose) { const size_t projected_unrefined = @@ -953,9 +976,9 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio << " lp_reinserts=" << stats.lp_reinserts << " proj_generated=" << stats.projected_nodes_generated << " proj_refined=" << stats.projected_nodes_refined - << " proj_unrefined_so_far=" << projected_unrefined - << " num_dets=" << node.num_dets << " max_num_dets=" << max_num_dets - << " f=" << node.f_cost << " g=" << node.g_cost << " h=" << node.h_cost + << " proj_unrefined_so_far=" << projected_unrefined << " num_dets=" << node.num_dets + << " max_num_dets=" << max_num_dets << " f=" << node.f_cost << " g=" << node.g_cost + << " h=" << node.h_cost << " h_source=" << heuristic_source_to_string(node.heuristic_source) << " exact_refined=" << node.exact_refined << std::endl; } @@ -977,19 +1000,22 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio if (node.num_dets < min_num_dets) { min_num_dets = node.num_dets; + const size_t next_max_num_dets = detector_beam > num_detectors - min_num_dets + ? num_detectors + : min_num_dets + detector_beam; if (config.no_revisit_dets) { - for (size_t count = min_num_dets + detector_beam + 1; count <= max_num_dets; ++count) { + for (size_t count = next_max_num_dets + 1; count <= max_num_dets; ++count) { visited_detectors[count].clear(); } } - max_num_dets = std::min(max_num_dets, min_num_dets + detector_beam); + max_num_dets = std::min(max_num_dets, next_max_num_dets); } if (!node.exact_refined) { const double prev_h = node.h_cost; const FTLHeuristicSource prev_source = node.heuristic_source; ExactSubsetSolution exact_solution = - solve_exact_subset_lp(detectors, detector_cost_tuples, node.warm_solution_idx); + solve_exact_subset_lp(detectors, blocked_flags, node.warm_solution_idx); if (prev_source == FTLHeuristicSource::kProjected) stats.projected_nodes_refined++; if (exact_solution.value == INF_D) { if (config.verbose) { @@ -1040,10 +1066,7 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio } } - std::vector prefix_blocked(num_errors, 0); - for (size_t ei = 0; ei < num_errors; ++ei) { - prefix_blocked[ei] = detector_cost_tuples[ei].error_blocked; - } + std::vector prefix_blocked = blocked_flags; size_t children_generated = 0; size_t children_projected = 0; @@ -1052,24 +1075,25 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio for (int ei : d2e[min_detector]) { prefix_blocked[(size_t)ei] = 1; - if (detector_cost_tuples[(size_t)ei].error_blocked) continue; + if (blocked_flags[(size_t)ei]) continue; boost::dynamic_bitset<> child_detectors = detectors; - for (int detector : edets[(size_t)ei]) child_detectors[(size_t)detector] = !child_detectors[(size_t)detector]; - const size_t child_num_dets = child_detectors.count(); + size_t child_num_dets = node.num_dets; + for (int detector : edets[(size_t)ei]) { + if (detectors[(size_t)detector]) { + --child_num_dets; + } else { + ++child_num_dets; + } + child_detectors.flip((size_t)detector); + } if (child_num_dets > max_num_dets) { children_beam_pruned++; continue; } - std::vector child_detector_cost_tuples(num_errors); - for (size_t j = 0; j < num_errors; ++j) { - child_detector_cost_tuples[j].error_blocked = prefix_blocked[j]; - } - - const double child_h = - project_from_exact_solution(exact_solution_arena[(size_t)node.exact_solution_idx], - child_detectors, child_detector_cost_tuples); + const double child_h = project_from_exact_solution( + exact_solution_arena[(size_t)node.exact_solution_idx], child_detectors, prefix_blocked); stats.projected_nodes_generated++; children_projected++; if (child_h == INF_D) { @@ -1108,8 +1132,8 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio stats.projected_nodes_generated - stats.projected_nodes_refined; std::cout << " expanded children_generated=" << children_generated << " children_projected=" << children_projected - << " beam_pruned=" << children_beam_pruned - << " infeasible=" << children_infeasible << " lp_calls=" << stats.lp_calls + << " beam_pruned=" << children_beam_pruned << " infeasible=" << children_infeasible + << " lp_calls=" << stats.lp_calls << " proj_unrefined_so_far=" << projected_unrefined << std::endl; } } @@ -1124,7 +1148,7 @@ double TesseractFTLDecoder::cost_from_errors(const std::vector& predicte if (plain_delegate) return plain_delegate->cost_from_errors(predicted_errors); double total_cost = 0.0; for (size_t dem_error_index : predicted_errors) { - const size_t error_index = dem_error_to_error.at(dem_error_index); + const size_t error_index = dem_error_to_error[dem_error_index]; if (error_index == std::numeric_limits::max()) { throw std::invalid_argument("error index does not map to a retained decoder error"); } @@ -1136,23 +1160,21 @@ double TesseractFTLDecoder::cost_from_errors(const std::vector& predicte std::vector TesseractFTLDecoder::get_flipped_observables( const std::vector& predicted_errors) const { if (plain_delegate) return plain_delegate->get_flipped_observables(predicted_errors); - std::unordered_set flipped_observables_set; + std::vector toggled(num_observables, 0); for (size_t dem_error_index : predicted_errors) { - const size_t error_index = dem_error_to_error.at(dem_error_index); + const size_t error_index = dem_error_to_error[dem_error_index]; if (error_index == std::numeric_limits::max()) { throw std::invalid_argument("error index does not map to a retained decoder error"); } for (int obs_index : errors[error_index].symptom.observables) { - if (flipped_observables_set.count(obs_index)) { - flipped_observables_set.erase(obs_index); - } else { - flipped_observables_set.insert(obs_index); - } + toggled[(size_t)obs_index] ^= 1; } } - std::vector flipped_observables(flipped_observables_set.begin(), - flipped_observables_set.end()); - std::sort(flipped_observables.begin(), flipped_observables.end()); + std::vector flipped_observables; + flipped_observables.reserve(num_observables); + for (size_t obs_index = 0; obs_index < num_observables; ++obs_index) { + if (toggled[obs_index]) flipped_observables.push_back((int)obs_index); + } return flipped_observables; } diff --git a/src/tesseract_ftl.h b/src/tesseract_ftl.h index 9b05e43..c34eb09 100644 --- a/src/tesseract_ftl.h +++ b/src/tesseract_ftl.h @@ -135,10 +135,6 @@ struct TesseractFTLDecoder { bool operator>(const FTLNode& other) const; }; - struct DetectorCostTuple { - uint32_t error_blocked = 0; - }; - struct SingletonPatternConstraint { std::vector local_detectors; double rhs = 0.0; @@ -183,19 +179,18 @@ struct TesseractFTLDecoder { void flip_detectors_and_block_errors(size_t detector_order, int64_t error_chain_idx, boost::dynamic_bitset<>& detectors, - std::vector& detector_cost_tuples) const; + std::vector& blocked_flags) const; - SingletonBuildResult build_singleton_components( - const boost::dynamic_bitset<>& detectors, - const std::vector& detector_cost_tuples) const; + SingletonBuildResult build_singleton_components(const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags) const; ExactSubsetSolution solve_exact_subset_lp(const boost::dynamic_bitset<>& detectors, - const std::vector& detector_cost_tuples, + const std::vector& blocked_flags, int64_t warm_solution_idx); double project_from_exact_solution(const ExactSubsetSolution& solution, const boost::dynamic_bitset<>& detectors, - const std::vector& detector_cost_tuples); + const std::vector& blocked_flags); void reset_decode_state(); }; diff --git a/src/tesseract_ftl_main.cc b/src/tesseract_ftl_main.cc index b0876e9..bf129e1 100644 --- a/src/tesseract_ftl_main.cc +++ b/src/tesseract_ftl_main.cc @@ -291,7 +291,9 @@ int main(int argc, char* argv[]) { .default_value(static_cast(std::random_device()())) .store_into(args.sample_seed); - program.add_argument("--shot-range-begin").default_value(size_t(0)).store_into(args.shot_range_begin); + program.add_argument("--shot-range-begin") + .default_value(size_t(0)) + .store_into(args.shot_range_begin); program.add_argument("--shot-range-end").default_value(size_t(0)).store_into(args.shot_range_end); program.add_argument("--in").default_value(std::string("")).store_into(args.in_fname); @@ -315,22 +317,18 @@ int main(int argc, char* argv[]) { program.add_argument("--obs-in-format", "--obs_in_format") .default_value(std::string("")) .store_into(args.obs_in_format); - program.add_argument("--out").default_value(std::string("")) - .store_into(args.out_fname); - program.add_argument("--out-format").default_value(std::string("")) - .store_into(args.out_format); - program.add_argument("--dem-out").default_value(std::string("")) - .store_into(args.dem_out_fname); - program.add_argument("--stats-out").default_value(std::string("")) + program.add_argument("--out").default_value(std::string("")).store_into(args.out_fname); + program.add_argument("--out-format").default_value(std::string("")).store_into(args.out_format); + program.add_argument("--dem-out").default_value(std::string("")).store_into(args.dem_out_fname); + program.add_argument("--stats-out") + .default_value(std::string("")) .store_into(args.stats_out_fname); program.add_argument("--threads") .default_value(size_t( std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) .store_into(args.num_threads); - program.add_argument("--beam") - .default_value(INF_DET_BEAM) - .store_into(args.det_beam); + program.add_argument("--beam").default_value(INF_DET_BEAM).store_into(args.det_beam); program.add_argument("--det-penalty").default_value(0.0).store_into(args.det_penalty); program.add_argument("--beam-climbing").flag().store_into(args.beam_climbing); program.add_argument("--no-revisit-dets").flag().store_into(args.no_revisit_dets); From 175b4a71f73076b155b669c77b7daec511cd172f Mon Sep 17 00:00:00 2001 From: noajshu Date: Mon, 30 Mar 2026 22:28:35 +0000 Subject: [PATCH 05/25] add more lp prototypes to src/py/astar with additional dual constraints included --- ...eedy_heuristics_plus_inactive_lift_lazy.py | 1207 ++++++++++++++ src/py/astar/astar_qec_inactive_lp.py | 1399 +++++++++++++++++ 2 files changed, 2606 insertions(+) create mode 100644 src/py/astar/astar_prototype_singleton_greedy_heuristics_plus_inactive_lift_lazy.py create mode 100644 src/py/astar/astar_qec_inactive_lp.py diff --git a/src/py/astar/astar_prototype_singleton_greedy_heuristics_plus_inactive_lift_lazy.py b/src/py/astar/astar_prototype_singleton_greedy_heuristics_plus_inactive_lift_lazy.py new file mode 100644 index 0000000..9674b48 --- /dev/null +++ b/src/py/astar/astar_prototype_singleton_greedy_heuristics_plus_inactive_lift_lazy.py @@ -0,0 +1,1207 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics. + +This version keeps the same Stim-facing API as the earlier greedy prototype but +adds lazy reinsertion / parent-y projection, in the same spirit as the lazy +optimal-singleton prototype: + + * nodes are seeded with a cheap feasible lower bound; + * when a node is popped, the selected heuristic is evaluated on that node; + * if the refined heuristic raises the node key, the node is reinserted; + * expanded nodes project their current feasible y-prices onto children; + * optionally, the projected child bound is maxed with plain detcost. + +Supported heuristic choices: + plain original detector-wise feasible point + asc_deg zero-start saturation ordered by ascending detector degree + desc_plain zero-start saturation ordered by descending plain y_d + plain_sweep start from plain, then one descending saturation sweep + best_of_two max(plain_sweep, asc_deg) + best_of_three max(plain_sweep, asc_deg, desc_plain) + exact_lp exact optimal singleton LP lower bound + lifted_sweep 1-pass inactive bounds transferred to plain_sweep + lifted_exact_lp 1-pass inactive bounds transferred to exact_lp + +When --lazy-reinsert-heuristics is enabled (the default), the root is seeded by +plain detcost and only popped nodes are refined with the selected heuristic. +This works for all of the above heuristics because each returns a feasible +singleton-budget vector y, and projecting that y to a child by keeping prices +on detectors that remain active and zeroing newly active detectors is still a +feasible child singleton-budget point. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SupportData: + active_detectors: List[int] + supports: List[Tuple[Tuple[int, ...], float]] + incident: Dict[int, List[int]] + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + h_cost: float + h_source: str + refined: bool + y_prices: Optional[np.ndarray] + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + max_queue_size: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + refinement_calls: int + lp_calls: int + reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_refinement_gain: float + max_refinement_gain: float + elapsed_seconds: float + + +class UnionFind: + def __init__(self, n: int) -> None: + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1.0 - p1) + (1.0 - p0) * p1 + + +def iter_dem_errors_from_dem(dem: stim.DetectorErrorModel) -> Iterable[ErrorRecord]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5), got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + yield ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors_from_dem(dem: stim.DetectorErrorModel) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors_from_dem(dem): + key = (error.detectors, error.observables) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = error.probability + else: + p_new = xor_probability(p_old, error.probability) + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Merged error has probability >= 0.5 ({probability}); cannot assign positive cost." + ) + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +class GreedySingletonHeuristicDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + num_observables: int, + *, + heuristic: str = "best_of_two", + respect_blocked_errors_in_heuristic: bool = True, + lazy_reinsert_heuristics: bool = True, + projection_combine_max_plain: bool = True, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_errors = len(self.errors) + self.num_detectors = int(num_detectors) + self.num_observables = int(num_observables) + self.heuristic_name = heuristic + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.lazy_reinsert_heuristics = lazy_reinsert_heuristics + self.projection_combine_max_plain = projection_combine_max_plain + self.verbose_search = verbose_search + + self.probabilities = np.array([err.probability for err in self.errors], dtype=np.float64) + self.weights = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.error_detectors: List[Tuple[int, ...]] = [tuple(err.detectors) for err in self.errors] + self.error_observables: List[Tuple[int, ...]] = [tuple(err.observables) for err in self.errors] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.error_detectors): + for d in dets: + d2e_lists[d].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + self.plain_heuristic_calls = 0 + self.projection_heuristic_calls = 0 + self.refinement_calls = 0 + self.lp_calls = 0 + self.reinserts = 0 + self.projected_nodes_generated = 0 + self.projected_nodes_refined = 0 + self.total_refinement_gain = 0.0 + self.max_refinement_gain = 0.0 + + @property + def mode_name(self) -> str: + if self.heuristic_name == "plain": + return "plain" + if self.lazy_reinsert_heuristics: + suffix = "-lazy-projection" + if self.projection_combine_max_plain: + suffix += "-maxplain" + return f"{self.heuristic_name}{suffix}" + return self.heuristic_name + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _has_cover_for_all_active_detectors(self, dets: np.ndarray, available_errs: np.ndarray) -> bool: + for d in np.flatnonzero(dets): + found = False + for ei in self.d2e[int(d)]: + if available_errs[int(ei)]: + found = True + break + if not found: + return False + return True + + def _apply_inactive_lift(self, y: np.ndarray, active_dets: np.ndarray, available_errs: np.ndarray) -> float: + """Modifies y IN PLACE to increase values using a post-processing dual slack transfer.""" + slacks = np.zeros(self.num_errors, dtype=np.float64) + available_indices = np.flatnonzero(available_errs) + + # Calculate initial slacks + for ei in available_indices: + ei = int(ei) + slacks[ei] = self.weights[ei] + for d in self.error_detectors[ei]: + if active_dets[d]: + slacks[ei] -= y[d] + + active_list = np.flatnonzero(active_dets) + # Sort descending to attack the heaviest y-values first + order = sorted((int(d) for d in active_list), key=lambda d: y[d], reverse=True) + + for d in order: + incident_eis = [int(ei) for ei in self.d2e[d] if available_errs[ei]] + if not incident_eis: + continue + + min_s = min(slacks[ei] for ei in incident_eis) + + # If the detector is bottled-necked, try to transfer slack from inactive neighbors + if min_s < 1e-9: + blocking_eis = [ei for ei in incident_eis if slacks[ei] < 1e-9] + for ei in blocking_eis: + for d_inact in self.error_detectors[ei]: + if not active_dets[d_inact]: + siblings = [int(j) for j in self.d2e[d_inact] if available_errs[j] and j != ei] + if not siblings: + continue + delta = min(slacks[j] for j in siblings) + if delta > 1e-9: + # Execute the transfer + slacks[ei] += delta + for j in siblings: + slacks[j] -= delta + break + + # Re-evaluate the bottleneck and lift if space was created + new_min_s = min(slacks[ei] for ei in incident_eis) + if new_min_s > 1e-9: + y[d] += new_min_s + for ei in incident_eis: + slacks[ei] -= new_min_s + + return float(y[active_dets].sum()) + + def build_support_data(self, active_dets: np.ndarray, available_errs: np.ndarray) -> SupportData: + active_list = sorted(map(int, np.flatnonzero(active_dets))) + incident: Dict[int, List[int]] = {d: [] for d in active_list} + support_to_weight: Dict[Tuple[int, ...], float] = {} + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + support = tuple(d for d in self.error_detectors[ei] if active_dets[d]) + if not support: + continue + weight = float(self.weights[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + supports = list(support_to_weight.items()) + for i, (support, _weight) in enumerate(supports): + for d in support: + if d in incident: + incident[d].append(i) + + return SupportData(active_detectors=active_list, supports=supports, incident=incident) + + def _check_coverage(self, support_data: SupportData) -> bool: + return all(len(support_data.incident[d]) > 0 for d in support_data.active_detectors) + + def plain_detcost_from_counts( + self, + dets: np.ndarray, + available_errs: np.ndarray, + det_counts: np.ndarray, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + self.plain_heuristic_calls += 1 + active = np.flatnonzero(dets) + if active.size == 0: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for d in active: + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.weights[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF, None + y[int(d)] = best + total += best + return total, y + + def heuristic_plain(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + y = np.zeros(self.num_detectors, dtype=np.float64) + for d in support_data.active_detectors: + best = INF + for i in support_data.incident[d]: + support, weight = support_data.supports[i] + best = min(best, weight / len(support)) + y[d] = best + return float(y[support_data.active_detectors].sum()), y + + def heuristic_saturation_zero(self, support_data: SupportData, *, order_kind: str) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + slack = np.array([weight for _support, weight in support_data.supports], dtype=np.float64) + y = np.zeros(self.num_detectors, dtype=np.float64) + + if order_kind == "asc_deg": + order = sorted(support_data.active_detectors, key=lambda d: (len(support_data.incident[d]), d)) + elif order_kind == "desc_plain": + _plain_value, y_plain = self.heuristic_plain(support_data) + if y_plain is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y_plain[d], d), reverse=True) + else: + raise ValueError(f"Unknown order_kind={order_kind!r}") + + for d in order: + value = min(slack[i] for i in support_data.incident[d]) + if value < 0: + value = 0.0 + y[d] = value + for i in support_data.incident[d]: + slack[i] -= value + return float(y[support_data.active_detectors].sum()), y + + def heuristic_plain_sweep(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + plain_value, y = self.heuristic_plain(support_data) + if y is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y[d], d), reverse=True) + for d in order: + max_feasible = min( + weight - sum(y[dd] for dd in support if dd != d) + for support, weight in support_data.supports + if d in support + ) + if max_feasible > y[d]: + y[d] = max_feasible + return float(y[support_data.active_detectors].sum()), y + + def heuristic_exact_lp(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + active = support_data.active_detectors + if not active: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + detector_index = {d: i for i, d in enumerate(active)} + uf = UnionFind(len(active)) + for support, _weight in support_data.supports: + if len(support) > 1: + a = detector_index[support[0]] + for d in support[1:]: + uf.union(a, detector_index[d]) + + components: Dict[int, List[int]] = defaultdict(list) + for d in active: + components[uf.find(detector_index[d])].append(d) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for component in components.values(): + component_set = set(component) + local = {d: i for i, d in enumerate(sorted(component))} + component_supports: List[Tuple[Tuple[int, ...], float]] = [] + for support, weight in support_data.supports: + if support[0] in component_set: + component_supports.append((tuple(local[d] for d in support), weight)) + + rows: List[int] = [] + cols: List[int] = [] + data: List[float] = [] + rhs: List[float] = [] + for r, (support, weight) in enumerate(component_supports): + rhs.append(weight) + for c in support: + rows.append(r) + cols.append(c) + data.append(1.0) + + a_ub = csr_matrix( + (data, (rows, cols)), + shape=(len(component_supports), len(component)), + dtype=np.float64, + ) + self.lp_calls += 1 + result = linprog( + c=-np.ones(len(component), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * len(component), + method="highs", + ) + if not result.success: + return INF, None + total += -float(result.fun) + for d, value in zip(sorted(component), result.x): + y[d] = float(value) + return float(total), y + + def evaluate_named_heuristic(self, support_data: SupportData, name: str) -> Tuple[float, Optional[np.ndarray]]: + if name == "plain": + return self.heuristic_plain(support_data) + if name == "asc_deg": + return self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if name == "desc_plain": + return self.heuristic_saturation_zero(support_data, order_kind="desc_plain") + if name == "plain_sweep": + return self.heuristic_plain_sweep(support_data) + if name == "best_of_two": + v1, y1 = self.heuristic_plain_sweep(support_data) + v2, y2 = self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if v1 >= v2: + return v1, y1 + return v2, y2 + if name == "best_of_three": + candidates = [ + self.heuristic_plain_sweep(support_data), + self.heuristic_saturation_zero(support_data, order_kind="asc_deg"), + self.heuristic_saturation_zero(support_data, order_kind="desc_plain"), + ] + return max(candidates, key=lambda t: t[0]) + if name == "exact_lp": + return self.heuristic_exact_lp(support_data) + raise ValueError(f"Unknown heuristic {name!r}") + + def compute_support_based_heuristic( + self, + dets: np.ndarray, + errs: np.ndarray, + blocked_errs: np.ndarray, + *, + name: Optional[str] = None, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + available = self._available_errors(errs, blocked_errs) + h_name = name or self.heuristic_name + + support_data = self.build_support_data(dets, available) + + if h_name in {"lifted_sweep", "lifted_exact_lp"}: + if h_name == "lifted_sweep": + safe_val, safe_y = self.heuristic_plain_sweep(support_data) + else: + safe_val, safe_y = self.heuristic_exact_lp(support_data) + + if safe_y is None: + return INF, None + + lifted_val = self._apply_inactive_lift(safe_y.copy(), dets, available) + # Return max to guarantee it is >= base heuristic, and safe_y to keep projection valid + return max(lifted_val, safe_val), safe_y + + return self.evaluate_named_heuristic(support_data, h_name) + + def project_child_y( + self, + parent_state: SearchState, + child_dets: np.ndarray, + child_errs: np.ndarray, + child_blocked_errs: np.ndarray, + child_det_counts: np.ndarray, + flipped_detectors: Sequence[int], + ) -> Tuple[float, Optional[np.ndarray], str]: + if parent_state.y_prices is None: + raise AssertionError("Expected a stored feasible y vector before projecting to a child.") + + self.heuristic_calls += 1 + self.projection_heuristic_calls += 1 + available = self._available_errors(child_errs, child_blocked_errs) + if not self._has_cover_for_all_active_detectors(child_dets, available): + return INF, None, "projected" + + y_projected = np.zeros(self.num_detectors, dtype=np.float64) + keep = parent_state.dets & child_dets + y_projected[keep] = parent_state.y_prices[keep] + projected_value = float(y_projected[np.flatnonzero(child_dets)].sum()) + best_value = projected_value + best_y = y_projected + best_source = "projected" + + if self.projection_combine_max_plain: + plain_value, plain_y = self.plain_detcost_from_counts(child_dets, available, child_det_counts) + if plain_y is None: + return INF, None, "plain" + if plain_value > best_value + HEURISTIC_EPS: + best_value = plain_value + best_y = plain_y + best_source = "plain" + + return best_value, best_y, best_source + + def report_root_heuristics(self, dets: np.ndarray, errs: np.ndarray, blocked_errs: np.ndarray) -> List[Tuple[str, float]]: + available = self._available_errors(errs, blocked_errs) + support_data = self.build_support_data(dets, available) + names = ["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp", "lifted_sweep", "lifted_exact_lp"] + out: List[Tuple[str, float]] = [] + saved_lp_calls = self.lp_calls + + for name in names: + if name in {"lifted_sweep", "lifted_exact_lp"}: + if name == "lifted_sweep": + safe_val, safe_y = self.heuristic_plain_sweep(support_data) + else: + safe_val, safe_y = self.heuristic_exact_lp(support_data) + if safe_y is None: + out.append((name, INF)) + else: + lifted_val = self._apply_inactive_lift(safe_y.copy(), dets, available) + out.append((name, max(lifted_val, safe_val))) + else: + value, _ = self.evaluate_named_heuristic(support_data, name) + out.append((name, value)) + + self.lp_calls = saved_lp_calls + return out + + def _maybe_refine_node(self, state: SearchState) -> Tuple[SearchState, bool]: + if state.refined or self.heuristic_name == "plain" or not self.lazy_reinsert_heuristics: + return state, False + + previous_source = state.h_source + self.refinement_calls += 1 + new_value, new_y = self.compute_support_based_heuristic( + state.dets, + state.errs, + state.blocked_errs, + name=self.heuristic_name, + ) + if new_y is None: + if previous_source == "projected": + self.projected_nodes_refined += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost:.6f} new_h=INF delta=INF reinserted=False discarded=True" + ) + state.refined = True + return state, True + + delta = new_value - state.h_cost + self.total_refinement_gain += max(0.0, delta) + self.max_refinement_gain = max(self.max_refinement_gain, max(0.0, delta)) + + if self.heuristic_name in {"exact_lp", "lifted_exact_lp"} and new_value + 1e-7 < state.h_cost: + raise AssertionError( + f"Exact LP value {new_value} below stored projected value {state.h_cost}." + ) + + if new_value > state.h_cost + HEURISTIC_EPS: + if previous_source == "projected": + self.projected_nodes_refined += 1 + state.h_cost = new_value + state.h_source = "refined" + state.y_prices = new_y + state.refined = True + self.reinserts += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost - delta:.6f} new_h={new_value:.6f} delta={delta:.6f} reinserted=True discarded=False" + ) + return state, True + + if previous_source == "projected": + self.projected_nodes_refined += 1 + if abs(new_value - state.h_cost) <= HEURISTIC_EPS: + state.y_prices = new_y + state.refined = True + if self.verbose_search: + new_text = "INF" if math.isinf(new_value) else f"{new_value:.6f}" + print( + f" refine approx_h={state.h_cost:.6f} new_h={new_text} delta={delta:.6f} reinserted=False discarded=False" + ) + return state, False + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + root_h, root_y = self.plain_detcost_from_counts(dets0, self._available_errors(errs0, blocked0), det_counts0) + if root_y is None or math.isinf(root_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + root_refined = (self.heuristic_name == "plain") or (not self.lazy_reinsert_heuristics) + if root_refined and self.heuristic_name != "plain": + eager_h, eager_y = self.compute_support_based_heuristic(dets0, errs0, blocked0, name=self.heuristic_name) + if eager_y is None or math.isinf(eager_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + root_h, root_y = eager_h, eager_y + + root_state = SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + h_cost=root_h, + h_source="plain" if not root_refined else ("plain" if self.heuristic_name == "plain" else "refined"), + refined=root_refined, + y_prices=root_y, + ) + + heap: List[Tuple[float, int, int, SearchState]] = [] + counter = 0 + heapq.heappush(heap, (root_state.g_cost + root_state.h_cost, int(dets0.sum()), counter, root_state)) + counter += 1 + nodes_pushed = 1 + nodes_popped = 0 + max_queue_size = 1 + min_num_dets = int(dets0.sum()) + + while heap: + max_queue_size = max(max_queue_size, len(heap)) + f_cost, num_dets, _entry_id, state = heapq.heappop(heap) + nodes_popped += 1 + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f"len(heap)={len(heap)} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"lp_calls={self.lp_calls} reinserts={self.reinserts} proj_generated={self.projected_nodes_generated} " + f"proj_refined={self.projected_nodes_refined} proj_unrefined_so_far={projected_unrefined} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f} " + f"h={state.h_cost:.6f} h_source={state.h_source} refined={state.refined}" + ) + + if num_dets == 0: + return DecodeResult( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + state, should_reinsert = self._maybe_refine_node(state) + if should_reinsert: + if state.y_prices is None or math.isinf(state.h_cost): + if state.h_source == "projected": + self.projected_nodes_refined += 1 + continue + if state.h_source != "plain": + heapq.heappush(heap, (state.g_cost + state.h_cost, num_dets, counter, state)) + counter += 1 + continue + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked = state.blocked_errs.copy() + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + children_projected = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked[ei] = True + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked = prefix_blocked.copy() + child_dets = state.dets.copy() + child_det_counts = state.det_counts.copy() + for d in self.error_detectors[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_g = state.g_cost + float(self.weights[ei]) + if self.heuristic_name == "plain" or (not self.lazy_reinsert_heuristics): + child_h, child_y = self.compute_support_based_heuristic( + child_dets, child_errs, child_blocked, name=self.heuristic_name + ) + child_source = "plain" if self.heuristic_name == "plain" else "refined" + child_refined = True + else: + if state.y_prices is None: + raise AssertionError("Expected parent feasible y-prices before projecting to child.") + child_h, child_y, child_source = self.project_child_y( + state, + child_dets, + child_errs, + child_blocked, + child_det_counts, + self.error_detectors[ei], + ) + self.projected_nodes_generated += 1 + children_projected += 1 + child_refined = False + + if child_y is None or math.isinf(child_h): + children_infeasible += 1 + continue + + child_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=child_h, + h_source=child_source, + refined=child_refined, + y_prices=child_y, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, counter, child_state)) + counter += 1 + nodes_pushed += 1 + children_generated += 1 + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f" expanded children_generated={children_generated} children_projected={children_projected} " + f"beam_pruned={children_beam_pruned} infeasible={children_infeasible} " + f"lp_calls={self.lp_calls} proj_unrefined_so_far={projected_unrefined}" + ) + + return DecodeResult( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.weights[errs].sum()) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.error_detectors[int(ei)]: + dets[d] ^= True + return dets + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.error_observables[int(ei)]: + parity[int(obs)] = not parity.get(int(obs), False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + +def sample_detections_and_observables( + circuit: stim.Circuit, + *, + num_shots: int, + seed: int, + num_detectors: int, + num_observables: int, +) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--dets", + type=str, + default=None, + help="String of shot dets (e.g., 'shot D0 D1 L2') to parse instead of sampling.", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample from Stim before selecting --shot (default: 100).", + ) + parser.add_argument( + "--shot", + type=int, + default=0, + help="Index of the sampled shot to decode (default: 0).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the residual detector count; use 'inf' for none.", + ) + parser.add_argument( + "--heuristic", + choices=["plain", "asc_deg", "desc_plain", "plain_sweep", "best_of_two", "best_of_three", "exact_lp", "lifted_sweep", "lifted_exact_lp"], + default="best_of_two", + help="Lower-bound heuristic to use during A* search (default: best_of_two).", + ) + parser.add_argument( + "--lazy-reinsert-heuristics", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For non-plain heuristics, seed nodes with plain detcost, refine on pop, and reinsert when the selected " + "heuristic improves the key (default: enabled)." + ), + ) + parser.add_argument( + "--projection-combine-max-plain", + action=argparse.BooleanOptionalAction, + default=True, + help="When projecting parent y-prices to a child, take max(projected, plain detcost) (default: enabled).", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action=argparse.BooleanOptionalAction, + default=True, + help="Exclude precedence-blocked errors when forming the lower bound (default: enabled).", + ) + parser.add_argument( + "--report-all-root-heuristics", + action="store_true", + help="Print all root-node heuristic values, including exact_lp, for the selected shot.", + ) + parser.add_argument( + "--skip-decode", + action="store_true", + help="Only report root heuristics; do not run A* search.", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the selected shot's active detector IDs (default: enabled).", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the decoded merged-error indices when decoding succeeds (default: enabled).", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) if args.merge_errors else list(iter_dem_errors_from_dem(dem)) + + if args.dets is not None: + shot_dets = np.zeros(dem.num_detectors, dtype=bool) + shot_obs = np.zeros(dem.num_observables, dtype=bool) + for token in args.dets.split(): + if token == "shot": + continue + if token.startswith("D") and token[1:].isdigit(): + d_idx = int(token[1:]) + if d_idx < dem.num_detectors: + shot_dets[d_idx] = True + elif token.startswith("L") and token[1:].isdigit(): + l_idx = int(token[1:]) + if l_idx < dem.num_observables: + shot_obs[l_idx] = True + else: + dets, obs = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + ) + shot_dets = dets[args.shot] + shot_obs = obs[args.shot] + + decoder = GreedySingletonHeuristicDecoder( + errors, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + heuristic=args.heuristic, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + lazy_reinsert_heuristics=args.lazy_reinsert_heuristics, + projection_combine_max_plain=args.projection_combine_max_plain, + verbose_search=args.verbose_search, + ) + + print(f"circuit = {args.circuit}") + print(f"heuristic = {args.heuristic}") + print(f"mode = {decoder.mode_name}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"shot = {args.shot}") + print(f"num_errors = {decoder.num_errors}") + print(f"num_detectors = {decoder.num_detectors}") + print(f"num_observables = {decoder.num_observables}") + print(f"det_beam = {args.det_beam}") + print(f"merge_errors = {args.merge_errors}") + print(f"respect_blocked_errors_in_heuristic = {args.respect_blocked_errors_in_heuristic}") + print(f"lazy_reinsert_heuristics = {args.lazy_reinsert_heuristics}") + print(f"projection_combine_max_plain = {args.projection_combine_max_plain}") + + if args.show_shot_detectors: + active_dets = np.flatnonzero(shot_dets) + print("shot_detectors =", " ".join(f"D{d}" for d in active_dets)) + + if args.report_all_root_heuristics: + root_errs = np.zeros(decoder.num_errors, dtype=bool) + root_blocked = np.zeros(decoder.num_errors, dtype=bool) + report = decoder.report_root_heuristics(shot_dets, root_errs, root_blocked) + exact = next((v for k, v in report if k == "exact_lp"), None) + print("root_heuristics:") + for name, value in report: + if exact is not None and not math.isinf(exact) and exact > 0: + ratio = value / exact if not math.isinf(value) else INF + ratio_text = "INF" if math.isinf(ratio) else f"{ratio:.6f}" + else: + ratio_text = "n/a" + value_text = "INF" if math.isinf(value) else f"{value:.12f}" + print(f" {name:>12s} value={value_text} ratio_to_exact={ratio_text}") + + if args.skip_decode: + return 0 + + result = decoder.decode(shot_dets, det_beam=args.det_beam) + print(f"success = {result.success}") + print(f"nodes_pushed = {result.nodes_pushed}") + print(f"nodes_popped = {result.nodes_popped}") + print(f"max_queue_size = {result.max_queue_size}") + print(f"heuristic_calls = {result.heuristic_calls}") + print(f"plain_heuristic_calls = {result.plain_heuristic_calls}") + print(f"projection_heuristic_calls = {result.projection_heuristic_calls}") + print(f"refinement_calls = {result.refinement_calls}") + print(f"lp_calls = {result.lp_calls}") + print(f"reinserts = {result.reinserts}") + print(f"projected_nodes_generated = {result.projected_nodes_generated}") + print(f"projected_nodes_refined = {result.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish = {result.projected_nodes_unrefined_at_finish}") + print(f"total_refinement_gain = {result.total_refinement_gain:.6f}") + print(f"max_refinement_gain = {result.max_refinement_gain:.6f}") + print(f"elapsed_seconds = {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + if args.show_error_indices: + print("decoded_error_indices =", " ".join(map(str, np.flatnonzero(result.errs).tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + decoded_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + sampled_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors = {int(result.errs.sum())}") + print(f"decoded_cost = {decoded_cost:.12f}") + print("predicted_observables =", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables =", " ".join(f"L{o}" for o in sampled_obs.tolist())) + print(f"observables_match = {bool(np.array_equal(predicted_obs, sampled_obs))}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/py/astar/astar_qec_inactive_lp.py b/src/py/astar/astar_qec_inactive_lp.py new file mode 100644 index 0000000..d9e2c36 --- /dev/null +++ b/src/py/astar/astar_qec_inactive_lp.py @@ -0,0 +1,1399 @@ +#!/usr/bin/env python3 +"""Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics. + +This version keeps the same Stim-facing API as the earlier greedy prototype but +adds lazy reinsertion / parent-y projection, in the same spirit as the lazy +optimal-singleton prototype: + + * nodes are seeded with a cheap feasible lower bound; + * when a node is popped, the selected heuristic is evaluated on that node; + * if the refined heuristic raises the node key, the node is reinserted; + * expanded nodes project their current feasible y-prices onto children; + * optionally, the projected child bound is maxed with plain detcost. + +Supported heuristic choices: + plain original detector-wise feasible point + asc_deg zero-start saturation ordered by ascending detector degree + desc_plain zero-start saturation ordered by descending plain y_d + plain_sweep start from plain, then one descending saturation sweep + best_of_two max(plain_sweep, asc_deg) + best_of_three max(plain_sweep, asc_deg, desc_plain) + exact_lp exact optimal singleton LP lower bound + exact_lp_plus_inactive + exact LP lower bound with extra inactive-detector no-one-hot constraints + +When --lazy-reinsert-heuristics is enabled (the default), the root is seeded by +plain detcost and only popped nodes are refined with the selected heuristic. +This works directly for the support-only heuristics because each returns a +feasible singleton-budget vector y, and projecting that y to a child by +keeping prices on detectors that remain active and zeroing newly active +detectors is still a feasible child singleton-budget point. For +exact_lp_plus_inactive, the refined LP optimum is not directly projectable to a +child, so lazy mode keeps the current projectable singleton-budget prices for +child seeding and uses the tightened LP only when refining popped nodes. +""" + +from __future__ import annotations + +import argparse +import heapq +import math +import time +from collections import defaultdict, deque +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import stim +from scipy.optimize import linprog +from scipy.sparse import csr_matrix + +INF = float("inf") +HEURISTIC_EPS = 1e-9 + + +@dataclass(frozen=True) +class ErrorRecord: + probability: float + likelihood_cost: float + detectors: Tuple[int, ...] + observables: Tuple[int, ...] + + +@dataclass +class SupportData: + active_detectors: List[int] + supports: List[Tuple[Tuple[int, ...], float]] + incident: Dict[int, List[int]] + + +@dataclass +class SearchState: + errs: np.ndarray + blocked_errs: np.ndarray + dets: np.ndarray + det_counts: np.ndarray + g_cost: float + h_cost: float + h_source: str + refined: bool + y_prices: Optional[np.ndarray] + + +@dataclass +class DecodeResult: + success: bool + errs: np.ndarray + residual_dets: np.ndarray + cost: float + nodes_pushed: int + nodes_popped: int + max_queue_size: int + heuristic_calls: int + plain_heuristic_calls: int + projection_heuristic_calls: int + refinement_calls: int + lp_calls: int + reinserts: int + projected_nodes_generated: int + projected_nodes_refined: int + projected_nodes_unrefined_at_finish: int + total_refinement_gain: float + max_refinement_gain: float + elapsed_seconds: float + + +class UnionFind: + def __init__(self, n: int) -> None: + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +def xor_probability(p0: float, p1: float) -> float: + return p0 * (1.0 - p1) + (1.0 - p0) * p1 + + +def iter_dem_errors_from_dem(dem: stim.DetectorErrorModel) -> Iterable[ErrorRecord]: + for instruction in dem.flattened(): + if instruction.type != "error": + continue + probability = float(instruction.args_copy()[0]) + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Expected flattened error probabilities in (0, 0.5), got {probability}." + ) + + detectors: set[int] = set() + observables: set[int] = set() + for target in instruction.targets_copy(): + if target.is_separator(): + continue + if target.is_logical_observable_id(): + if target.val in observables: + observables.remove(target.val) + else: + observables.add(target.val) + else: + if not target.is_relative_detector_id(): + raise ValueError(f"Unexpected DEM target: {target!r}") + if target.val in detectors: + detectors.remove(target.val) + else: + detectors.add(target.val) + + yield ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=tuple(sorted(detectors)), + observables=tuple(sorted(observables)), + ) + + +def merged_errors_from_dem(dem: stim.DetectorErrorModel) -> List[ErrorRecord]: + errors_by_symptom: Dict[Tuple[Tuple[int, ...], Tuple[int, ...]], float] = {} + for error in iter_dem_errors_from_dem(dem): + key = (error.detectors, error.observables) + p_old = errors_by_symptom.get(key) + if p_old is None: + p_new = error.probability + else: + p_new = xor_probability(p_old, error.probability) + errors_by_symptom[key] = p_new + + merged: List[ErrorRecord] = [] + for (detectors, observables), probability in errors_by_symptom.items(): + if probability <= 0: + continue + if probability >= 0.5: + raise ValueError( + f"Merged error has probability >= 0.5 ({probability}); cannot assign positive cost." + ) + merged.append( + ErrorRecord( + probability=probability, + likelihood_cost=-math.log(probability / (1.0 - probability)), + detectors=detectors, + observables=observables, + ) + ) + return merged + + +class GreedySingletonHeuristicDecoder: + def __init__( + self, + errors: Sequence[ErrorRecord], + num_detectors: int, + num_observables: int, + *, + heuristic: str = "best_of_two", + respect_blocked_errors_in_heuristic: bool = True, + lazy_reinsert_heuristics: bool = True, + projection_combine_max_plain: bool = True, + verbose_search: bool = False, + ) -> None: + self.errors = list(errors) + self.num_errors = len(self.errors) + self.num_detectors = int(num_detectors) + self.num_observables = int(num_observables) + self.heuristic_name = heuristic + self.respect_blocked_errors_in_heuristic = respect_blocked_errors_in_heuristic + self.lazy_reinsert_heuristics = lazy_reinsert_heuristics + self.projection_combine_max_plain = projection_combine_max_plain + self.verbose_search = verbose_search + + self.probabilities = np.array([err.probability for err in self.errors], dtype=np.float64) + self.weights = np.array([err.likelihood_cost for err in self.errors], dtype=np.float64) + self.error_detectors: List[Tuple[int, ...]] = [tuple(err.detectors) for err in self.errors] + self.error_observables: List[Tuple[int, ...]] = [tuple(err.observables) for err in self.errors] + + d2e_lists: List[List[int]] = [[] for _ in range(self.num_detectors)] + for ei, dets in enumerate(self.error_detectors): + for d in dets: + d2e_lists[d].append(ei) + self.d2e: List[np.ndarray] = [np.array(v, dtype=np.int32) for v in d2e_lists] + + self.reset_stats() + + def reset_stats(self) -> None: + self.heuristic_calls = 0 + self.plain_heuristic_calls = 0 + self.projection_heuristic_calls = 0 + self.refinement_calls = 0 + self.lp_calls = 0 + self.reinserts = 0 + self.projected_nodes_generated = 0 + self.projected_nodes_refined = 0 + self.total_refinement_gain = 0.0 + self.max_refinement_gain = 0.0 + + @property + def mode_name(self) -> str: + if self.heuristic_name == "plain": + return "plain" + if self.lazy_reinsert_heuristics: + suffix = "-lazy-projection" + if self.projection_combine_max_plain: + suffix += "-maxplain" + return f"{self.heuristic_name}{suffix}" + return self.heuristic_name + + @staticmethod + def heuristic_has_projectable_prices(name: str) -> bool: + return name != "exact_lp_plus_inactive" + + def _available_errors(self, errs: np.ndarray, blocked_errs: np.ndarray) -> np.ndarray: + available = ~errs + if self.respect_blocked_errors_in_heuristic: + available &= ~blocked_errs + return available + + def _has_cover_for_all_active_detectors(self, dets: np.ndarray, available_errs: np.ndarray) -> bool: + for d in np.flatnonzero(dets): + found = False + for ei in self.d2e[int(d)]: + if available_errs[int(ei)]: + found = True + break + if not found: + return False + return True + + def build_support_data(self, active_dets: np.ndarray, available_errs: np.ndarray) -> SupportData: + active_list = sorted(map(int, np.flatnonzero(active_dets))) + incident: Dict[int, List[int]] = {d: [] for d in active_list} + support_to_weight: Dict[Tuple[int, ...], float] = {} + + for ei in np.flatnonzero(available_errs): + ei = int(ei) + support = tuple(d for d in self.error_detectors[ei] if active_dets[d]) + if not support: + continue + weight = float(self.weights[ei]) + old = support_to_weight.get(support) + if old is None or weight < old: + support_to_weight[support] = weight + + supports = list(support_to_weight.items()) + for i, (support, _weight) in enumerate(supports): + for d in support: + if d in incident: + incident[d].append(i) + + return SupportData(active_detectors=active_list, supports=supports, incident=incident) + + def _check_coverage(self, support_data: SupportData) -> bool: + return all(len(support_data.incident[d]) > 0 for d in support_data.active_detectors) + + def plain_detcost_from_counts( + self, + dets: np.ndarray, + available_errs: np.ndarray, + det_counts: np.ndarray, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + self.plain_heuristic_calls += 1 + active = np.flatnonzero(dets) + if active.size == 0: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for d in active: + best = INF + for ei in self.d2e[int(d)]: + ei = int(ei) + if not available_errs[ei]: + continue + count = int(det_counts[ei]) + assert count > 0 + value = self.weights[ei] / count + if value < best: + best = value + if math.isinf(best): + return INF, None + y[int(d)] = best + total += best + return total, y + + def heuristic_plain(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + y = np.zeros(self.num_detectors, dtype=np.float64) + for d in support_data.active_detectors: + best = INF + for i in support_data.incident[d]: + support, weight = support_data.supports[i] + best = min(best, weight / len(support)) + y[d] = best + return float(y[support_data.active_detectors].sum()), y + + def heuristic_saturation_zero(self, support_data: SupportData, *, order_kind: str) -> Tuple[float, Optional[np.ndarray]]: + if not support_data.active_detectors: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + slack = np.array([weight for _support, weight in support_data.supports], dtype=np.float64) + y = np.zeros(self.num_detectors, dtype=np.float64) + + if order_kind == "asc_deg": + order = sorted(support_data.active_detectors, key=lambda d: (len(support_data.incident[d]), d)) + elif order_kind == "desc_plain": + _plain_value, y_plain = self.heuristic_plain(support_data) + if y_plain is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y_plain[d], d), reverse=True) + else: + raise ValueError(f"Unknown order_kind={order_kind!r}") + + for d in order: + value = min(slack[i] for i in support_data.incident[d]) + if value < 0: + value = 0.0 + y[d] = value + for i in support_data.incident[d]: + slack[i] -= value + return float(y[support_data.active_detectors].sum()), y + + def heuristic_plain_sweep(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + plain_value, y = self.heuristic_plain(support_data) + if y is None: + return INF, None + order = sorted(support_data.active_detectors, key=lambda d: (y[d], d), reverse=True) + for d in order: + max_feasible = min( + weight - sum(y[dd] for dd in support if dd != d) + for support, weight in support_data.supports + if d in support + ) + if max_feasible > y[d]: + y[d] = max_feasible + return float(y[support_data.active_detectors].sum()), y + + def heuristic_exact_lp(self, support_data: SupportData) -> Tuple[float, Optional[np.ndarray]]: + active = support_data.active_detectors + if not active: + return 0.0, np.zeros(self.num_detectors, dtype=np.float64) + if not self._check_coverage(support_data): + return INF, None + + detector_index = {d: i for i, d in enumerate(active)} + uf = UnionFind(len(active)) + for support, _weight in support_data.supports: + if len(support) > 1: + a = detector_index[support[0]] + for d in support[1:]: + uf.union(a, detector_index[d]) + + components: Dict[int, List[int]] = defaultdict(list) + for d in active: + components[uf.find(detector_index[d])].append(d) + + y = np.zeros(self.num_detectors, dtype=np.float64) + total = 0.0 + for component in components.values(): + component_set = set(component) + local = {d: i for i, d in enumerate(sorted(component))} + component_supports: List[Tuple[Tuple[int, ...], float]] = [] + for support, weight in support_data.supports: + if support[0] in component_set: + component_supports.append((tuple(local[d] for d in support), weight)) + + rows: List[int] = [] + cols: List[int] = [] + data: List[float] = [] + rhs: List[float] = [] + for r, (support, weight) in enumerate(component_supports): + rhs.append(weight) + for c in support: + rows.append(r) + cols.append(c) + data.append(1.0) + + a_ub = csr_matrix( + (data, (rows, cols)), + shape=(len(component_supports), len(component)), + dtype=np.float64, + ) + self.lp_calls += 1 + result = linprog( + c=-np.ones(len(component), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * len(component), + method="highs", + ) + if not result.success: + return INF, None + total += -float(result.fun) + for d, value in zip(sorted(component), result.x): + y[d] = float(value) + return float(total), y + + + def _reachable_available_components( + self, + dets: np.ndarray, + available_errs: np.ndarray, + ) -> List[Tuple[List[int], List[int], List[int]]]: + active = sorted(map(int, np.flatnonzero(dets))) + if not active: + return [] + + det_visited = np.zeros(self.num_detectors, dtype=bool) + err_visited = np.zeros(self.num_errors, dtype=bool) + components: List[Tuple[List[int], List[int], List[int]]] = [] + + for seed in active: + if det_visited[seed]: + continue + det_visited[seed] = True + queue: deque[int] = deque([seed]) + component_dets: List[int] = [] + component_errs: List[int] = [] + while queue: + d = queue.popleft() + component_dets.append(d) + for ei in self.d2e[d]: + ei = int(ei) + if not available_errs[ei] or err_visited[ei]: + continue + err_visited[ei] = True + component_errs.append(ei) + for dd in self.error_detectors[ei]: + dd = int(dd) + if not det_visited[dd]: + det_visited[dd] = True + queue.append(dd) + + component_active = [d for d in component_dets if dets[d]] + if not component_active: + continue + component_inactive = [d for d in component_dets if not dets[d]] + components.append((component_active, component_inactive, component_errs)) + + return components + + def _solve_component_exact_lp_plus_inactive( + self, + component_active: Sequence[int], + component_inactive: Sequence[int], + component_errors: Sequence[int], + ) -> float: + if not component_active: + return 0.0 + if not component_errors: + return INF + + local_errors = list(component_errors) + det_to_local_errors: Dict[int, List[int]] = defaultdict(list) + for local_ei, ei in enumerate(local_errors): + for d in self.error_detectors[ei]: + det_to_local_errors[int(d)].append(local_ei) + + active_set = set(component_active) + inactive_set = set(component_inactive) + deg: Dict[int, int] = { + d: len(det_to_local_errors.get(d, [])) + for d in active_set | inactive_set + } + alive = np.ones(len(local_errors), dtype=bool) + queue: deque[int] = deque(d for d in component_inactive if deg.get(d, 0) == 1) + + while queue: + d = queue.popleft() + if deg.get(d, 0) != 1: + continue + forced_local = next( + (local_ei for local_ei in det_to_local_errors.get(d, []) if alive[local_ei]), + None, + ) + if forced_local is None: + deg[d] = 0 + continue + if not alive[forced_local]: + continue + alive[forced_local] = False + for dd in self.error_detectors[local_errors[forced_local]]: + dd = int(dd) + if dd not in deg or deg[dd] <= 0: + continue + deg[dd] -= 1 + if dd in active_set and deg[dd] == 0: + return INF + if dd in inactive_set and deg[dd] == 1: + queue.append(dd) + + for d in component_active: + if deg.get(d, 0) <= 0: + return INF + + reduced_errors = [ei for local_ei, ei in enumerate(local_errors) if alive[local_ei]] + if not reduced_errors: + return INF + + local_error_index = {ei: local_ei for local_ei, ei in enumerate(reduced_errors)} + det_to_reduced_errors: Dict[int, List[int]] = defaultdict(list) + for ei in reduced_errors: + local_ei = local_error_index[ei] + for d in self.error_detectors[ei]: + d = int(d) + if deg.get(d, 0) > 0: + det_to_reduced_errors[d].append(local_ei) + + inactive_with_incidence = [d for d in component_inactive if deg.get(d, 0) > 0] + num_x = len(reduced_errors) + num_s = len(inactive_with_incidence) + num_vars = num_x + num_s + c = np.zeros(num_vars, dtype=np.float64) + c[:num_x] = self.weights[reduced_errors] + + inactive_col = {d: num_x + i for i, d in enumerate(inactive_with_incidence)} + + ub_rows: List[int] = [] + ub_cols: List[int] = [] + ub_data: List[float] = [] + b_ub: List[float] = [] + ub_r = 0 + + for d in component_active: + incident = det_to_reduced_errors.get(d, []) + if not incident: + return INF + for local_ei in incident: + ub_rows.append(ub_r) + ub_cols.append(local_ei) + ub_data.append(-1.0) + b_ub.append(-1.0) + ub_r += 1 + + for d in inactive_with_incidence: + s_col = inactive_col[d] + for local_ei in det_to_reduced_errors[d]: + ub_rows.extend([ub_r, ub_r]) + ub_cols.extend([local_ei, s_col]) + ub_data.extend([2.0, -1.0]) + b_ub.append(0.0) + ub_r += 1 + + eq_rows: List[int] = [] + eq_cols: List[int] = [] + eq_data: List[float] = [] + b_eq: List[float] = [] + eq_r = 0 + + for d in inactive_with_incidence: + s_col = inactive_col[d] + eq_rows.append(eq_r) + eq_cols.append(s_col) + eq_data.append(1.0) + for local_ei in det_to_reduced_errors[d]: + eq_rows.append(eq_r) + eq_cols.append(local_ei) + eq_data.append(-1.0) + b_eq.append(0.0) + eq_r += 1 + + a_ub = None + if ub_r > 0: + a_ub = csr_matrix( + (ub_data, (ub_rows, ub_cols)), + shape=(ub_r, num_vars), + dtype=np.float64, + ) + + a_eq = None + if eq_r > 0: + a_eq = csr_matrix( + (eq_data, (eq_rows, eq_cols)), + shape=(eq_r, num_vars), + dtype=np.float64, + ) + + self.lp_calls += 1 + result = linprog( + c=c, + A_ub=a_ub, + b_ub=np.array(b_ub, dtype=np.float64) if b_ub else None, + A_eq=a_eq, + b_eq=np.array(b_eq, dtype=np.float64) if b_eq else None, + bounds=[(0.0, None)] * num_vars, + method="highs", + ) + if not result.success or result.fun is None: + return INF + return float(result.fun) + + def heuristic_exact_lp_plus_inactive( + self, + dets: np.ndarray, + available_errs: np.ndarray, + ) -> Tuple[float, Optional[np.ndarray]]: + if not np.any(dets): + return 0.0, None + if not self._has_cover_for_all_active_detectors(dets, available_errs): + return INF, None + + total = 0.0 + for component_active, component_inactive, component_errors in self._reachable_available_components( + dets, + available_errs, + ): + component_value = self._solve_component_exact_lp_plus_inactive( + component_active, + component_inactive, + component_errors, + ) + if math.isinf(component_value): + return INF, None + total += component_value + return float(total), None + + def evaluate_named_heuristic(self, support_data: SupportData, name: str) -> Tuple[float, Optional[np.ndarray]]: + if name == "plain": + return self.heuristic_plain(support_data) + if name == "asc_deg": + return self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if name == "desc_plain": + return self.heuristic_saturation_zero(support_data, order_kind="desc_plain") + if name == "plain_sweep": + return self.heuristic_plain_sweep(support_data) + if name == "best_of_two": + v1, y1 = self.heuristic_plain_sweep(support_data) + v2, y2 = self.heuristic_saturation_zero(support_data, order_kind="asc_deg") + if v1 >= v2: + return v1, y1 + return v2, y2 + if name == "best_of_three": + candidates = [ + self.heuristic_plain_sweep(support_data), + self.heuristic_saturation_zero(support_data, order_kind="asc_deg"), + self.heuristic_saturation_zero(support_data, order_kind="desc_plain"), + ] + return max(candidates, key=lambda t: t[0]) + if name == "exact_lp": + return self.heuristic_exact_lp(support_data) + raise ValueError(f"Unknown heuristic {name!r}") + + def compute_support_based_heuristic( + self, + dets: np.ndarray, + errs: np.ndarray, + blocked_errs: np.ndarray, + *, + name: Optional[str] = None, + ) -> Tuple[float, Optional[np.ndarray]]: + self.heuristic_calls += 1 + available = self._available_errors(errs, blocked_errs) + heuristic_name = name or self.heuristic_name + if heuristic_name == "exact_lp_plus_inactive": + return self.heuristic_exact_lp_plus_inactive(dets, available) + support_data = self.build_support_data(dets, available) + return self.evaluate_named_heuristic(support_data, heuristic_name) + + def project_child_y( + self, + parent_state: SearchState, + child_dets: np.ndarray, + child_errs: np.ndarray, + child_blocked_errs: np.ndarray, + child_det_counts: np.ndarray, + flipped_detectors: Sequence[int], + ) -> Tuple[float, Optional[np.ndarray], str]: + if parent_state.y_prices is None: + raise AssertionError("Expected a stored feasible y vector before projecting to a child.") + + self.heuristic_calls += 1 + self.projection_heuristic_calls += 1 + available = self._available_errors(child_errs, child_blocked_errs) + if not self._has_cover_for_all_active_detectors(child_dets, available): + return INF, None, "projected" + + y_projected = np.zeros(self.num_detectors, dtype=np.float64) + keep = parent_state.dets & child_dets + y_projected[keep] = parent_state.y_prices[keep] + projected_value = float(y_projected[np.flatnonzero(child_dets)].sum()) + best_value = projected_value + best_y = y_projected + best_source = "projected" + + if self.projection_combine_max_plain: + plain_value, plain_y = self.plain_detcost_from_counts(child_dets, available, child_det_counts) + if plain_y is None: + return INF, None, "plain" + if plain_value > best_value + HEURISTIC_EPS: + best_value = plain_value + best_y = plain_y + best_source = "plain" + + return best_value, best_y, best_source + + def report_root_heuristics(self, dets: np.ndarray, errs: np.ndarray, blocked_errs: np.ndarray) -> List[Tuple[str, float]]: + available = self._available_errors(errs, blocked_errs) + support_data = self.build_support_data(dets, available) + names = [ + "plain", + "asc_deg", + "desc_plain", + "plain_sweep", + "best_of_two", + "best_of_three", + "exact_lp", + "exact_lp_plus_inactive", + ] + out: List[Tuple[str, float]] = [] + saved_lp_calls = self.lp_calls + for name in names: + if name == "exact_lp_plus_inactive": + value, _ = self.heuristic_exact_lp_plus_inactive(dets, available) + else: + value, _ = self.evaluate_named_heuristic(support_data, name) + out.append((name, value)) + self.lp_calls = saved_lp_calls + return out + + def _maybe_refine_node(self, state: SearchState) -> Tuple[SearchState, bool]: + if state.refined or self.heuristic_name == "plain" or not self.lazy_reinsert_heuristics: + return state, False + + previous_source = state.h_source + projectable = self.heuristic_has_projectable_prices(self.heuristic_name) + self.refinement_calls += 1 + new_value, new_y = self.compute_support_based_heuristic( + state.dets, + state.errs, + state.blocked_errs, + name=self.heuristic_name, + ) + if math.isinf(new_value): + if previous_source == "projected": + self.projected_nodes_refined += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost:.6f} new_h=INF delta=INF reinserted=False discarded=True" + ) + state.h_cost = INF + state.h_source = "refined" + if projectable: + state.y_prices = None + state.refined = True + return state, True + if projectable and new_y is None: + raise AssertionError(f"Expected projectable y-prices from heuristic {self.heuristic_name!r}.") + + delta = new_value - state.h_cost + self.total_refinement_gain += max(0.0, delta) + self.max_refinement_gain = max(self.max_refinement_gain, max(0.0, delta)) + + if self.heuristic_name in {"exact_lp", "exact_lp_plus_inactive"} and new_value + 1e-7 < state.h_cost: + raise AssertionError( + f"Exact LP refinement {new_value} below stored projected value {state.h_cost}." + ) + + if new_value > state.h_cost + HEURISTIC_EPS: + if previous_source == "projected": + self.projected_nodes_refined += 1 + state.h_cost = new_value + state.h_source = "refined" + if projectable: + state.y_prices = new_y + state.refined = True + self.reinserts += 1 + if self.verbose_search: + print( + f" refine approx_h={state.h_cost - delta:.6f} new_h={new_value:.6f} delta={delta:.6f} reinserted=True discarded=False" + ) + return state, True + + # Non-improving recomputation: keep the existing projectable feasible point unless the + # selected heuristic returned a fresh one that can still be projected to children. + if previous_source == "projected": + self.projected_nodes_refined += 1 + if projectable and abs(new_value - state.h_cost) <= HEURISTIC_EPS and new_y is not None: + state.y_prices = new_y + state.refined = True + if self.verbose_search: + new_text = "INF" if math.isinf(new_value) else f"{new_value:.6f}" + print( + f" refine approx_h={state.h_cost:.6f} new_h={new_text} delta={delta:.6f} reinserted=False discarded=False" + ) + return state, False + + def decode(self, shot_dets: np.ndarray, det_beam: float = INF) -> DecodeResult: + start_time = time.perf_counter() + self.reset_stats() + + dets0 = np.array(shot_dets, dtype=bool, copy=True) + errs0 = np.zeros(self.num_errors, dtype=bool) + blocked0 = np.zeros(self.num_errors, dtype=bool) + det_counts0 = np.zeros(self.num_errors, dtype=np.uint16) + for d in np.flatnonzero(dets0): + for ei in self.d2e[int(d)]: + det_counts0[int(ei)] += 1 + + root_h, root_y = self.plain_detcost_from_counts(dets0, self._available_errors(errs0, blocked0), det_counts0) + if root_y is None or math.isinf(root_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + root_refined = (self.heuristic_name == "plain") or (not self.lazy_reinsert_heuristics) + if root_refined and self.heuristic_name != "plain": + # Eager mode: use the selected heuristic immediately. + eager_h, eager_y = self.compute_support_based_heuristic(dets0, errs0, blocked0, name=self.heuristic_name) + if math.isinf(eager_h): + return DecodeResult( + success=False, + errs=errs0, + residual_dets=dets0, + cost=INF, + nodes_pushed=1, + nodes_popped=0, + max_queue_size=1, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + if self.heuristic_has_projectable_prices(self.heuristic_name): + if eager_y is None: + raise AssertionError(f"Expected projectable y-prices from heuristic {self.heuristic_name!r}.") + root_y = eager_y + root_h = eager_h + + root_state = SearchState( + errs=errs0, + blocked_errs=blocked0, + dets=dets0, + det_counts=det_counts0, + g_cost=0.0, + h_cost=root_h, + h_source="plain" if not root_refined else ("plain" if self.heuristic_name == "plain" else "refined"), + refined=root_refined, + y_prices=root_y, + ) + + heap: List[Tuple[float, int, int, SearchState]] = [] + counter = 0 + heapq.heappush(heap, (root_state.g_cost + root_state.h_cost, int(dets0.sum()), counter, root_state)) + counter += 1 + nodes_pushed = 1 + nodes_popped = 0 + max_queue_size = 1 + min_num_dets = int(dets0.sum()) + + while heap: + max_queue_size = max(max_queue_size, len(heap)) + f_cost, num_dets, _entry_id, state = heapq.heappop(heap) + nodes_popped += 1 + max_num_dets = min_num_dets + det_beam + if num_dets > max_num_dets: + continue + if num_dets < min_num_dets: + min_num_dets = num_dets + max_num_dets = min_num_dets + det_beam + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f"len(heap)={len(heap)} nodes_pushed={nodes_pushed} nodes_popped={nodes_popped} " + f"lp_calls={self.lp_calls} reinserts={self.reinserts} proj_generated={self.projected_nodes_generated} " + f"proj_refined={self.projected_nodes_refined} proj_unrefined_so_far={projected_unrefined} " + f"num_dets={num_dets} max_num_dets={max_num_dets} f={f_cost:.6f} g={state.g_cost:.6f} " + f"h={state.h_cost:.6f} h_source={state.h_source} refined={state.refined}" + ) + + if num_dets == 0: + return DecodeResult( + success=True, + errs=state.errs, + residual_dets=state.dets, + cost=state.g_cost, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + state, should_reinsert = self._maybe_refine_node(state) + if should_reinsert: + if state.y_prices is None or math.isinf(state.h_cost): + if state.h_source == "projected": + self.projected_nodes_refined += 1 + continue + if state.h_source != "plain": + heapq.heappush(heap, (state.g_cost + state.h_cost, num_dets, counter, state)) + counter += 1 + continue + + min_det = int(np.flatnonzero(state.dets)[0]) + prefix_blocked = state.blocked_errs.copy() + children_generated = 0 + children_beam_pruned = 0 + children_infeasible = 0 + children_projected = 0 + + for ei in self.d2e[min_det]: + ei = int(ei) + prefix_blocked[ei] = True + if state.errs[ei] or state.blocked_errs[ei]: + continue + + child_errs = state.errs.copy() + child_errs[ei] = True + child_blocked = prefix_blocked.copy() + child_dets = state.dets.copy() + child_det_counts = state.det_counts.copy() + for d in self.error_detectors[ei]: + d = int(d) + if child_dets[d]: + child_dets[d] = False + for oei in self.d2e[d]: + child_det_counts[int(oei)] -= 1 + else: + child_dets[d] = True + for oei in self.d2e[d]: + child_det_counts[int(oei)] += 1 + + child_num_dets = int(child_dets.sum()) + if child_num_dets > max_num_dets: + children_beam_pruned += 1 + continue + + child_g = state.g_cost + float(self.weights[ei]) + if self.heuristic_name == "plain" or (not self.lazy_reinsert_heuristics): + child_h, child_y = self.compute_support_based_heuristic( + child_dets, child_errs, child_blocked, name=self.heuristic_name + ) + child_source = "plain" if self.heuristic_name == "plain" else "refined" + child_refined = True + else: + if state.y_prices is None: + raise AssertionError("Expected parent feasible y-prices before projecting to child.") + child_h, child_y, child_source = self.project_child_y( + state, + child_dets, + child_errs, + child_blocked, + child_det_counts, + self.error_detectors[ei], + ) + self.projected_nodes_generated += 1 + children_projected += 1 + child_refined = False + + if math.isinf(child_h): + children_infeasible += 1 + continue + if ( + child_refined + and self.heuristic_has_projectable_prices(self.heuristic_name) + and child_y is None + ): + raise AssertionError(f"Expected projectable y-prices from heuristic {self.heuristic_name!r}.") + + child_state = SearchState( + errs=child_errs, + blocked_errs=child_blocked, + dets=child_dets, + det_counts=child_det_counts, + g_cost=child_g, + h_cost=child_h, + h_source=child_source, + refined=child_refined, + y_prices=child_y, + ) + heapq.heappush(heap, (child_g + child_h, child_num_dets, counter, child_state)) + counter += 1 + nodes_pushed += 1 + children_generated += 1 + + if self.verbose_search: + projected_unrefined = self.projected_nodes_generated - self.projected_nodes_refined + print( + f" expanded children_generated={children_generated} children_projected={children_projected} " + f"beam_pruned={children_beam_pruned} infeasible={children_infeasible} " + f"lp_calls={self.lp_calls} proj_unrefined_so_far={projected_unrefined}" + ) + + return DecodeResult( + success=False, + errs=np.zeros(self.num_errors, dtype=bool), + residual_dets=np.array(shot_dets, dtype=bool, copy=True), + cost=INF, + nodes_pushed=nodes_pushed, + nodes_popped=nodes_popped, + max_queue_size=max_queue_size, + heuristic_calls=self.heuristic_calls, + plain_heuristic_calls=self.plain_heuristic_calls, + projection_heuristic_calls=self.projection_heuristic_calls, + refinement_calls=self.refinement_calls, + lp_calls=self.lp_calls, + reinserts=self.reinserts, + projected_nodes_generated=self.projected_nodes_generated, + projected_nodes_refined=self.projected_nodes_refined, + projected_nodes_unrefined_at_finish=self.projected_nodes_generated - self.projected_nodes_refined, + total_refinement_gain=self.total_refinement_gain, + max_refinement_gain=self.max_refinement_gain, + elapsed_seconds=time.perf_counter() - start_time, + ) + + def cost_from_errs(self, errs: np.ndarray) -> float: + return float(self.weights[errs].sum()) + + def detectors_from_errs(self, errs: np.ndarray) -> np.ndarray: + dets = np.zeros(self.num_detectors, dtype=bool) + for ei in np.flatnonzero(errs): + for d in self.error_detectors[int(ei)]: + dets[d] ^= True + return dets + + def observables_from_errs(self, errs: np.ndarray) -> np.ndarray: + parity: Dict[int, bool] = {} + for ei in np.flatnonzero(errs): + for obs in self.error_observables[int(ei)]: + parity[int(obs)] = not parity.get(int(obs), False) + return np.array(sorted(obs for obs, bit in parity.items() if bit), dtype=np.int32) + + +def sample_detections_and_observables( + circuit: stim.Circuit, + *, + num_shots: int, + seed: int, + num_detectors: int, + num_observables: int, +) -> Tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets_packed, obs_packed = sampler.sample( + shots=num_shots, + separate_observables=True, + bit_packed=True, + ) + dets_unpacked = np.unpackbits( + dets_packed, + bitorder="little", + axis=1, + count=num_detectors, + ) + obs_unpacked = np.unpackbits( + obs_packed, + bitorder="little", + axis=1, + count=num_observables, + ) + return dets_unpacked.astype(bool), obs_unpacked.astype(bool) + + +def parse_det_beam(text: str) -> float: + lowered = text.strip().lower() + if lowered in {"inf", "infinity", "none"}: + return INF + value = int(text) + if value < 0: + raise argparse.ArgumentTypeError("det-beam must be non-negative or 'inf'.") + return float(value) + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Prototype A* decoder for Stim circuits using greedy singleton-budget heuristics." + ) + ) + parser.add_argument("--circuit", type=Path, required=True, help="Path to a .stim circuit file.") + parser.add_argument( + "--dets", + type=str, + default=None, + help="String of shot dets (e.g., 'shot D0 D1 L2') to parse instead of sampling.", + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=100, + help="Number of shots to sample from Stim before selecting --shot (default: 100).", + ) + parser.add_argument( + "--shot", + type=int, + default=0, + help="Index of the sampled shot to decode (default: 0).", + ) + parser.add_argument( + "--seed", + type=int, + default=27123839530, + help="Stim sampler seed (default: 27123839530).", + ) + parser.add_argument( + "--det-beam", + type=parse_det_beam, + default=INF, + help="Beam cutoff on the residual detector count; use 'inf' for none.", + ) + parser.add_argument( + "--heuristic", + choices=[ + "plain", + "asc_deg", + "desc_plain", + "plain_sweep", + "best_of_two", + "best_of_three", + "exact_lp", + "exact_lp_plus_inactive", + ], + default="best_of_two", + help="Lower-bound heuristic to use during A* search (default: best_of_two).", + ) + parser.add_argument( + "--lazy-reinsert-heuristics", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For non-plain heuristics, seed nodes with plain detcost, refine on pop, and reinsert when the selected " + "heuristic improves the key (default: enabled)." + ), + ) + parser.add_argument( + "--projection-combine-max-plain", + action=argparse.BooleanOptionalAction, + default=True, + help="When projecting parent y-prices to a child, take max(projected, plain detcost) (default: enabled).", + ) + parser.add_argument( + "--merge-errors", + action=argparse.BooleanOptionalAction, + default=True, + help="Merge indistinguishable DEM errors before decoding (default: enabled).", + ) + parser.add_argument( + "--respect-blocked-errors-in-heuristic", + action=argparse.BooleanOptionalAction, + default=True, + help="Exclude precedence-blocked errors when forming the lower bound (default: enabled).", + ) + parser.add_argument( + "--report-all-root-heuristics", + action="store_true", + help="Print all root-node heuristic values, including exact_lp and exact_lp_plus_inactive, for the selected shot.", + ) + parser.add_argument( + "--skip-decode", + action="store_true", + help="Only report root heuristics; do not run A* search.", + ) + parser.add_argument( + "--show-shot-detectors", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the selected shot's active detector IDs (default: enabled).", + ) + parser.add_argument( + "--show-error-indices", + action=argparse.BooleanOptionalAction, + default=True, + help="Print the decoded merged-error indices when decoding succeeds (default: enabled).", + ) + parser.add_argument( + "--verbose-search", + action="store_true", + help="Print per-node search diagnostics.", + ) + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_arg_parser() + args = parser.parse_args(argv) + + if args.sample_num_shots <= 0: + parser.error("--sample-num-shots must be positive.") + if args.shot < 0: + parser.error("--shot must be non-negative.") + if args.shot >= args.sample_num_shots: + parser.error("--shot must be smaller than --sample-num-shots.") + + circuit = stim.Circuit.from_file(str(args.circuit)) + dem = circuit.detector_error_model(decompose_errors=False) + errors = merged_errors_from_dem(dem) if args.merge_errors else list(iter_dem_errors_from_dem(dem)) + + if args.dets is not None: + shot_dets = np.zeros(dem.num_detectors, dtype=bool) + shot_obs = np.zeros(dem.num_observables, dtype=bool) + for token in args.dets.split(): + if token == "shot": + continue + if token.startswith("D") and token[1:].isdigit(): + d_idx = int(token[1:]) + if d_idx < dem.num_detectors: + shot_dets[d_idx] = True + elif token.startswith("L") and token[1:].isdigit(): + l_idx = int(token[1:]) + if l_idx < dem.num_observables: + shot_obs[l_idx] = True + else: + dets, obs = sample_detections_and_observables( + circuit, + num_shots=args.sample_num_shots, + seed=args.seed, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + ) + shot_dets = dets[args.shot] + shot_obs = obs[args.shot] + + decoder = GreedySingletonHeuristicDecoder( + errors, + num_detectors=dem.num_detectors, + num_observables=dem.num_observables, + heuristic=args.heuristic, + respect_blocked_errors_in_heuristic=args.respect_blocked_errors_in_heuristic, + lazy_reinsert_heuristics=args.lazy_reinsert_heuristics, + projection_combine_max_plain=args.projection_combine_max_plain, + verbose_search=args.verbose_search, + ) + + print(f"circuit = {args.circuit}") + print(f"heuristic = {args.heuristic}") + print(f"mode = {decoder.mode_name}") + print(f"sample_num_shots = {args.sample_num_shots}") + print(f"shot = {args.shot}") + print(f"num_errors = {decoder.num_errors}") + print(f"num_detectors = {decoder.num_detectors}") + print(f"num_observables = {decoder.num_observables}") + print(f"det_beam = {args.det_beam}") + print(f"merge_errors = {args.merge_errors}") + print(f"respect_blocked_errors_in_heuristic = {args.respect_blocked_errors_in_heuristic}") + print(f"lazy_reinsert_heuristics = {args.lazy_reinsert_heuristics}") + print(f"projection_combine_max_plain = {args.projection_combine_max_plain}") + + if args.show_shot_detectors: + active_dets = np.flatnonzero(shot_dets) + print("shot_detectors =", " ".join(f"D{d}" for d in active_dets)) + + if args.report_all_root_heuristics: + root_errs = np.zeros(decoder.num_errors, dtype=bool) + root_blocked = np.zeros(decoder.num_errors, dtype=bool) + report = decoder.report_root_heuristics(shot_dets, root_errs, root_blocked) + exact = next((v for k, v in report if k == "exact_lp"), None) + print("root_heuristics:") + for name, value in report: + if exact is not None and not math.isinf(exact) and exact > 0: + ratio = value / exact if not math.isinf(value) else INF + ratio_text = "INF" if math.isinf(ratio) else f"{ratio:.6f}" + else: + ratio_text = "n/a" + value_text = "INF" if math.isinf(value) else f"{value:.12f}" + print(f" {name:>24s} value={value_text} ratio_to_exact={ratio_text}") + + if args.skip_decode: + return 0 + + result = decoder.decode(shot_dets, det_beam=args.det_beam) + print(f"success = {result.success}") + print(f"nodes_pushed = {result.nodes_pushed}") + print(f"nodes_popped = {result.nodes_popped}") + print(f"max_queue_size = {result.max_queue_size}") + print(f"heuristic_calls = {result.heuristic_calls}") + print(f"plain_heuristic_calls = {result.plain_heuristic_calls}") + print(f"projection_heuristic_calls = {result.projection_heuristic_calls}") + print(f"refinement_calls = {result.refinement_calls}") + print(f"lp_calls = {result.lp_calls}") + print(f"reinserts = {result.reinserts}") + print(f"projected_nodes_generated = {result.projected_nodes_generated}") + print(f"projected_nodes_refined = {result.projected_nodes_refined}") + print(f"projected_nodes_unrefined_at_finish = {result.projected_nodes_unrefined_at_finish}") + print(f"total_refinement_gain = {result.total_refinement_gain:.6f}") + print(f"max_refinement_gain = {result.max_refinement_gain:.6f}") + print(f"elapsed_seconds = {result.elapsed_seconds:.6f}") + + if not result.success: + print("decode failed") + return 1 + + if args.show_error_indices: + print("decoded_error_indices =", " ".join(map(str, np.flatnonzero(result.errs).tolist()))) + + reproduced_dets = decoder.detectors_from_errs(result.errs) + if not np.array_equal(reproduced_dets, shot_dets): + raise AssertionError("Decoded errors do not reproduce the sampled detection events.") + + decoded_cost = decoder.cost_from_errs(result.errs) + predicted_obs = decoder.observables_from_errs(result.errs) + sampled_obs = np.flatnonzero(shot_obs) + + print(f"num_decoded_errors = {int(result.errs.sum())}") + print(f"decoded_cost = {decoded_cost:.12f}") + print("predicted_observables =", " ".join(f"L{o}" for o in predicted_obs.tolist())) + print("sampled_observables =", " ".join(f"L{o}" for o in sampled_obs.tolist())) + print(f"observables_match = {bool(np.array_equal(predicted_obs, sampled_obs))}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From d5dbba63a667cd5e188a5ba5d7e42436f93b4681 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Mon, 30 Mar 2026 18:19:48 -0700 Subject: [PATCH 06/25] - Removed repeated copying of component constraints into temporary row_supports / rhs vectors before simplex solves. - Reused a single local_hits buffer while building deduplicated active supports. - Removed an unnecessary sort of local_hits; the local detector order was already stable. - Reused a monotone warm-start cursor when reading detector budgets from the parent exact solution, instead of restarting the search for every component. --- src/tesseract_ftl.cc | 70 ++++++++++++++++++++------------------------ src/tesseract_ftl.h | 11 ++++--- 2 files changed, 37 insertions(+), 44 deletions(-) diff --git a/src/tesseract_ftl.cc b/src/tesseract_ftl.cc index 953e59b..c1eb5d6 100644 --- a/src/tesseract_ftl.cc +++ b/src/tesseract_ftl.cc @@ -107,8 +107,9 @@ double dot_on_support(const std::vector& values, const std::vector& // x >= 0 // where A is a 0/1 matrix given by row supports for a selected subset of rows. DenseSimplexResult solve_dense_primal_packing_lp( - size_t num_vars, const std::vector>& row_supports, - const std::vector& selected_rows, const std::vector& rhs, + size_t num_vars, + const std::vector& constraints, + const std::vector& selected_rows, const std::vector* entering_priorities = nullptr) { DenseSimplexResult result; result.solution.assign(num_vars, 0.0); @@ -130,13 +131,13 @@ DenseSimplexResult solve_dense_primal_packing_lp( for (size_t row = 0; row < num_rows; ++row) { size_t orig_row = (size_t)selected_rows[row]; - for (int col : row_supports[orig_row]) { + for (int col : constraints[orig_row].local_detectors) { tableau[row * width + (size_t)col] = 1.0; } tableau[row * width + num_vars + row] = 1.0; - tableau[row * width + width - 1] = rhs[orig_row]; + tableau[row * width + width - 1] = constraints[orig_row].rhs; basis[row] = num_vars + row; - if (rhs[orig_row] < -SIMPLEX_EPS) { + if (constraints[orig_row].rhs < -SIMPLEX_EPS) { throw std::runtime_error("Dense simplex received a negative RHS."); } } @@ -255,8 +256,9 @@ struct SingletonComponentSolveResult { }; SingletonComponentSolveResult solve_singleton_component_lp( - size_t num_local_detectors, const std::vector>& row_supports, - const std::vector& rhs, const std::vector& cheapest_constraint_for_local_detector, + size_t num_local_detectors, + const std::vector& constraints, + const std::vector& cheapest_constraint_for_local_detector, const std::vector& seed_budgets) { SingletonComponentSolveResult result; result.detector_budgets.assign(num_local_detectors, 0.0); @@ -265,16 +267,16 @@ SingletonComponentSolveResult solve_singleton_component_lp( result.success = true; return result; } - if (row_supports.empty()) { + if (constraints.empty()) { result.unbounded = true; return result; } const double seed_total = std::accumulate(seed_budgets.begin(), seed_budgets.end(), 0.0); - std::vector selected(row_supports.size(), 0); + std::vector selected(constraints.size(), 0); std::vector selected_indices; - selected_indices.reserve(std::min(row_supports.size(), num_local_detectors * 2 + 4)); + selected_indices.reserve(std::min(constraints.size(), num_local_detectors * 2 + 4)); auto add_constraint = [&](int idx) { if (idx < 0) return; @@ -284,16 +286,17 @@ SingletonComponentSolveResult solve_singleton_component_lp( } }; - for (size_t row = 0; row < row_supports.size(); ++row) { - const double slack = rhs[row] - dot_on_support(seed_budgets, row_supports[row]); - if (slack <= SEED_TIGHT_EPS * (1.0 + rhs[row])) { + for (size_t row = 0; row < constraints.size(); ++row) { + const auto& constraint = constraints[row]; + const double slack = constraint.rhs - dot_on_support(seed_budgets, constraint.local_detectors); + if (slack <= SEED_TIGHT_EPS * (1.0 + constraint.rhs)) { add_constraint((int)row); } } std::vector covered(num_local_detectors, 0); for (int idx : selected_indices) { - for (int local : row_supports[(size_t)idx]) covered[(size_t)local] = 1; + for (int local : constraints[(size_t)idx].local_detectors) covered[(size_t)local] = 1; } for (size_t local = 0; local < num_local_detectors; ++local) { if (!covered[local]) { @@ -302,7 +305,7 @@ SingletonComponentSolveResult solve_singleton_component_lp( throw std::runtime_error("Missing seed constraint for active detector."); } add_constraint(idx); - for (int touched : row_supports[(size_t)idx]) covered[(size_t)touched] = 1; + for (int touched : constraints[(size_t)idx].local_detectors) covered[(size_t)touched] = 1; } } @@ -312,12 +315,13 @@ SingletonComponentSolveResult solve_singleton_component_lp( size_t rounds = 0; while (true) { - if (++rounds > row_supports.size() + 1) { + if (++rounds > constraints.size() + 1) { throw std::runtime_error("Constraint generation exceeded the number of unique constraints."); } - DenseSimplexResult simplex = solve_dense_primal_packing_lp( - num_local_detectors, row_supports, selected_indices, rhs, &seed_budgets); + DenseSimplexResult simplex = + solve_dense_primal_packing_lp(num_local_detectors, constraints, selected_indices, + &seed_budgets); result.simplex_solves++; if (simplex.unbounded) { result.unbounded = true; @@ -334,14 +338,15 @@ SingletonComponentSolveResult solve_singleton_component_lp( std::vector> top_violated; top_violated.reserve(VIOLATION_BATCH_SIZE); - for (size_t row = 0; row < row_supports.size(); ++row) { + for (size_t row = 0; row < constraints.size(); ++row) { if (selected[row]) continue; - const double lhs = dot_on_support(simplex.solution, row_supports[row]); - const double violation = lhs - rhs[row]; + const auto& constraint = constraints[row]; + const double lhs = dot_on_support(simplex.solution, constraint.local_detectors); + const double violation = lhs - constraint.rhs; if (violation > max_violation) { max_violation = violation; } - if (violation <= VIOLATION_EPS * (1.0 + rhs[row])) continue; + if (violation <= VIOLATION_EPS * (1.0 + constraint.rhs)) continue; top_violated.emplace_back(violation, (int)row); std::sort(top_violated.begin(), top_violated.end(), @@ -622,13 +627,14 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c std::vector, double, IntVectorHash>> min_rhs_by_pattern( result.components.size()); + std::vector local_hits; + local_hits.reserve(16); for (size_t ei = 0; ei < num_errors; ++ei) { if (blocked_flags[ei]) continue; int component_index = -1; - std::vector local_hits; - local_hits.reserve(edets[ei].size()); + local_hits.clear(); for (int detector : edets[ei]) { const int active_pos = detector_to_active_pos[(size_t)detector]; @@ -643,7 +649,6 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c if (component_index < 0) continue; - std::sort(local_hits.begin(), local_hits.end()); auto& rhs_map = min_rhs_by_pattern[(size_t)component_index]; const double rhs = errors[ei].likelihood_cost; auto it = rhs_map.find(local_hits); @@ -722,16 +727,15 @@ TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset const ExactSubsetSolution* warm_solution = warm_solution_idx >= 0 ? &exact_solution_arena[(size_t)warm_solution_idx] : nullptr; - solution.value = 0.0; solution.num_components = build.components.size(); std::vector> detector_budget_pairs; detector_budget_pairs.reserve(detectors.count()); + size_t warm_pos = 0; for (const auto& component : build.components) { std::vector seed_budgets(component.detectors.size(), 0.0); if (warm_solution != nullptr) { - size_t warm_pos = 0; for (size_t local = 0; local < component.detectors.size(); ++local) { int det = component.detectors[local]; while (warm_pos < warm_solution->active_detectors.size() && @@ -744,18 +748,8 @@ TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset } } } - - std::vector> row_supports; - std::vector rhs; - row_supports.reserve(component.constraints.size()); - rhs.reserve(component.constraints.size()); - for (const auto& constraint : component.constraints) { - row_supports.push_back(constraint.local_detectors); - rhs.push_back(constraint.rhs); - } - const auto component_result = solve_singleton_component_lp( - component.detectors.size(), row_supports, rhs, + component.detectors.size(), component.constraints, component.cheapest_constraint_for_local_detector, seed_budgets); stats.lp_calls += component_result.simplex_solves; diff --git a/src/tesseract_ftl.h b/src/tesseract_ftl.h index c34eb09..ec4971c 100644 --- a/src/tesseract_ftl.h +++ b/src/tesseract_ftl.h @@ -114,6 +114,11 @@ struct TesseractFTLDecoder { size_t num_detectors = 0; TesseractFTLStats stats; + struct SingletonPatternConstraint { + std::vector local_detectors; + double rhs = 0.0; + }; + private: struct ErrorCost { double likelihood_cost = 0; @@ -134,12 +139,6 @@ struct TesseractFTLDecoder { bool operator>(const FTLNode& other) const; }; - - struct SingletonPatternConstraint { - std::vector local_detectors; - double rhs = 0.0; - }; - struct SingletonLPComponent { std::vector detectors; std::vector constraints; From 90462b9e5dc4cbbcf610a9e543fefe75a2353ee8 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Mon, 30 Mar 2026 20:48:36 -0700 Subject: [PATCH 07/25] =?UTF-8?q?=E2=80=A2=20Optimized=20tesseract=5Fftl?= =?UTF-8?q?=20singleton-LP=20hot=20paths=20without=20changing=20the=20sear?= =?UTF-8?q?ch=20behavior=20on=20the=20full=20100-shot=20benchmark.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Kept changes: - Avoid rebuilding detector parity from the full error chain on pop by caching each node’s residual detector bitset. - Speed up exact-refinement LP setup by reusing buffers, avoiding unnecessary temporary row/rhs copies, removing an unnecessary local-hit sort, and using a monotone warm-start cursor when pulling detector budgets from the parent exact solution. - Replace full-error scans during component construction with candidate-error gathering from the active detectors’ incidence lists. - Simplify union-find component assignment by mapping roots directly to component indices. - Add detailed FTL timing/stat counters for chain replay, component build, simplex, projection, and component-build subphases to guide further optimization. --- src/tesseract_ftl.cc | 308 +++++++++++++++++++++++++++----------- src/tesseract_ftl.h | 19 ++- src/tesseract_ftl_main.cc | 20 +++ 3 files changed, 261 insertions(+), 86 deletions(-) diff --git a/src/tesseract_ftl.cc b/src/tesseract_ftl.cc index c1eb5d6..3c927b8 100644 --- a/src/tesseract_ftl.cc +++ b/src/tesseract_ftl.cc @@ -81,8 +81,9 @@ std::ostream& operator<<(std::ostream& os, const std::vector& vec) { return os; } +template struct IntVectorHash { - size_t operator()(const std::vector& values) const { + size_t operator()(const T& values) const { return boost::hash_range(values.begin(), values.end()); } }; @@ -95,7 +96,8 @@ struct DenseSimplexResult { std::vector solution; }; -double dot_on_support(const std::vector& values, const std::vector& support) { +template +double dot_on_support(const std::vector& values, const T& support) { double total = 0.0; for (int idx : support) total += values[(size_t)idx]; return total; @@ -401,7 +403,8 @@ std::string TesseractFTLConfig::str() { ss << "det_orders=" << det_orders << ", "; ss << "det_penalty=" << det_penalty << ", "; ss << "create_visualization=" << create_visualization << ", "; - ss << "subset_detcost_size=" << subset_detcost_size; + ss << "subset_detcost_size=" << subset_detcost_size << ", "; + ss << "ignore_blocked_errors_in_heuristic=" << ignore_blocked_errors_in_heuristic; ss << ")"; return ss.str(); } @@ -425,6 +428,17 @@ void TesseractFTLStats::accumulate(const TesseractFTLStats& other) { total_lp_refinement_gain += other.total_lp_refinement_gain; max_lp_refinement_gain = std::max(max_lp_refinement_gain, other.max_lp_refinement_gain); lp_total_seconds += other.lp_total_seconds; + chain_replay_total_seconds += other.chain_replay_total_seconds; + component_build_total_seconds += other.component_build_total_seconds; + component_candidate_total_seconds += other.component_candidate_total_seconds; + component_union_total_seconds += other.component_union_total_seconds; + component_dedup_total_seconds += other.component_dedup_total_seconds; + component_finalize_total_seconds += other.component_finalize_total_seconds; + simplex_total_seconds += other.simplex_total_seconds; + projection_total_seconds += other.projection_total_seconds; + component_build_calls += other.component_build_calls; + simplex_calls += other.simplex_calls; + projection_calls += other.projection_calls; } bool TesseractFTLDecoder::FTLNode::operator>(const FTLNode& other) const { @@ -513,6 +527,8 @@ void TesseractFTLDecoder::initialize_structures(size_t num_detectors_) { d2e.resize(num_detectors_); edets.resize(num_errors); error_costs.resize(num_errors); + candidate_error_marks.assign(num_errors, 0); + candidate_error_mark_epoch = 1; for (size_t ei = 0; ei < num_errors; ++ei) { edets[ei] = errors[ei].symptom.detectors; @@ -549,9 +565,26 @@ void TesseractFTLDecoder::flip_detectors_and_block_errors( } } +void block_errors_from_chain(const std::vector& error_chain_arena, + const std::vector>& d2e, int64_t error_chain_idx, + std::vector& blocked_flags) { + int64_t walker_idx = error_chain_idx; + while (walker_idx != -1) { + const auto& node = error_chain_arena[(size_t)walker_idx]; + const size_t ei = node.error_index; + const size_t min_detector = node.min_detector; + for (int oei : d2e[min_detector]) { + blocked_flags[(size_t)oei] = 1; + if ((size_t)oei == ei) break; + } + walker_idx = node.parent_idx; + } +} + TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_components( - const boost::dynamic_bitset<>& detectors, const std::vector& blocked_flags) const { + const boost::dynamic_bitset<>& detectors, const std::vector& blocked_flags) { SingletonBuildResult result; + const auto candidate_start_time = std::chrono::high_resolution_clock::now(); std::vector active_detectors; active_detectors.reserve(detectors.count()); @@ -563,13 +596,34 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c } if (active_detectors.empty()) return result; + if (candidate_error_mark_epoch == std::numeric_limits::max()) { + std::fill(candidate_error_marks.begin(), candidate_error_marks.end(), 0); + candidate_error_mark_epoch = 1; + } + const uint64_t mark_epoch = candidate_error_mark_epoch++; + std::vector candidate_errors; + for (int detector : active_detectors) { + for (int ei : d2e[(size_t)detector]) { + if (blocked_flags[(size_t)ei]) continue; + if (candidate_error_marks[(size_t)ei] == mark_epoch) continue; + candidate_error_marks[(size_t)ei] = mark_epoch; + candidate_errors.push_back(ei); + } + } + const auto candidate_stop_time = std::chrono::high_resolution_clock::now(); + stats.component_candidate_total_seconds += + std::chrono::duration_cast(candidate_stop_time - + candidate_start_time) + .count() / + 1e6; + + const auto union_start_time = std::chrono::high_resolution_clock::now(); UnionFind uf(active_detectors.size()); std::vector has_available(active_detectors.size(), 0); - for (size_t ei = 0; ei < num_errors; ++ei) { - if (blocked_flags[ei]) continue; + for (int ei : candidate_errors) { int first_active = -1; - for (int detector : edets[ei]) { + for (int detector : edets[(size_t)ei]) { const int active_pos = detector_to_active_pos[(size_t)detector]; if (active_pos < 0) continue; has_available[(size_t)active_pos] = 1; @@ -588,55 +642,40 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c } } - std::vector> positions_by_root(active_detectors.size()); - for (int active_pos = 0; active_pos < (int)active_detectors.size(); ++active_pos) { - positions_by_root[(size_t)uf.find(active_pos)].push_back(active_pos); - } - - std::vector> component_positions; - component_positions.reserve(active_detectors.size()); - for (auto& positions : positions_by_root) { - if (positions.empty()) continue; - std::sort(positions.begin(), positions.end(), [&](int a, int b) { - return active_detectors[(size_t)a] < active_detectors[(size_t)b]; - }); - component_positions.push_back(std::move(positions)); - } - std::sort(component_positions.begin(), component_positions.end(), - [&](const auto& a, const auto& b) { - return active_detectors[(size_t)a[0]] < active_detectors[(size_t)b[0]]; - }); - + std::vector root_to_component_index(active_detectors.size(), -1); std::vector active_pos_to_component(active_detectors.size(), -1); std::vector active_pos_to_local(active_detectors.size(), -1); - - result.components.reserve(component_positions.size()); - for (size_t component_index = 0; component_index < component_positions.size(); - ++component_index) { - const auto& positions = component_positions[component_index]; - SingletonLPComponent component; - component.detectors.reserve(positions.size()); - for (size_t local = 0; local < positions.size(); ++local) { - const int active_pos = positions[local]; - active_pos_to_component[(size_t)active_pos] = (int)component_index; - active_pos_to_local[(size_t)active_pos] = (int)local; - component.detectors.push_back(active_detectors[(size_t)active_pos]); + result.components.reserve(active_detectors.size()); + for (int active_pos = 0; active_pos < (int)active_detectors.size(); ++active_pos) { + const int root = uf.find(active_pos); + int& component_index = root_to_component_index[(size_t)root]; + if (component_index < 0) { + component_index = (int)result.components.size(); + result.components.emplace_back(); } - result.components.push_back(std::move(component)); + auto& component = result.components[(size_t)component_index]; + active_pos_to_component[(size_t)active_pos] = component_index; + active_pos_to_local[(size_t)active_pos] = (int)component.detectors.size(); + component.detectors.push_back(active_detectors[(size_t)active_pos]); } - - std::vector, double, IntVectorHash>> min_rhs_by_pattern( + const auto union_stop_time = std::chrono::high_resolution_clock::now(); + stats.component_union_total_seconds += + std::chrono::duration_cast(union_stop_time - union_start_time) + .count() / + 1e6; + + const auto dedup_start_time = std::chrono::high_resolution_clock::now(); + std::vector, double, IntVectorHash>>> + min_rhs_by_pattern( result.components.size()); std::vector local_hits; local_hits.reserve(16); - for (size_t ei = 0; ei < num_errors; ++ei) { - if (blocked_flags[ei]) continue; - + for (int ei : candidate_errors) { int component_index = -1; local_hits.clear(); - for (int detector : edets[ei]) { + for (int detector : edets[(size_t)ei]) { const int active_pos = detector_to_active_pos[(size_t)detector]; if (active_pos < 0) continue; if (component_index < 0) { @@ -648,19 +687,23 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c } if (component_index < 0) continue; - + const double rhs = errors[(size_t)ei].likelihood_cost; auto& rhs_map = min_rhs_by_pattern[(size_t)component_index]; - const double rhs = errors[ei].likelihood_cost; auto it = rhs_map.find(local_hits); if (it == rhs_map.end() || rhs < it->second) { rhs_map[local_hits] = rhs; } } + const auto dedup_stop_time = std::chrono::high_resolution_clock::now(); + stats.component_dedup_total_seconds += + std::chrono::duration_cast(dedup_stop_time - dedup_start_time) + .count() / + 1e6; + const auto finalize_start_time = std::chrono::high_resolution_clock::now(); for (size_t component_index = 0; component_index < result.components.size(); ++component_index) { auto& component = result.components[component_index]; const auto& rhs_map = min_rhs_by_pattern[component_index]; - component.constraints.reserve(rhs_map.size()); for (const auto& [local_hits, rhs] : rhs_map) { component.constraints.push_back({local_hits, rhs}); @@ -697,6 +740,11 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c } } } + const auto finalize_stop_time = std::chrono::high_resolution_clock::now(); + stats.component_finalize_total_seconds += + std::chrono::duration_cast(finalize_stop_time - finalize_start_time) + .count() / + 1e6; return result; } @@ -709,7 +757,20 @@ TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset const auto start_time = std::chrono::high_resolution_clock::now(); ExactSubsetSolution solution; - const auto build = build_singleton_components(detectors, blocked_flags); + std::vector ignored_blocked_flags; + const std::vector* effective_blocked_flags = &blocked_flags; + if (config.ignore_blocked_errors_in_heuristic) { + ignored_blocked_flags.assign(num_errors, 0); + effective_blocked_flags = &ignored_blocked_flags; + } + const auto build_start_time = std::chrono::high_resolution_clock::now(); + const auto build = build_singleton_components(detectors, *effective_blocked_flags); + const auto build_stop_time = std::chrono::high_resolution_clock::now(); + stats.component_build_calls++; + stats.component_build_total_seconds += + std::chrono::duration_cast(build_stop_time - build_start_time) + .count() / + 1e6; if (!build.feasible) { solution.value = INF_D; const auto stop_time = std::chrono::high_resolution_clock::now(); @@ -748,9 +809,17 @@ TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset } } } + const auto simplex_start_time = std::chrono::high_resolution_clock::now(); const auto component_result = solve_singleton_component_lp( component.detectors.size(), component.constraints, component.cheapest_constraint_for_local_detector, seed_budgets); + const auto simplex_stop_time = std::chrono::high_resolution_clock::now(); + stats.simplex_calls++; + stats.simplex_total_seconds += + std::chrono::duration_cast(simplex_stop_time - + simplex_start_time) + .count() / + 1e6; stats.lp_calls += component_result.simplex_solves; if (component_result.unbounded) { @@ -793,19 +862,33 @@ double TesseractFTLDecoder::project_from_exact_solution(const ExactSubsetSolutio const std::vector& blocked_flags) { stats.heuristic_calls++; stats.projection_heuristic_calls++; + const auto start_time = std::chrono::high_resolution_clock::now(); + stats.projection_calls++; double total = 0.0; size_t budget_pos = 0; + const std::vector* effective_blocked_flags = &blocked_flags; + std::vector ignored_blocked_flags; + if (config.ignore_blocked_errors_in_heuristic) { + ignored_blocked_flags.assign(num_errors, 0); + effective_blocked_flags = &ignored_blocked_flags; + } for (size_t detector = detectors.find_first(); detector != boost::dynamic_bitset<>::npos; detector = detectors.find_next(detector)) { bool has_available = false; for (int ei : d2e[detector]) { - if (!blocked_flags[(size_t)ei]) { + if (!(*effective_blocked_flags)[(size_t)ei]) { has_available = true; break; } } - if (!has_available) return INF_D; + if (!has_available) { + const auto stop_time = std::chrono::high_resolution_clock::now(); + stats.projection_total_seconds += + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + return INF_D; + } while (budget_pos < solution.active_detectors.size() && solution.active_detectors[budget_pos] < (int)detector) { @@ -816,6 +899,9 @@ double TesseractFTLDecoder::project_from_exact_solution(const ExactSubsetSolutio total += solution.detector_budgets[budget_pos]; } } + const auto stop_time = std::chrono::high_resolution_clock::now(); + stats.projection_total_seconds += + std::chrono::duration_cast(stop_time - start_time).count() / 1e6; return total; } @@ -823,7 +909,9 @@ void TesseractFTLDecoder::reset_decode_state() { low_confidence_flag = false; predicted_errors_buffer.clear(); error_chain_arena.clear(); + detector_state_arena.clear(); exact_solution_arena.clear(); + exact_solution_cache.clear(); stats.clear(); } @@ -905,6 +993,7 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio if (config.pqlimit != std::numeric_limits::max()) { const size_t reserve_size = std::min(config.pqlimit, 5000000); error_chain_arena.reserve(reserve_size); + detector_state_arena.reserve(reserve_size + 1); exact_solution_arena.reserve(reserve_size / 4 + 1); } @@ -930,6 +1019,8 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio root.num_dets = min_num_dets; root.depth = 0; root.error_chain_idx = -1; + detector_state_arena.push_back(initial_detectors); + root.detector_state_idx = 0; root.warm_solution_idx = -1; root.exact_solution_idx = -1; @@ -941,6 +1032,9 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio } exact_solution_arena.push_back(std::move(root_exact)); root.exact_solution_idx = (int64_t)exact_solution_arena.size() - 1; + if (config.ignore_blocked_errors_in_heuristic) { + exact_solution_cache.emplace(initial_detectors, root.exact_solution_idx); + } root.f_cost = exact_solution_arena.back().value; root.h_cost = exact_solution_arena.back().value; root.exact_refined = true; @@ -957,9 +1051,15 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio if (node.num_dets > max_num_dets) continue; - boost::dynamic_bitset<> detectors = initial_detectors; + boost::dynamic_bitset<> detectors = detector_state_arena[(size_t)node.detector_state_idx]; std::vector blocked_flags(num_errors, 0); - flip_detectors_and_block_errors(detector_order, node.error_chain_idx, detectors, blocked_flags); + const auto chain_start_time = std::chrono::high_resolution_clock::now(); + block_errors_from_chain(error_chain_arena, d2e, node.error_chain_idx, blocked_flags); + const auto chain_stop_time = std::chrono::high_resolution_clock::now(); + stats.chain_replay_total_seconds += + std::chrono::duration_cast(chain_stop_time - chain_start_time) + .count() / + 1e6; if (config.verbose) { const size_t projected_unrefined = @@ -1008,42 +1108,78 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio if (!node.exact_refined) { const double prev_h = node.h_cost; const FTLHeuristicSource prev_source = node.heuristic_source; - ExactSubsetSolution exact_solution = - solve_exact_subset_lp(detectors, blocked_flags, node.warm_solution_idx); + bool used_cached_exact_solution = false; + int64_t cached_exact_solution_idx = -1; + if (config.ignore_blocked_errors_in_heuristic) { + auto it = exact_solution_cache.find(detectors); + if (it != exact_solution_cache.end()) { + used_cached_exact_solution = true; + cached_exact_solution_idx = it->second; + } + } if (prev_source == FTLHeuristicSource::kProjected) stats.projected_nodes_refined++; - if (exact_solution.value == INF_D) { + if (used_cached_exact_solution) { + node.exact_solution_idx = cached_exact_solution_idx; + node.h_cost = exact_solution_arena[(size_t)cached_exact_solution_idx].value; + const double delta = node.h_cost - prev_h; + if (node.h_cost + 1e-7 < prev_h) { + throw std::runtime_error("Cached singleton lower bound fell below stored lower bound."); + } + stats.total_lp_refinement_gain += delta; + stats.max_lp_refinement_gain = std::max(stats.max_lp_refinement_gain, delta); + node.f_cost = node.g_cost + node.h_cost; + node.exact_refined = true; + node.heuristic_source = FTLHeuristicSource::kExact; + if (delta > HEURISTIC_EPS) { + stats.lp_reinserts++; + pq.push(node); + stats.num_pq_pushed++; + if (stats.num_pq_pushed > config.pqlimit) { + low_confidence_flag = true; + return; + } + continue; + } + } else { + ExactSubsetSolution exact_solution = + solve_exact_subset_lp(detectors, blocked_flags, node.warm_solution_idx); + if (exact_solution.value == INF_D) { + if (config.verbose) { + std::cout << " lp_refine exact_h=INF discarded=true" << std::endl; + } + continue; + } + if (exact_solution.value + 1e-7 < prev_h) { + throw std::runtime_error("Exact singleton lower bound fell below stored lower bound."); + } + const double delta = exact_solution.value - prev_h; + stats.total_lp_refinement_gain += delta; + stats.max_lp_refinement_gain = std::max(stats.max_lp_refinement_gain, delta); + exact_solution_arena.push_back(std::move(exact_solution)); + node.exact_solution_idx = (int64_t)exact_solution_arena.size() - 1; + if (config.ignore_blocked_errors_in_heuristic) { + exact_solution_cache.emplace(detectors, node.exact_solution_idx); + } + node.h_cost = exact_solution_arena.back().value; + node.f_cost = node.g_cost + node.h_cost; + node.exact_refined = true; + node.heuristic_source = FTLHeuristicSource::kExact; if (config.verbose) { - std::cout << " lp_refine exact_h=INF discarded=true" << std::endl; + std::cout << " lp_refine approx_h=" << prev_h << " exact_h=" << node.h_cost + << " delta=" << delta << " vars=" << exact_solution_arena.back().num_variables + << " constraints=" << exact_solution_arena.back().num_constraints + << " reinserted=" << (delta > HEURISTIC_EPS) << std::endl; } - continue; - } - if (exact_solution.value + 1e-7 < prev_h) { - throw std::runtime_error("Exact singleton lower bound fell below stored lower bound."); - } - const double delta = exact_solution.value - prev_h; - stats.total_lp_refinement_gain += delta; - stats.max_lp_refinement_gain = std::max(stats.max_lp_refinement_gain, delta); - exact_solution_arena.push_back(std::move(exact_solution)); - node.exact_solution_idx = (int64_t)exact_solution_arena.size() - 1; - node.h_cost = exact_solution_arena.back().value; - node.f_cost = node.g_cost + node.h_cost; - node.exact_refined = true; - node.heuristic_source = FTLHeuristicSource::kExact; - if (config.verbose) { - std::cout << " lp_refine approx_h=" << prev_h << " exact_h=" << node.h_cost - << " delta=" << delta << " vars=" << exact_solution_arena.back().num_variables - << " constraints=" << exact_solution_arena.back().num_constraints - << " reinserted=" << (delta > HEURISTIC_EPS) << std::endl; - } - if (delta > HEURISTIC_EPS) { - stats.lp_reinserts++; - pq.push(node); - stats.num_pq_pushed++; - if (stats.num_pq_pushed > config.pqlimit) { - low_confidence_flag = true; - return; + if (delta > HEURISTIC_EPS) { + stats.lp_reinserts++; + pq.push(node); + stats.num_pq_pushed++; + if (stats.num_pq_pushed > config.pqlimit) { + low_confidence_flag = true; + return; + } + continue; } - continue; } } @@ -1108,6 +1244,8 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio child.num_dets = child_num_dets; child.depth = node.depth + 1; child.error_chain_idx = (int64_t)error_chain_arena.size() - 1; + detector_state_arena.push_back(std::move(child_detectors)); + child.detector_state_idx = (int64_t)detector_state_arena.size() - 1; child.warm_solution_idx = node.exact_solution_idx; child.exact_solution_idx = -1; child.exact_refined = false; diff --git a/src/tesseract_ftl.h b/src/tesseract_ftl.h index ec4971c..5e5f74a 100644 --- a/src/tesseract_ftl.h +++ b/src/tesseract_ftl.h @@ -46,6 +46,7 @@ struct TesseractFTLConfig { std::vector> det_orders; double det_penalty = 0; bool create_visualization = false; + bool ignore_blocked_errors_in_heuristic = false; // 0 = delegate to the original Tesseract detcost heuristic. // 1 = use the singleton fractional lower bound implemented in this file. @@ -72,6 +73,17 @@ struct TesseractFTLStats { double total_lp_refinement_gain = 0.0; double max_lp_refinement_gain = 0.0; double lp_total_seconds = 0.0; + double chain_replay_total_seconds = 0.0; + double component_build_total_seconds = 0.0; + double component_candidate_total_seconds = 0.0; + double component_union_total_seconds = 0.0; + double component_dedup_total_seconds = 0.0; + double component_finalize_total_seconds = 0.0; + double simplex_total_seconds = 0.0; + double projection_total_seconds = 0.0; + size_t component_build_calls = 0; + size_t simplex_calls = 0; + size_t projection_calls = 0; void clear(); void accumulate(const TesseractFTLStats& other); @@ -132,6 +144,7 @@ struct TesseractFTLDecoder { size_t num_dets = 0; size_t depth = 0; int64_t error_chain_idx = -1; + int64_t detector_state_idx = -1; int64_t warm_solution_idx = -1; int64_t exact_solution_idx = -1; bool exact_refined = false; @@ -169,7 +182,11 @@ struct TesseractFTLDecoder { size_t num_errors = 0; std::vector error_costs; std::vector error_chain_arena; + std::vector> detector_state_arena; std::vector exact_solution_arena; + std::unordered_map, int64_t, DynamicBitsetHash> exact_solution_cache; + mutable std::vector candidate_error_marks; + mutable uint64_t candidate_error_mark_epoch = 1; // If subset_detcost_size == 0, delegate to the original Tesseract decoder. std::unique_ptr plain_delegate; @@ -181,7 +198,7 @@ struct TesseractFTLDecoder { std::vector& blocked_flags) const; SingletonBuildResult build_singleton_components(const boost::dynamic_bitset<>& detectors, - const std::vector& blocked_flags) const; + const std::vector& blocked_flags); ExactSubsetSolution solve_exact_subset_lp(const boost::dynamic_bitset<>& detectors, const std::vector& blocked_flags, diff --git a/src/tesseract_ftl_main.cc b/src/tesseract_ftl_main.cc index bf129e1..5517948 100644 --- a/src/tesseract_ftl_main.cc +++ b/src/tesseract_ftl_main.cc @@ -65,6 +65,7 @@ struct Args { size_t pqlimit; size_t subset_detcost_size = 0; + bool ignore_blocked_errors_in_heuristic = false; bool verbose = false; bool print_stats = false; @@ -151,6 +152,7 @@ struct Args { config.merge_errors = !no_merge_errors; config.subset_detcost_size = subset_detcost_size; + config.ignore_blocked_errors_in_heuristic = ignore_blocked_errors_in_heuristic; { DetOrder order = DetOrder::DetBFS; @@ -257,6 +259,10 @@ int main(int argc, char* argv[]) { .help("0 = plain detcost delegate, 1 = singleton fractional lower bound") .default_value(size_t(0)) .store_into(args.subset_detcost_size); + program.add_argument("--ignore-blocked-errors-in-heuristic") + .help("Experimental: ignore precedence-blocked errors when computing the FTL LP heuristic") + .flag() + .store_into(args.ignore_blocked_errors_in_heuristic); program.add_argument("--num-det-orders") .help("Number of ways to orient the manifold when reordering the detectors") @@ -445,6 +451,7 @@ int main(int argc, char* argv[]) { {"num_det_orders", args.num_det_orders}, {"det_order_seed", args.det_order_seed}, {"subset_detcost_size", args.subset_detcost_size}, + {"ignore_blocked_errors_in_heuristic", args.ignore_blocked_errors_in_heuristic}, {"total_time_seconds", total_time_seconds}, {"num_errors", num_errors}, {"num_low_confidence", num_low_confidence}, @@ -465,6 +472,19 @@ int main(int argc, char* argv[]) { {"ftl_total_lp_refinement_gain", decoder_stats_total.total_lp_refinement_gain}, {"ftl_max_lp_refinement_gain", decoder_stats_total.max_lp_refinement_gain}, {"ftl_lp_total_seconds", decoder_stats_total.lp_total_seconds}, + {"ftl_chain_replay_total_seconds", decoder_stats_total.chain_replay_total_seconds}, + {"ftl_component_build_total_seconds", decoder_stats_total.component_build_total_seconds}, + {"ftl_component_candidate_total_seconds", + decoder_stats_total.component_candidate_total_seconds}, + {"ftl_component_union_total_seconds", decoder_stats_total.component_union_total_seconds}, + {"ftl_component_dedup_total_seconds", decoder_stats_total.component_dedup_total_seconds}, + {"ftl_component_finalize_total_seconds", + decoder_stats_total.component_finalize_total_seconds}, + {"ftl_simplex_total_seconds", decoder_stats_total.simplex_total_seconds}, + {"ftl_projection_total_seconds", decoder_stats_total.projection_total_seconds}, + {"ftl_component_build_calls", decoder_stats_total.component_build_calls}, + {"ftl_simplex_calls", decoder_stats_total.simplex_calls}, + {"ftl_projection_calls", decoder_stats_total.projection_calls}, }; if (args.stats_out_fname == "-") { From cf0b14ed47f0fecfc844d4f3618aad5fb2cc3113 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Tue, 31 Mar 2026 08:38:26 -0700 Subject: [PATCH 08/25] add several CLI flags: --num-min-dets-to-consider --detector-choice-policy order, fewest_incident_errors, --error-order-policy, --root-det-order-count, --root-det-order-depth, --exact-child-refine-count --- src/tesseract_ftl.cc | 332 +++++++++++++++++++++++++++++++------- src/tesseract_ftl.h | 42 +++++ src/tesseract_ftl_main.cc | 93 +++++++++++ 3 files changed, 407 insertions(+), 60 deletions(-) diff --git a/src/tesseract_ftl.cc b/src/tesseract_ftl.cc index 3c927b8..9b6f884 100644 --- a/src/tesseract_ftl.cc +++ b/src/tesseract_ftl.cc @@ -389,6 +389,30 @@ std::string heuristic_source_to_string(FTLHeuristicSource source) { return "unknown"; } +std::string detector_choice_policy_to_string(FTLDetectorChoicePolicy policy) { + switch (policy) { + case FTLDetectorChoicePolicy::kOrder: + return "order"; + case FTLDetectorChoicePolicy::kFewestIncidentErrors: + return "fewest_incident_errors"; + case FTLDetectorChoicePolicy::kLargestBudget: + return "largest_budget"; + case FTLDetectorChoicePolicy::kLargestBudgetPerIncident: + return "largest_budget_per_incident"; + } + return "unknown"; +} + +std::string error_order_policy_to_string(FTLErrorOrderPolicy policy) { + switch (policy) { + case FTLErrorOrderPolicy::kStatic: + return "static"; + case FTLErrorOrderPolicy::kReducedCost: + return "reduced_cost"; + } + return "unknown"; +} + } // namespace std::string TesseractFTLConfig::str() { @@ -404,7 +428,14 @@ std::string TesseractFTLConfig::str() { ss << "det_penalty=" << det_penalty << ", "; ss << "create_visualization=" << create_visualization << ", "; ss << "subset_detcost_size=" << subset_detcost_size << ", "; - ss << "ignore_blocked_errors_in_heuristic=" << ignore_blocked_errors_in_heuristic; + ss << "ignore_blocked_errors_in_heuristic=" << ignore_blocked_errors_in_heuristic << ", "; + ss << "num_min_dets_to_consider=" << num_min_dets_to_consider << ", "; + ss << "detector_choice_policy=" + << detector_choice_policy_to_string(detector_choice_policy) << ", "; + ss << "error_order_policy=" << error_order_policy_to_string(error_order_policy) << ", "; + ss << "root_det_order_count=" << root_det_order_count << ", "; + ss << "root_det_order_depth=" << root_det_order_depth << ", "; + ss << "exact_child_refine_count=" << exact_child_refine_count; ss << ")"; return ss.str(); } @@ -439,6 +470,20 @@ void TesseractFTLStats::accumulate(const TesseractFTLStats& other) { component_build_calls += other.component_build_calls; simplex_calls += other.simplex_calls; projection_calls += other.projection_calls; + detector_choice_calls += other.detector_choice_calls; + error_ordering_calls += other.error_ordering_calls; + total_active_detectors_popped += other.total_active_detectors_popped; + total_root_order_candidates += other.total_root_order_candidates; + total_min_detector_candidates += other.total_min_detector_candidates; + total_min_detectors_selected += other.total_min_detectors_selected; + total_min_detector_available_errors += other.total_min_detector_available_errors; + total_min_detector_blocked_errors += other.total_min_detector_blocked_errors; + total_child_candidates_considered += other.total_child_candidates_considered; + total_children_generated += other.total_children_generated; + total_children_beam_pruned += other.total_children_beam_pruned; + total_children_infeasible += other.total_children_infeasible; + total_selected_min_detector_budget += other.total_selected_min_detector_budget; + exact_child_pre_refinements += other.exact_child_pre_refinements; } bool TesseractFTLDecoder::FTLNode::operator>(const FTLNode& other) const { @@ -905,6 +950,148 @@ double TesseractFTLDecoder::project_from_exact_solution(const ExactSubsetSolutio return total; } +std::vector TesseractFTLDecoder::select_min_detectors( + const boost::dynamic_bitset<>& detectors, const std::vector& blocked_flags, + size_t detector_order, size_t depth, const ExactSubsetSolution& exact_solution) { + stats.detector_choice_calls++; + stats.total_active_detectors_popped += detectors.count(); + + struct CandidateDetector { + size_t detector; + size_t order_rank; + size_t available_errors; + double budget; + }; + + const size_t order_count = + depth < config.root_det_order_depth ? std::min(config.root_det_order_count, config.det_orders.size()) : 1; + std::vector seen(num_detectors, 0); + std::vector candidates; + candidates.reserve(detectors.count()); + + size_t discovery_rank = 0; + for (size_t order_offset = 0; order_offset < order_count; ++order_offset) { + size_t taken_from_order = 0; + const size_t order_index = (detector_order + order_offset) % config.det_orders.size(); + for (size_t offset = 0; offset < num_detectors; ++offset) { + const size_t detector = config.det_orders[order_index][offset]; + if (!detectors[detector]) continue; + if (!seen[detector]) { + seen[detector] = 1; + size_t available_errors = 0; + for (int ei : d2e[detector]) { + if (!blocked_flags[(size_t)ei]) { + available_errors++; + } + } + candidates.push_back({detector, discovery_rank++, available_errors, + lookup_detector_budget(exact_solution, (int)detector)}); + } + taken_from_order++; + if (config.detector_choice_policy == FTLDetectorChoicePolicy::kOrder && + taken_from_order >= config.num_min_dets_to_consider) { + break; + } + } + } + + stats.total_root_order_candidates += candidates.size(); + stats.total_min_detector_candidates += candidates.size(); + + if (config.detector_choice_policy != FTLDetectorChoicePolicy::kOrder) { + std::stable_sort(candidates.begin(), candidates.end(), [&](const auto& a, const auto& b) { + switch (config.detector_choice_policy) { + case FTLDetectorChoicePolicy::kOrder: + break; + case FTLDetectorChoicePolicy::kFewestIncidentErrors: + if (a.available_errors != b.available_errors) { + return a.available_errors < b.available_errors; + } + break; + case FTLDetectorChoicePolicy::kLargestBudget: + if (a.budget != b.budget) return a.budget > b.budget; + break; + case FTLDetectorChoicePolicy::kLargestBudgetPerIncident: { + const double a_score = + a.available_errors == 0 ? INF_D : a.budget / (double)a.available_errors; + const double b_score = + b.available_errors == 0 ? INF_D : b.budget / (double)b.available_errors; + if (a_score != b_score) return a_score > b_score; + break; + } + } + if (a.order_rank != b.order_rank) return a.order_rank < b.order_rank; + return a.detector < b.detector; + }); + } + + std::vector selected; + selected.reserve(std::min(config.num_min_dets_to_consider, candidates.size())); + for (const auto& candidate : candidates) { + selected.push_back(candidate.detector); + stats.total_min_detectors_selected++; + stats.total_min_detector_available_errors += candidate.available_errors; + stats.total_selected_min_detector_budget += candidate.budget; + if (selected.size() >= config.num_min_dets_to_consider) break; + } + return selected; +} + +std::vector TesseractFTLDecoder::order_candidate_errors( + size_t min_detector, const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags, const ExactSubsetSolution& exact_solution) { + stats.error_ordering_calls++; + + std::vector ordered_errors; + ordered_errors.reserve(d2e[min_detector].size()); + + if (config.error_order_policy == FTLErrorOrderPolicy::kStatic) { + for (int ei : d2e[min_detector]) { + if (blocked_flags[(size_t)ei]) { + stats.total_min_detector_blocked_errors++; + continue; + } + ordered_errors.push_back(ei); + } + return ordered_errors; + } + + struct CandidateError { + int error_index; + size_t order_rank; + double reduced_cost; + int net_det_delta; + }; + std::vector candidates; + candidates.reserve(d2e[min_detector].size()); + size_t order_rank = 0; + for (int ei : d2e[min_detector]) { + if (blocked_flags[(size_t)ei]) { + stats.total_min_detector_blocked_errors++; + continue; + } + double covered_budget = 0.0; + int net_det_delta = 0; + for (int detector : edets[(size_t)ei]) { + if (detectors[(size_t)detector]) { + covered_budget += lookup_detector_budget(exact_solution, detector); + net_det_delta--; + } else { + net_det_delta++; + } + } + candidates.push_back( + {ei, order_rank++, errors[(size_t)ei].likelihood_cost - covered_budget, net_det_delta}); + } + std::stable_sort(candidates.begin(), candidates.end(), [&](const auto& a, const auto& b) { + if (a.reduced_cost != b.reduced_cost) return a.reduced_cost < b.reduced_cost; + if (a.net_det_delta != b.net_det_delta) return a.net_det_delta < b.net_det_delta; + return a.order_rank < b.order_rank; + }); + for (const auto& candidate : candidates) ordered_errors.push_back(candidate.error_index); + return ordered_errors; +} + void TesseractFTLDecoder::reset_decode_state() { low_confidence_flag = false; predicted_errors_buffer.clear(); @@ -1187,75 +1374,100 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio continue; } - size_t min_detector = std::numeric_limits::max(); - for (size_t offset = 0; offset < num_detectors; ++offset) { - const size_t detector = config.det_orders[detector_order][offset]; - if (detectors[detector]) { - min_detector = detector; - break; - } + const auto& exact_solution = exact_solution_arena[(size_t)node.exact_solution_idx]; + std::vector min_detectors = + select_min_detectors(detectors, blocked_flags, detector_order, node.depth, exact_solution); + if (min_detectors.empty()) { + throw std::runtime_error("Failed to select an active min detector for a non-terminal node."); } - std::vector prefix_blocked = blocked_flags; - size_t children_generated = 0; size_t children_projected = 0; size_t children_beam_pruned = 0; size_t children_infeasible = 0; + size_t children_exactly_refined = 0; + + for (size_t min_detector : min_detectors) { + std::vector prefix_blocked = blocked_flags; + const std::vector ordered_errors = + order_candidate_errors(min_detector, detectors, blocked_flags, exact_solution); + for (int ei : ordered_errors) { + prefix_blocked[(size_t)ei] = 1; + stats.total_child_candidates_considered++; + + boost::dynamic_bitset<> child_detectors = detectors; + size_t child_num_dets = node.num_dets; + for (int detector : edets[(size_t)ei]) { + if (detectors[(size_t)detector]) { + --child_num_dets; + } else { + ++child_num_dets; + } + child_detectors.flip((size_t)detector); + } + if (child_num_dets > max_num_dets) { + children_beam_pruned++; + stats.total_children_beam_pruned++; + continue; + } - for (int ei : d2e[min_detector]) { - prefix_blocked[(size_t)ei] = 1; - if (blocked_flags[(size_t)ei]) continue; - - boost::dynamic_bitset<> child_detectors = detectors; - size_t child_num_dets = node.num_dets; - for (int detector : edets[(size_t)ei]) { - if (detectors[(size_t)detector]) { - --child_num_dets; - } else { - ++child_num_dets; + double child_h = project_from_exact_solution(exact_solution, child_detectors, prefix_blocked); + stats.projected_nodes_generated++; + children_projected++; + if (child_h == INF_D) { + children_infeasible++; + stats.total_children_infeasible++; + continue; } - child_detectors.flip((size_t)detector); - } - if (child_num_dets > max_num_dets) { - children_beam_pruned++; - continue; - } - const double child_h = project_from_exact_solution( - exact_solution_arena[(size_t)node.exact_solution_idx], child_detectors, prefix_blocked); - stats.projected_nodes_generated++; - children_projected++; - if (child_h == INF_D) { - children_infeasible++; - continue; - } + error_chain_arena.emplace_back(); + auto& chain_node = error_chain_arena.back(); + chain_node.error_index = (size_t)ei; + chain_node.min_detector = min_detector; + chain_node.parent_idx = node.error_chain_idx; + + FTLNode child; + child.g_cost = node.g_cost + errors[(size_t)ei].likelihood_cost; + child.h_cost = child_h; + child.f_cost = child.g_cost + child.h_cost; + child.num_dets = child_num_dets; + child.depth = node.depth + 1; + child.error_chain_idx = (int64_t)error_chain_arena.size() - 1; + detector_state_arena.push_back(std::move(child_detectors)); + child.detector_state_idx = (int64_t)detector_state_arena.size() - 1; + child.warm_solution_idx = node.exact_solution_idx; + child.exact_solution_idx = -1; + child.exact_refined = false; + child.heuristic_source = FTLHeuristicSource::kProjected; + + if (config.exact_child_refine_count > 0 && + children_exactly_refined < config.exact_child_refine_count) { + ExactSubsetSolution child_exact = solve_exact_subset_lp( + detector_state_arena[(size_t)child.detector_state_idx], prefix_blocked, + child.warm_solution_idx); + if (child_exact.value == INF_D) { + children_infeasible++; + stats.total_children_infeasible++; + continue; + } + exact_solution_arena.push_back(std::move(child_exact)); + child.exact_solution_idx = (int64_t)exact_solution_arena.size() - 1; + child.h_cost = exact_solution_arena.back().value; + child.f_cost = child.g_cost + child.h_cost; + child.exact_refined = true; + child.heuristic_source = FTLHeuristicSource::kExact; + children_exactly_refined++; + stats.exact_child_pre_refinements++; + } - error_chain_arena.emplace_back(); - auto& chain_node = error_chain_arena.back(); - chain_node.error_index = (size_t)ei; - chain_node.min_detector = min_detector; - chain_node.parent_idx = node.error_chain_idx; - - FTLNode child; - child.g_cost = node.g_cost + errors[(size_t)ei].likelihood_cost; - child.h_cost = child_h; - child.f_cost = child.g_cost + child.h_cost; - child.num_dets = child_num_dets; - child.depth = node.depth + 1; - child.error_chain_idx = (int64_t)error_chain_arena.size() - 1; - detector_state_arena.push_back(std::move(child_detectors)); - child.detector_state_idx = (int64_t)detector_state_arena.size() - 1; - child.warm_solution_idx = node.exact_solution_idx; - child.exact_solution_idx = -1; - child.exact_refined = false; - child.heuristic_source = FTLHeuristicSource::kProjected; - pq.push(child); - stats.num_pq_pushed++; - children_generated++; - if (stats.num_pq_pushed > config.pqlimit) { - low_confidence_flag = true; - return; + pq.push(child); + stats.num_pq_pushed++; + children_generated++; + stats.total_children_generated++; + if (stats.num_pq_pushed > config.pqlimit) { + low_confidence_flag = true; + return; + } } } diff --git a/src/tesseract_ftl.h b/src/tesseract_ftl.h index 5e5f74a..6df0373 100644 --- a/src/tesseract_ftl.h +++ b/src/tesseract_ftl.h @@ -34,6 +34,18 @@ constexpr size_t DEFAULT_FTL_SUBSET_DETCOST_SIZE = 0; +enum class FTLDetectorChoicePolicy : uint8_t { + kOrder = 0, + kFewestIncidentErrors = 1, + kLargestBudget = 2, + kLargestBudgetPerIncident = 3, +}; + +enum class FTLErrorOrderPolicy : uint8_t { + kStatic = 0, + kReducedCost = 1, +}; + struct TesseractFTLConfig { stim::DetectorErrorModel dem; int det_beam = DEFAULT_DET_BEAM; @@ -47,6 +59,12 @@ struct TesseractFTLConfig { double det_penalty = 0; bool create_visualization = false; bool ignore_blocked_errors_in_heuristic = false; + size_t num_min_dets_to_consider = 1; + FTLDetectorChoicePolicy detector_choice_policy = FTLDetectorChoicePolicy::kOrder; + FTLErrorOrderPolicy error_order_policy = FTLErrorOrderPolicy::kStatic; + size_t root_det_order_count = 1; + size_t root_det_order_depth = 0; + size_t exact_child_refine_count = 0; // 0 = delegate to the original Tesseract detcost heuristic. // 1 = use the singleton fractional lower bound implemented in this file. @@ -84,6 +102,20 @@ struct TesseractFTLStats { size_t component_build_calls = 0; size_t simplex_calls = 0; size_t projection_calls = 0; + size_t detector_choice_calls = 0; + size_t error_ordering_calls = 0; + size_t total_active_detectors_popped = 0; + size_t total_root_order_candidates = 0; + size_t total_min_detector_candidates = 0; + size_t total_min_detectors_selected = 0; + size_t total_min_detector_available_errors = 0; + size_t total_min_detector_blocked_errors = 0; + size_t total_child_candidates_considered = 0; + size_t total_children_generated = 0; + size_t total_children_beam_pruned = 0; + size_t total_children_infeasible = 0; + double total_selected_min_detector_budget = 0.0; + size_t exact_child_pre_refinements = 0; void clear(); void accumulate(const TesseractFTLStats& other); @@ -208,6 +240,16 @@ struct TesseractFTLDecoder { const boost::dynamic_bitset<>& detectors, const std::vector& blocked_flags); + std::vector select_min_detectors(const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags, + size_t detector_order, size_t depth, + const ExactSubsetSolution& exact_solution); + + std::vector order_candidate_errors(size_t min_detector, + const boost::dynamic_bitset<>& detectors, + const std::vector& blocked_flags, + const ExactSubsetSolution& exact_solution); + void reset_decode_state(); }; diff --git a/src/tesseract_ftl_main.cc b/src/tesseract_ftl_main.cc index 5517948..9ca6fc0 100644 --- a/src/tesseract_ftl_main.cc +++ b/src/tesseract_ftl_main.cc @@ -27,6 +27,26 @@ #include "tesseract_ftl.h" #include "utils.h" +namespace { + +FTLDetectorChoicePolicy parse_detector_choice_policy(const std::string& value) { + if (value == "order") return FTLDetectorChoicePolicy::kOrder; + if (value == "fewest_incident_errors") return FTLDetectorChoicePolicy::kFewestIncidentErrors; + if (value == "largest_budget") return FTLDetectorChoicePolicy::kLargestBudget; + if (value == "largest_budget_per_incident") { + return FTLDetectorChoicePolicy::kLargestBudgetPerIncident; + } + throw std::invalid_argument("Unknown detector choice policy: " + value); +} + +FTLErrorOrderPolicy parse_error_order_policy(const std::string& value) { + if (value == "static") return FTLErrorOrderPolicy::kStatic; + if (value == "reduced_cost") return FTLErrorOrderPolicy::kReducedCost; + throw std::invalid_argument("Unknown error order policy: " + value); +} + +} // namespace + struct Args { std::string circuit_path; std::string dem_path; @@ -66,6 +86,12 @@ struct Args { size_t subset_detcost_size = 0; bool ignore_blocked_errors_in_heuristic = false; + size_t num_min_dets_to_consider = 1; + std::string detector_choice_policy = "order"; + std::string error_order_policy = "static"; + size_t root_det_order_count = 1; + size_t root_det_order_depth = 0; + size_t exact_child_refine_count = 0; bool verbose = false; bool print_stats = false; @@ -123,6 +149,14 @@ struct Args { if (subset_detcost_size > 1) { throw std::invalid_argument("This prototype currently supports --subset-detcost-size <= 1"); } + if (num_min_dets_to_consider == 0) { + throw std::invalid_argument("--num-min-dets-to-consider must be at least 1"); + } + if (root_det_order_count == 0) { + throw std::invalid_argument("--root-det-order-count must be at least 1"); + } + parse_detector_choice_policy(detector_choice_policy); + parse_error_order_policy(error_order_policy); } void extract(TesseractFTLConfig& config, std::vector& shots, @@ -153,6 +187,12 @@ struct Args { config.merge_errors = !no_merge_errors; config.subset_detcost_size = subset_detcost_size; config.ignore_blocked_errors_in_heuristic = ignore_blocked_errors_in_heuristic; + config.num_min_dets_to_consider = num_min_dets_to_consider; + config.detector_choice_policy = parse_detector_choice_policy(detector_choice_policy); + config.error_order_policy = parse_error_order_policy(error_order_policy); + config.root_det_order_count = root_det_order_count; + config.root_det_order_depth = root_det_order_depth; + config.exact_child_refine_count = exact_child_refine_count; { DetOrder order = DetOrder::DetBFS; @@ -263,6 +303,33 @@ int main(int argc, char* argv[]) { .help("Experimental: ignore precedence-blocked errors when computing the FTL LP heuristic") .flag() .store_into(args.ignore_blocked_errors_in_heuristic); + program.add_argument("--num-min-dets-to-consider") + .help("Experimental: when expanding a node, branch on the first N active detectors in the " + "selected detector order.") + .default_value(size_t(1)) + .store_into(args.num_min_dets_to_consider); + program.add_argument("--detector-choice-policy") + .help("Experimental detector pivot policy: order, fewest_incident_errors, " + "largest_budget, or largest_budget_per_incident.") + .default_value(std::string("order")) + .store_into(args.detector_choice_policy); + program.add_argument("--error-order-policy") + .help("Experimental sibling ordering policy: static or reduced_cost.") + .default_value(std::string("static")) + .store_into(args.error_order_policy); + program.add_argument("--root-det-order-count") + .help("Experimental: at shallow depths, union candidates from the first N detector orders.") + .default_value(size_t(1)) + .store_into(args.root_det_order_count); + program.add_argument("--root-det-order-depth") + .help("Experimental: use root-det-order-count while node depth is less than this value.") + .default_value(size_t(0)) + .store_into(args.root_det_order_depth); + program.add_argument("--exact-child-refine-count") + .help("Experimental exact mode: immediately LP-refine the first N generated children per " + "expanded node.") + .default_value(size_t(0)) + .store_into(args.exact_child_refine_count); program.add_argument("--num-det-orders") .help("Number of ways to orient the manifold when reordering the detectors") @@ -452,6 +519,12 @@ int main(int argc, char* argv[]) { {"det_order_seed", args.det_order_seed}, {"subset_detcost_size", args.subset_detcost_size}, {"ignore_blocked_errors_in_heuristic", args.ignore_blocked_errors_in_heuristic}, + {"num_min_dets_to_consider", args.num_min_dets_to_consider}, + {"detector_choice_policy", args.detector_choice_policy}, + {"error_order_policy", args.error_order_policy}, + {"root_det_order_count", args.root_det_order_count}, + {"root_det_order_depth", args.root_det_order_depth}, + {"exact_child_refine_count", args.exact_child_refine_count}, {"total_time_seconds", total_time_seconds}, {"num_errors", num_errors}, {"num_low_confidence", num_low_confidence}, @@ -485,6 +558,24 @@ int main(int argc, char* argv[]) { {"ftl_component_build_calls", decoder_stats_total.component_build_calls}, {"ftl_simplex_calls", decoder_stats_total.simplex_calls}, {"ftl_projection_calls", decoder_stats_total.projection_calls}, + {"ftl_detector_choice_calls", decoder_stats_total.detector_choice_calls}, + {"ftl_error_ordering_calls", decoder_stats_total.error_ordering_calls}, + {"ftl_total_active_detectors_popped", decoder_stats_total.total_active_detectors_popped}, + {"ftl_total_root_order_candidates", decoder_stats_total.total_root_order_candidates}, + {"ftl_total_min_detector_candidates", decoder_stats_total.total_min_detector_candidates}, + {"ftl_total_min_detectors_selected", decoder_stats_total.total_min_detectors_selected}, + {"ftl_total_min_detector_available_errors", + decoder_stats_total.total_min_detector_available_errors}, + {"ftl_total_min_detector_blocked_errors", + decoder_stats_total.total_min_detector_blocked_errors}, + {"ftl_total_child_candidates_considered", + decoder_stats_total.total_child_candidates_considered}, + {"ftl_total_children_generated", decoder_stats_total.total_children_generated}, + {"ftl_total_children_beam_pruned", decoder_stats_total.total_children_beam_pruned}, + {"ftl_total_children_infeasible", decoder_stats_total.total_children_infeasible}, + {"ftl_total_selected_min_detector_budget", + decoder_stats_total.total_selected_min_detector_budget}, + {"ftl_exact_child_pre_refinements", decoder_stats_total.exact_child_pre_refinements}, }; if (args.stats_out_fname == "-") { @@ -506,6 +597,8 @@ int main(int argc, char* argv[]) { std::cout << " lp_reinserts = " << decoder_stats_total.lp_reinserts; std::cout << " projected_nodes_generated = " << decoder_stats_total.projected_nodes_generated; std::cout << " projected_nodes_refined = " << decoder_stats_total.projected_nodes_refined; + std::cout << " child_candidates = " << decoder_stats_total.total_child_candidates_considered; + std::cout << " children_generated = " << decoder_stats_total.total_children_generated; } std::cout << std::endl; } From 9e98359743dda80f48b0ffea4d246051eb5d76f3 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Fri, 3 Apr 2026 20:29:00 +0800 Subject: [PATCH 09/25] add cool prototype trellis beam decoder --- src/py/astar/trellis_beam.py | 151 +++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 src/py/astar/trellis_beam.py diff --git a/src/py/astar/trellis_beam.py b/src/py/astar/trellis_beam.py new file mode 100644 index 0000000..b4bbed1 --- /dev/null +++ b/src/py/astar/trellis_beam.py @@ -0,0 +1,151 @@ +import stim +from collections import defaultdict +import sys + + +def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> bool | None: + """ + Decodes a syndrome using a dynamic programming sweep with a Top-L beam cutoff. + """ + # 1. Extract the Detector Error Model (flattened, decompose_errors=False) + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + # 2. Parse the DEM into a list of faults + faults = [] + all_possible_dets_mask = 0 + + for inst in dem: + if inst.type != "error": + continue + + p = inst.args_copy()[0] + det_mask = 0 + flip_l0 = 0 + + for t in inst.targets_copy(): + if t.is_separator(): + continue + if t.is_relative_detector_id(): + det_mask ^= (1 << t.val) + elif t.is_logical_observable_id() and t.val == 0: + flip_l0 ^= 1 + + faults.append((p, det_mask, flip_l0)) + all_possible_dets_mask |= det_mask + + # 3. Convert observed syndrome set to an integer bitmask + actual_dets_mask = 0 + for d in actual_dets: + actual_dets_mask ^= (1 << d) + + # If the quantum computer triggered a detector that our error model says + # is mathematically impossible to trigger, the syndrome is invalid. + if (actual_dets_mask & ~all_possible_dets_mask) != 0: + return None + + # 4. Pre-calculate retirement schedules + # retiring_masks[i] stores the bits of detectors that see their final fault at index i. + retiring_masks = [0] * len(faults) + last_seen_index = {} + + for idx, (_, det_mask, _) in enumerate(faults): + temp = det_mask + d_id = 0 + # Extract which bits are set in the mask to find the latest index for each detector + while temp > 0: + if temp & 1: + last_seen_index[d_id] = idx + temp >>= 1 + d_id += 1 + + for d_id, idx in last_seen_index.items(): + retiring_masks[idx] |= (1 << d_id) + + # 5. The Beam Search Sweep + state_probs = {0: [1.0, 0.0]} # active_syndrome_mask -> [P(L0), P(L1)] + + for i, (p, det_mask, flip_l0) in enumerate(faults): + q = 1.0 - p + next_probs = defaultdict(lambda: [0.0, 0.0]) + + # A. Expand the beam + for s, (p0, p1) in state_probs.items(): + # Fault absent + next_probs[s][0] += p0 * q + next_probs[s][1] += p1 * q + + # Fault present + t = s ^ det_mask + if flip_l0: + next_probs[t][0] += p1 * p + next_probs[t][1] += p0 * p + else: + next_probs[t][0] += p0 * p + next_probs[t][1] += p1 * p + + # B. Enforce Reality & Collapse the State Space + retiring_mask = retiring_masks[i] + collapsed_probs = defaultdict(lambda: [0.0, 0.0]) + + for s, (p0, p1) in next_probs.items(): + if retiring_mask != 0: + # If the retiring bits don't match our actual observation, kill the state + if (s & retiring_mask) != (actual_dets_mask & retiring_mask): + continue + + # Zero out the retired bits so states merge properly in the dictionary + shrunk_s = s & ~retiring_mask + collapsed_probs[shrunk_s][0] += p0 + collapsed_probs[shrunk_s][1] += p1 + + # C. Truncate the Beam (Top L Cutoff) + if len(collapsed_probs) > L: + # Sort by total marginal probability: P(L0) + P(L1) + sorted_states = sorted( + collapsed_probs.items(), + key=lambda kv: kv[1][0] + kv[1][1], + reverse=True + ) + state_probs = dict(sorted_states[:L]) + else: + state_probs = dict(collapsed_probs) + + # 6. Final Likelihood Comparison + # Since all bits are retired, the only surviving state mask should be exactly 0. + p0, p1 = state_probs.get(0, (0.0, 0.0)) + + if p0 == p1: + return None # Tie or beam missed the correct path entirely + return p1 > p0 + + +def run_experiment(circuit_fname: str, L: int): + """ + Generates a surface code, samples an error, and decodes it using the beam search. + """ + # print(f"--- Running Distance {d}, Rounds {r}, Beam Size {L} ---") + print(f'Running on circuit {circuit_fname}') + + circuit = stim.Circuit.from_file(circuit_fname) + + sampler = circuit.compile_detector_sampler() + syndromes, logicals = sampler.sample(shots=1, separate_observables=True) + + actual_dets = set(i for i, triggered in enumerate(syndromes[0]) if triggered) + actual_logical = logicals[0][0] + + predicted_logical = decode_beam_search(circuit, actual_dets, L) + + print(f"Total Detectors: {circuit.num_detectors}") + print(f"Triggered Detectors: {len(actual_dets)}") + print(f"Predicted Logical: {predicted_logical}") + print(f"Actual Logical: {bool(actual_logical)}") + + if predicted_logical is None: + print("Result: DECODE FAILED (Tie or Beam too narrow)") + else: + print(f"Result: {'SUCCESS' if predicted_logical == actual_logical else 'LOGICAL ERROR'}") + print() + +if __name__ == '__main__': + run_experiment(sys.argv[1], L=1000) From c3a72fa3c7275e8b7402cee64c4bdf3d5e00dda4 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Fri, 3 Apr 2026 20:46:44 +0800 Subject: [PATCH 10/25] optimized trellis_beam.py --- src/py/astar/trellis_beam.py | 108 ++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 39 deletions(-) diff --git a/src/py/astar/trellis_beam.py b/src/py/astar/trellis_beam.py index b4bbed1..11df2f8 100644 --- a/src/py/astar/trellis_beam.py +++ b/src/py/astar/trellis_beam.py @@ -1,6 +1,8 @@ -import stim -from collections import defaultdict +import heapq import sys +from operator import itemgetter + +import stim def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> bool | None: @@ -30,7 +32,9 @@ def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> elif t.is_logical_observable_id() and t.val == 0: flip_l0 ^= 1 - faults.append((p, det_mask, flip_l0)) + q = 1.0 - p + delta_scale = -p if flip_l0 else p + faults.append((q, p, delta_scale, det_mask)) all_possible_dets_mask |= det_mask # 3. Convert observed syndrome set to an integer bitmask @@ -48,7 +52,7 @@ def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> retiring_masks = [0] * len(faults) last_seen_index = {} - for idx, (_, det_mask, _) in enumerate(faults): + for idx, (_, _, _, det_mask) in enumerate(faults): temp = det_mask d_id = 0 # Extract which bits are set in the mask to find the latest index for each detector @@ -62,61 +66,87 @@ def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> retiring_masks[idx] |= (1 << d_id) # 5. The Beam Search Sweep - state_probs = {0: [1.0, 0.0]} # active_syndrome_mask -> [P(L0), P(L1)] + # Each beam entry is (active_syndrome_mask, total_probability, logical_bias), + # where logical_bias = P(L0) - P(L1). Total probability is enough for beam + # ranking, and the bias preserves the final logical comparison. + beam = [(0, 1.0, 1.0)] - for i, (p, det_mask, flip_l0) in enumerate(faults): - q = 1.0 - p - next_probs = defaultdict(lambda: [0.0, 0.0]) + for i, (q, p, delta_scale, det_mask) in enumerate(faults): + next_probs: dict[int, list[float]] = {} # A. Expand the beam - for s, (p0, p1) in state_probs.items(): + for s, total, delta in beam: # Fault absent - next_probs[s][0] += p0 * q - next_probs[s][1] += p1 * q + entry = next_probs.get(s) + absent_total = total * q + absent_delta = delta * q + if entry is None: + next_probs[s] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta # Fault present t = s ^ det_mask - if flip_l0: - next_probs[t][0] += p1 * p - next_probs[t][1] += p0 * p + present_total = total * p + present_delta = delta * delta_scale + if t == s: + entry = next_probs[s] + entry[0] += present_total + entry[1] += present_delta else: - next_probs[t][0] += p0 * p - next_probs[t][1] += p1 * p + entry = next_probs.get(t) + if entry is None: + next_probs[t] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta # B. Enforce Reality & Collapse the State Space retiring_mask = retiring_masks[i] - collapsed_probs = defaultdict(lambda: [0.0, 0.0]) - - for s, (p0, p1) in next_probs.items(): - if retiring_mask != 0: - # If the retiring bits don't match our actual observation, kill the state - if (s & retiring_mask) != (actual_dets_mask & retiring_mask): - continue - - # Zero out the retired bits so states merge properly in the dictionary - shrunk_s = s & ~retiring_mask - collapsed_probs[shrunk_s][0] += p0 - collapsed_probs[shrunk_s][1] += p1 + if retiring_mask != 0: + collapsed_probs: dict[int, list[float]] = {} + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + + for s, (total, delta) in next_probs.items(): + # If the retiring bits don't match our actual observation, kill the state. + if (s & retiring_mask) != expected_bits: + continue + + shrunk_s = s & keep_mask + entry = collapsed_probs.get(shrunk_s) + if entry is None: + collapsed_probs[shrunk_s] = [total, delta] + else: + entry[0] += total + entry[1] += delta + else: + collapsed_probs = next_probs # C. Truncate the Beam (Top L Cutoff) if len(collapsed_probs) > L: - # Sort by total marginal probability: P(L0) + P(L1) - sorted_states = sorted( - collapsed_probs.items(), - key=lambda kv: kv[1][0] + kv[1][1], - reverse=True + beam = heapq.nlargest( + L, + ( + (state, total, delta) + for state, (total, delta) in collapsed_probs.items() + ), + key=itemgetter(1), ) - state_probs = dict(sorted_states[:L]) else: - state_probs = dict(collapsed_probs) + beam = [ + (state, total, delta) + for state, (total, delta) in collapsed_probs.items() + ] # 6. Final Likelihood Comparison # Since all bits are retired, the only surviving state mask should be exactly 0. - p0, p1 = state_probs.get(0, (0.0, 0.0)) - - if p0 == p1: + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + + if final_delta == 0.0: return None # Tie or beam missed the correct path entirely - return p1 > p0 + return final_delta < 0.0 def run_experiment(circuit_fname: str, L: int): From 249a433aac2bb71ac2339af60e871cebb7e42669 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Mon, 6 Apr 2026 11:54:26 +0800 Subject: [PATCH 11/25] add beamed trellis prototype --- src/BUILD | 27 ++ src/tesseract_trellis.cc | 679 ++++++++++++++++++++++++++++++++++ src/tesseract_trellis.h | 67 ++++ src/tesseract_trellis_main.cc | 405 ++++++++++++++++++++ 4 files changed, 1178 insertions(+) create mode 100644 src/tesseract_trellis.cc create mode 100644 src/tesseract_trellis.h create mode 100644 src/tesseract_trellis_main.cc diff --git a/src/BUILD b/src/BUILD index 95a6794..713f1ae 100644 --- a/src/BUILD +++ b/src/BUILD @@ -140,6 +140,20 @@ cc_library( ], ) +cc_library( + name = "libtesseract_trellis", + srcs = ["tesseract_trellis.cc"], + hdrs = ["tesseract_trellis.h"], + copts = OPT_COPTS, + linkopts = OPT_LINKOPTS, + deps = [ + ":libcommon", + ":libutils", + "@boost//:dynamic_bitset", + "@stim//:stim_lib", + ], +) + cc_library( name = "libtesseract_ftl", @@ -167,6 +181,19 @@ cc_binary( ], ) +cc_binary( + name = "tesseract_trellis", + srcs = ["tesseract_trellis_main.cc"], + copts = OPT_COPTS, + linkopts = OPT_LINKOPTS, + deps = [ + ":libtesseract_trellis", + "@argparse", + "@nlohmann_json//:json", + "@stim//:stim_lib", + ], +) + cc_binary( name = "tesseract_ftl", diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc new file mode 100644 index 0000000..d64bfc0 --- /dev/null +++ b/src/tesseract_trellis.cc @@ -0,0 +1,679 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tesseract_trellis.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +namespace std { +template <> +struct hash> { + size_t operator()(const boost::dynamic_bitset<>& bs) const { + return boost::hash_value(bs); + } +}; +} // namespace std + +namespace { + +struct Fault { + size_t error_index; + double log_q; + double log_p; + uint64_t obs_mask; + std::vector detectors; +}; + +struct LayerFault { + size_t error_index; + double log_q; + double log_p; + uint64_t obs_mask; + boost::dynamic_bitset<> local_det_mask; + boost::dynamic_bitset<> retiring_mask; + boost::dynamic_bitset<> expected_retiring_bits; + std::vector surviving_local_indices; +}; + +struct SmallLayerFault { + size_t error_index; + double q; + double p; + uint64_t obs_flip_bit; + uint64_t local_det_mask; + uint64_t retiring_mask; + uint64_t expected_retiring_bits; + std::vector surviving_local_indices; +}; + +struct PackedMass { + uint64_t key; + double mass; +}; + +struct StateMass { + uint64_t state; + double mass; +}; + +struct ObsAggregate { + double log_mass = -INF; +}; + +struct FrontierAggregate { + double total_log_mass = -INF; + std::unordered_map obs_entries; +}; + +double logsumexp2(double a, double b) { + if (a == -INF) return b; + if (b == -INF) return a; + if (a < b) std::swap(a, b); + return a + std::log1p(std::exp(b - a)); +} + +void add_obs_mass(FrontierAggregate& aggregate, uint64_t obs_mask, double log_mass) { + aggregate.total_log_mass = logsumexp2(aggregate.total_log_mass, log_mass); + auto& obs = aggregate.obs_entries[obs_mask]; + obs.log_mass = logsumexp2(obs.log_mass, log_mass); +} + +std::vector parse_faults(const std::vector& errors, size_t num_observables) { + std::vector faults; + faults.reserve(errors.size()); + for (size_t error_index = 0; error_index < errors.size(); ++error_index) { + const auto& error = errors[error_index]; + const double p = error.get_probability(); + if (p <= 0) continue; + Fault fault; + fault.error_index = error_index; + fault.log_q = std::log1p(-p); + fault.log_p = std::log(p); + fault.obs_mask = 0; + for (int obs : error.symptom.observables) { + if (obs >= 64) { + throw std::invalid_argument("tesseract_trellis currently supports at most 64 observables"); + } + if (size_t(obs) >= num_observables) { + throw std::invalid_argument("Observable index out of range in DEM"); + } + fault.obs_mask ^= uint64_t{1} << obs; + } + fault.detectors = error.symptom.detectors; + faults.push_back(std::move(fault)); + } + return faults; +} + +boost::dynamic_bitset<> project_state(const boost::dynamic_bitset<>& state, + const std::vector& surviving_local_indices) { + boost::dynamic_bitset<> out(surviving_local_indices.size()); + for (size_t k = 0; k < surviving_local_indices.size(); ++k) { + out[k] = state[surviving_local_indices[k]]; + } + return out; +} + +std::vector build_layer_faults(const std::vector& faults, size_t num_detectors, + const std::vector& detections, + size_t* max_frontier_width_seen) { + std::vector last_seen(num_detectors, std::numeric_limits::max()); + for (size_t i = 0; i < faults.size(); ++i) { + for (int d : faults[i].detectors) { + last_seen[d] = i; + } + } + + boost::dynamic_bitset<> actual_dets(num_detectors); + for (uint64_t d : detections) { + if (d >= num_detectors) { + throw std::runtime_error("Detector index out of range."); + } + actual_dets.flip(d); + } + + std::vector active_detectors; + active_detectors.reserve(num_detectors); + std::vector global_to_local(num_detectors, -1); + std::vector layers; + layers.reserve(faults.size()); + *max_frontier_width_seen = 0; + + for (size_t i = 0; i < faults.size(); ++i) { + for (int d : faults[i].detectors) { + if (global_to_local[d] == -1) { + global_to_local[d] = active_detectors.size(); + active_detectors.push_back(d); + } + } + + *max_frontier_width_seen = std::max(*max_frontier_width_seen, active_detectors.size()); + LayerFault layer{ + .error_index = faults[i].error_index, + .log_q = faults[i].log_q, + .log_p = faults[i].log_p, + .obs_mask = faults[i].obs_mask, + .local_det_mask = boost::dynamic_bitset<>(active_detectors.size()), + .retiring_mask = boost::dynamic_bitset<>(active_detectors.size()), + .expected_retiring_bits = boost::dynamic_bitset<>(active_detectors.size()), + .surviving_local_indices = {}, + }; + + for (int d : faults[i].detectors) { + layer.local_det_mask.set(global_to_local[d]); + } + + for (size_t local = 0; local < active_detectors.size(); ++local) { + const int d = active_detectors[local]; + if (last_seen[d] == i) { + layer.retiring_mask.set(local); + layer.expected_retiring_bits[local] = actual_dets[d]; + } else { + layer.surviving_local_indices.push_back(local); + } + } + + std::vector next_active; + next_active.reserve(layer.surviving_local_indices.size()); + std::fill(global_to_local.begin(), global_to_local.end(), -1); + for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { + int d = active_detectors[layer.surviving_local_indices[next_local]]; + global_to_local[d] = next_local; + next_active.push_back(d); + } + active_detectors = std::move(next_active); + layers.push_back(std::move(layer)); + } + return layers; +} + +bool try_build_small_layer_faults(const std::vector& faults, size_t num_detectors, + const std::vector& detections, + std::vector* layers, + size_t* max_frontier_width_seen) { + std::vector last_seen(num_detectors, std::numeric_limits::max()); + for (size_t i = 0; i < faults.size(); ++i) { + for (int d : faults[i].detectors) { + last_seen[d] = i; + } + } + + boost::dynamic_bitset<> actual_dets(num_detectors); + for (uint64_t d : detections) { + if (d >= num_detectors) { + throw std::runtime_error("Detector index out of range."); + } + actual_dets.flip(d); + } + + std::vector active_detectors; + active_detectors.reserve(num_detectors); + std::vector global_to_local(num_detectors, -1); + layers->clear(); + layers->reserve(faults.size()); + *max_frontier_width_seen = 0; + + for (size_t i = 0; i < faults.size(); ++i) { + for (int d : faults[i].detectors) { + if (global_to_local[d] == -1) { + global_to_local[d] = active_detectors.size(); + active_detectors.push_back(d); + } + } + + *max_frontier_width_seen = std::max(*max_frontier_width_seen, active_detectors.size()); + if (*max_frontier_width_seen > 63) { + return false; + } + + SmallLayerFault layer{ + .error_index = faults[i].error_index, + .q = std::exp(faults[i].log_q), + .p = std::exp(faults[i].log_p), + .obs_flip_bit = faults[i].obs_mask & 1, + .local_det_mask = 0, + .retiring_mask = 0, + .expected_retiring_bits = 0, + .surviving_local_indices = {}, + }; + for (int d : faults[i].detectors) { + layer.local_det_mask ^= uint64_t{1} << global_to_local[d]; + } + for (size_t local = 0; local < active_detectors.size(); ++local) { + const int d = active_detectors[local]; + if (last_seen[d] == i) { + layer.retiring_mask ^= uint64_t{1} << local; + if (actual_dets[d]) { + layer.expected_retiring_bits ^= uint64_t{1} << local; + } + } else { + layer.surviving_local_indices.push_back(local); + } + } + + std::vector next_active; + next_active.reserve(layer.surviving_local_indices.size()); + std::fill(global_to_local.begin(), global_to_local.end(), -1); + for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { + int d = active_detectors[layer.surviving_local_indices[next_local]]; + global_to_local[d] = next_local; + next_active.push_back(d); + } + active_detectors = std::move(next_active); + layers->push_back(std::move(layer)); + } + + return true; +} + +uint64_t project_small_state(uint64_t state, const std::vector& surviving_local_indices) { + uint64_t out = 0; + for (size_t k = 0; k < surviving_local_indices.size(); ++k) { + out |= ((state >> surviving_local_indices[k]) & 1ULL) << k; + } + return out; +} + +uint64_t pack_small_key(uint64_t state, uint64_t obs_flip_bit) { + return (state << 1) | (obs_flip_bit & 1ULL); +} + +uint64_t unpack_small_state(uint64_t packed_key) { + return packed_key >> 1; +} + +uint64_t unpack_small_obs(uint64_t packed_key) { + return packed_key & 1ULL; +} + +void normalize_items(std::vector& items) { + double total_mass = 0.0; + for (const auto& item : items) { + total_mass += item.mass; + } + if (total_mass == 0.0) { + items.clear(); + return; + } + double inv = 1.0 / total_mass; + for (auto& item : items) { + item.mass *= inv; + } +} + +std::vector merge_equal_keys(std::vector& items) { + if (items.empty()) { + return {}; + } + std::sort(items.begin(), items.end(), [](const PackedMass& a, const PackedMass& b) { + return a.key < b.key; + }); + std::vector merged; + merged.reserve(items.size()); + uint64_t cur_key = items[0].key; + double cur_mass = items[0].mass; + for (size_t i = 1; i < items.size(); ++i) { + if (items[i].key == cur_key) { + cur_mass += items[i].mass; + } else { + merged.push_back({cur_key, cur_mass}); + cur_key = items[i].key; + cur_mass = items[i].mass; + } + } + merged.push_back({cur_key, cur_mass}); + return merged; +} + +std::vector accumulate_state_masses_from_entries(const std::vector& entries) { + std::vector totals; + if (entries.empty()) { + return totals; + } + totals.reserve(entries.size()); + uint64_t cur_state = unpack_small_state(entries[0].key); + double cur_mass = entries[0].mass; + for (size_t i = 1; i < entries.size(); ++i) { + uint64_t s = unpack_small_state(entries[i].key); + if (s == cur_state) { + cur_mass += entries[i].mass; + } else { + totals.push_back({cur_state, cur_mass}); + cur_state = s; + cur_mass = entries[i].mass; + } + } + totals.push_back({cur_state, cur_mass}); + return totals; +} + +void keep_top_states(std::vector& entries, size_t beam_width) { + if (entries.empty()) { + return; + } + auto totals = accumulate_state_masses_from_entries(entries); + if (totals.size() <= beam_width) { + return; + } + std::nth_element(totals.begin(), totals.begin() + beam_width, totals.end(), + [](const StateMass& a, const StateMass& b) { return a.mass > b.mass; }); + totals.resize(beam_width); + std::sort(totals.begin(), totals.end(), [](const StateMass& a, const StateMass& b) { + return a.state < b.state; + }); + + std::vector kept; + kept.reserve(entries.size()); + size_t ti = 0; + for (const auto& item : entries) { + uint64_t s = unpack_small_state(item.key); + while (ti < totals.size() && totals[ti].state < s) { + ++ti; + } + if (ti < totals.size() && totals[ti].state == s) { + kept.push_back(item); + } + } + entries = std::move(kept); +} + +void keep_top_branch_entries(std::vector& entries, size_t beam_width) { + if (entries.size() <= beam_width) { + return; + } + std::nth_element(entries.begin(), entries.begin() + beam_width, entries.end(), + [](const PackedMass& a, const PackedMass& b) { return a.mass > b.mass; }); + entries.resize(beam_width); +} + +} // namespace + +TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_) + : config(std::move(config_)) { + std::vector dem_error_map(config.dem.flattened().count_errors()); + std::iota(dem_error_map.begin(), dem_error_map.end(), 0); + dem_error_to_error = std::move(dem_error_map); + error_to_dem_error = common::invert_error_map(dem_error_to_error, config.dem.count_errors()); + errors = get_errors_from_dem(config.dem.flattened()); + num_detectors = config.dem.count_detectors(); + num_observables = config.dem.count_observables(); +} + +void TesseractTrellisDecoder::decode_shot(const std::vector& detections) { + low_confidence_flag = false; + num_states_expanded = 0; + num_states_merged = 0; + max_beam_size_seen = 0; + max_frontier_width_seen = 0; + time_expand_seconds = 0; + time_collapse_seconds = 0; + time_truncate_seconds = 0; + time_reconstruct_seconds = 0; + predicted_obs_mask = 0; + total_mass_obs0 = 0; + total_mass_obs1 = 0; + + auto faults = parse_faults(errors, num_observables); + + std::unordered_set all_possible_dets; + for (const auto& error : errors) { + for (int d : error.symptom.detectors) { + all_possible_dets.insert(d); + } + } + for (uint64_t d : detections) { + if (!all_possible_dets.contains(int(d))) { + low_confidence_flag = true; + return; + } + } + + std::vector small_layers; + if (num_observables <= 1 && + try_build_small_layer_faults(faults, num_detectors, detections, &small_layers, + &max_frontier_width_seen)) { + std::vector beam_entries; + beam_entries.push_back({pack_small_key(0, 0), 1.0}); + max_beam_size_seen = 1; + + for (size_t layer_index = 0; layer_index < small_layers.size(); ++layer_index) { + const auto& layer = small_layers[layer_index]; + auto t0 = std::chrono::high_resolution_clock::now(); + std::vector expanded_entries; + expanded_entries.reserve(beam_entries.size() * 2); + for (const auto& item : beam_entries) { + ++num_states_expanded; + uint64_t base_state = unpack_small_state(item.key); + expanded_entries.push_back({pack_small_key(base_state, unpack_small_obs(item.key)), + item.mass * layer.q}); + uint64_t present_key = + pack_small_key(base_state ^ layer.local_det_mask, + unpack_small_obs(item.key) ^ layer.obs_flip_bit); + expanded_entries.push_back({present_key, item.mass * layer.p}); + } + auto t1 = std::chrono::high_resolution_clock::now(); + time_expand_seconds += + std::chrono::duration_cast(t1 - t0).count() / 1e6; + + std::vector next_entries; + next_entries.reserve(expanded_entries.size()); + for (const auto& item : expanded_entries) { + uint64_t state = unpack_small_state(item.key); + if (((state ^ layer.expected_retiring_bits) & layer.retiring_mask) != 0) { + continue; + } + uint64_t projected_state = project_small_state(state, layer.surviving_local_indices); + uint64_t projected_key = pack_small_key(projected_state, unpack_small_obs(item.key)); + next_entries.push_back({projected_key, item.mass}); + } + auto t1b = std::chrono::high_resolution_clock::now(); + time_collapse_seconds += + std::chrono::duration_cast(t1b - t1).count() / 1e6; + + beam_entries = std::move(next_entries); + bool at_checkpoint = ((layer_index + 1) % config.merge_interval == 0) || + (layer_index + 1 == small_layers.size()); + if (!at_checkpoint) { + max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); + if (beam_entries.empty()) { + low_confidence_flag = true; + return; + } + continue; + } + + auto t2a = std::chrono::high_resolution_clock::now(); + if (config.prune_mode != TesseractTrellisPruneMode::kNoMerge) { + beam_entries = merge_equal_keys(beam_entries); + } + auto t2 = std::chrono::high_resolution_clock::now(); + time_collapse_seconds += + std::chrono::duration_cast(t2 - t2a).count() / 1e6; + + if (config.prune_mode == TesseractTrellisPruneMode::kMergedStates) { + keep_top_states(beam_entries, config.beam_width); + } else if (config.prune_mode == TesseractTrellisPruneMode::kBranchEntries || + config.prune_mode == TesseractTrellisPruneMode::kNoMerge) { + keep_top_branch_entries(beam_entries, config.beam_width); + } + normalize_items(beam_entries); + if (beam_entries.empty()) { + low_confidence_flag = true; + return; + } + if (config.prune_mode == TesseractTrellisPruneMode::kNoMerge) { + num_states_merged += beam_entries.size(); + max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); + } else { + auto post_totals = accumulate_state_masses_from_entries(beam_entries); + num_states_merged += post_totals.size(); + max_beam_size_seen = std::max(max_beam_size_seen, post_totals.size()); + } + auto t3 = std::chrono::high_resolution_clock::now(); + time_truncate_seconds += + std::chrono::duration_cast(t3 - t2).count() / 1e6; + } + + auto tr0 = std::chrono::high_resolution_clock::now(); + for (const auto& [packed_key, mass] : beam_entries) { + if (unpack_small_state(packed_key) != 0) { + continue; + } + if (unpack_small_obs(packed_key) == 0) { + total_mass_obs0 += mass; + } else { + total_mass_obs1 += mass; + } + } + if (total_mass_obs0 == 0.0 && total_mass_obs1 == 0.0) { + low_confidence_flag = true; + return; + } + predicted_obs_mask = total_mass_obs1 > total_mass_obs0 ? 1 : 0; + auto tr1 = std::chrono::high_resolution_clock::now(); + time_reconstruct_seconds += + std::chrono::duration_cast(tr1 - tr0).count() / 1e6; + } else { + auto layers = build_layer_faults(faults, num_detectors, detections, &max_frontier_width_seen); + std::unordered_map, FrontierAggregate> beam; + FrontierAggregate init; + add_obs_mass(init, 0, 0.0); + beam.emplace(boost::dynamic_bitset<>(0), std::move(init)); + max_beam_size_seen = 1; + + for (const auto& layer : layers) { + auto t0 = std::chrono::high_resolution_clock::now(); + std::unordered_map, FrontierAggregate> expanded; + expanded.reserve(beam.size() * 2 + 1); + + for (const auto& [state, aggregate] : beam) { + ++num_states_expanded; + boost::dynamic_bitset<> base_state = state; + base_state.resize(layer.local_det_mask.size()); + for (const auto& [obs_mask, obs] : aggregate.obs_entries) { + auto& absent_bucket = expanded[base_state]; + add_obs_mass(absent_bucket, obs_mask, obs.log_mass + layer.log_q); + + boost::dynamic_bitset<> present_state = base_state ^ layer.local_det_mask; + auto& present_bucket = expanded[present_state]; + add_obs_mass(present_bucket, obs_mask ^ layer.obs_mask, obs.log_mass + layer.log_p); + } + } + auto t1 = std::chrono::high_resolution_clock::now(); + time_expand_seconds += + std::chrono::duration_cast(t1 - t0).count() / 1e6; + + std::unordered_map, FrontierAggregate> collapsed; + collapsed.reserve(expanded.size()); + for (auto& [state, aggregate] : expanded) { + if (((state & layer.retiring_mask) ^ layer.expected_retiring_bits).any()) { + continue; + } + boost::dynamic_bitset<> projected = project_state(state, layer.surviving_local_indices); + auto& out = collapsed[projected]; + for (const auto& [obs_mask, obs] : aggregate.obs_entries) { + add_obs_mass(out, obs_mask, obs.log_mass); + } + } + auto t2 = std::chrono::high_resolution_clock::now(); + time_collapse_seconds += + std::chrono::duration_cast(t2 - t1).count() / 1e6; + + num_states_merged += collapsed.size(); + if (collapsed.empty()) { + low_confidence_flag = true; + return; + } + + std::vector, FrontierAggregate>> next_beam; + next_beam.reserve(collapsed.size()); + for (auto& item : collapsed) { + next_beam.push_back(std::move(item)); + } + if (next_beam.size() > config.beam_width) { + std::nth_element(next_beam.begin(), next_beam.begin() + config.beam_width, next_beam.end(), + [](const auto& a, const auto& b) { + return a.second.total_log_mass > b.second.total_log_mass; + }); + next_beam.resize(config.beam_width); + } + max_beam_size_seen = std::max(max_beam_size_seen, next_beam.size()); + + beam.clear(); + beam.reserve(next_beam.size()); + for (auto& item : next_beam) { + beam.emplace(std::move(item.first), std::move(item.second)); + } + auto t3 = std::chrono::high_resolution_clock::now(); + time_truncate_seconds += + std::chrono::duration_cast(t3 - t2).count() / 1e6; + } + + auto it = beam.find(boost::dynamic_bitset<>(0)); + if (it == beam.end() || it->second.obs_entries.empty()) { + low_confidence_flag = true; + return; + } + + const auto& final_entry = it->second; + auto tr0 = std::chrono::high_resolution_clock::now(); + if (final_entry.obs_entries.empty()) { + low_confidence_flag = true; + return; + } + auto it0 = final_entry.obs_entries.find(0); + auto it1 = final_entry.obs_entries.find(1); + total_mass_obs0 = it0 == final_entry.obs_entries.end() ? 0.0 : std::exp(it0->second.log_mass); + total_mass_obs1 = it1 == final_entry.obs_entries.end() ? 0.0 : std::exp(it1->second.log_mass); + if (total_mass_obs0 == 0.0 && total_mass_obs1 == 0.0) { + low_confidence_flag = true; + return; + } + predicted_obs_mask = total_mass_obs1 > total_mass_obs0 ? 1 : 0; + auto tr1 = std::chrono::high_resolution_clock::now(); + time_reconstruct_seconds += + std::chrono::duration_cast(tr1 - tr0).count() / 1e6; + } + + if (config.verbose) { + std::cout << "trellis beam_width=" << config.beam_width + << " frontier_width=" << max_frontier_width_seen + << " states_expanded=" << num_states_expanded + << " states_merged=" << num_states_merged + << " max_beam=" << max_beam_size_seen << std::endl; + } +} + +std::vector TesseractTrellisDecoder::decode(const std::vector& detections) { + decode_shot(detections); + return predicted_obs_mask ? std::vector{0} : std::vector{}; +} + +void TesseractTrellisDecoder::decode_shots(std::vector& shots, + std::vector>& obs_predicted) { + obs_predicted.resize(shots.size()); + for (size_t i = 0; i < shots.size(); ++i) { + obs_predicted[i] = decode(shots[i].hits); + } +} diff --git a/src/tesseract_trellis.h b/src/tesseract_trellis.h new file mode 100644 index 0000000..491d74f --- /dev/null +++ b/src/tesseract_trellis.h @@ -0,0 +1,67 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TESSERACT_TRELLIS_DECODER_H +#define TESSERACT_TRELLIS_DECODER_H + +#include +#include + +#include "common.h" +#include "stim.h" + +enum class TesseractTrellisPruneMode { + kMergedStates, + kBranchEntries, + kNoMerge, +}; + +struct TesseractTrellisConfig { + stim::DetectorErrorModel dem; + size_t beam_width = 1024; + size_t merge_interval = 1; + bool verbose = false; + TesseractTrellisPruneMode prune_mode = TesseractTrellisPruneMode::kMergedStates; +}; + +struct TesseractTrellisDecoder { + explicit TesseractTrellisDecoder(TesseractTrellisConfig config); + + void decode_shot(const std::vector& detections); + std::vector decode(const std::vector& detections); + void decode_shots(std::vector& shots, + std::vector>& obs_predicted); + + TesseractTrellisConfig config; + bool low_confidence_flag = false; + size_t num_states_expanded = 0; + size_t num_states_merged = 0; + size_t max_beam_size_seen = 0; + size_t max_frontier_width_seen = 0; + double time_expand_seconds = 0; + double time_collapse_seconds = 0; + double time_truncate_seconds = 0; + double time_reconstruct_seconds = 0; + uint64_t predicted_obs_mask = 0; + double total_mass_obs0 = 0; + double total_mass_obs1 = 0; + + std::vector dem_error_to_error; + std::vector error_to_dem_error; + std::vector errors; + size_t num_observables = 0; + size_t num_detectors = 0; +}; + +#endif // TESSERACT_TRELLIS_DECODER_H diff --git a/src/tesseract_trellis_main.cc b/src/tesseract_trellis_main.cc new file mode 100644 index 0000000..b252e29 --- /dev/null +++ b/src/tesseract_trellis_main.cc @@ -0,0 +1,405 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "stim.h" +#include "tesseract_trellis.h" +#include "utils.h" + +namespace { + +TesseractTrellisPruneMode parse_prune_mode(const std::string& value) { + if (value == "merged") return TesseractTrellisPruneMode::kMergedStates; + if (value == "branch") return TesseractTrellisPruneMode::kBranchEntries; + if (value == "none") return TesseractTrellisPruneMode::kNoMerge; + throw std::invalid_argument("Unknown trellis prune mode: " + value); +} + +} // namespace + +struct Args { + std::string circuit_path; + std::string dem_path; + + size_t sample_num_shots = 0; + size_t max_errors = SIZE_MAX; + uint64_t sample_seed; + + size_t shot_range_begin = 0; + size_t shot_range_end = 0; + + std::string in_fname = ""; + std::string in_format = ""; + std::string obs_in_fname = ""; + std::string obs_in_format = ""; + bool append_observables = false; + std::string out_fname = ""; + std::string out_format = ""; + + std::string dem_out_fname = ""; + std::string stats_out_fname = ""; + + size_t num_threads = 1; + size_t beam_width = 1024; + size_t merge_interval = 1; + std::string prune_mode = "merged"; + + bool verbose = false; + bool print_stats = false; + + bool has_observables() { + return append_observables || !obs_in_fname.empty() || (sample_num_shots > 0); + } + + void validate() { + if (circuit_path.empty() && dem_path.empty()) { + throw std::invalid_argument("Must provide at least one of --circuit or --dem"); + } + int num_data_sources = int(sample_num_shots > 0) + int(!in_fname.empty()); + if (num_data_sources != 1) { + throw std::invalid_argument("Requires exactly 1 source of shots."); + } + if (!in_fname.empty() && in_format.empty()) { + throw std::invalid_argument("If --in is provided, must also specify --in-format."); + } + if (!out_fname.empty() && out_format.empty()) { + throw std::invalid_argument("If --out is provided, must also specify --out-format."); + } + if (!in_format.empty() && !stim::format_name_to_enum_map().contains(in_format)) { + throw std::invalid_argument("Invalid format: " + in_format); + } + if (!obs_in_format.empty() && !stim::format_name_to_enum_map().contains(obs_in_format)) { + throw std::invalid_argument("Invalid format: " + obs_in_format); + } + if (!out_format.empty() && !stim::format_name_to_enum_map().contains(out_format)) { + throw std::invalid_argument("Invalid format: " + out_format); + } + if (!obs_in_fname.empty() && in_fname.empty()) { + throw std::invalid_argument( + "Cannot load observable flips without a corresponding detection event data file."); + } + if (num_threads == 0) { + throw std::invalid_argument("--threads must be at least 1."); + } + if (num_threads > 1000) { + throw std::invalid_argument("There is a maximum limit of 1000 threads."); + } + if ((shot_range_begin || shot_range_end) && shot_range_end < shot_range_begin) { + throw std::invalid_argument("Provided shot range must have end >= begin."); + } + if (sample_num_shots > 0 && circuit_path.empty()) { + throw std::invalid_argument("Cannot sample shots without a circuit."); + } + if (beam_width == 0) { + throw std::invalid_argument("--beam must be at least 1."); + } + if (merge_interval == 0) { + throw std::invalid_argument("--merge-interval must be at least 1."); + } + parse_prune_mode(prune_mode); + } + + void extract(TesseractTrellisConfig& config, std::vector& shots, + std::unique_ptr& writer) { + stim::Circuit circuit; + if (!circuit_path.empty()) { + FILE* file = fopen(circuit_path.c_str(), "r"); + if (!file) { + throw std::invalid_argument("Could not open the file: " + circuit_path); + } + circuit = stim::Circuit::from_file(file); + fclose(file); + } + + if (!dem_path.empty()) { + FILE* file = fopen(dem_path.c_str(), "r"); + if (!file) { + throw std::invalid_argument("Could not open the file: " + dem_path); + } + config.dem = stim::DetectorErrorModel::from_file(file); + fclose(file); + } else { + assert(!circuit_path.empty()); + config.dem = stim::ErrorAnalyzer::circuit_to_detector_error_model( + circuit, /*decompose_errors=*/false, /*fold_loops=*/true, + /*allow_gauge_detectors=*/true, + /*approximate_disjoint_errors_threshold=*/1, + /*ignore_decomposition_failures=*/false, + /*block_decomposition_from_introducing_remnant_edges=*/false); + } + + config.beam_width = beam_width; + config.merge_interval = merge_interval; + config.verbose = verbose; + config.prune_mode = parse_prune_mode(prune_mode); + + if (sample_num_shots > 0) { + assert(!circuit_path.empty()); + std::mt19937_64 rng(sample_seed); + size_t num_detectors = circuit.count_detectors(); + const auto [dets, obs] = + stim::sample_batch_detection_events<64>(circuit, sample_num_shots, rng); + stim::simd_bit_table<64> obs_T = obs.transposed(); + shots.resize(sample_num_shots); + for (size_t k = 0; k < sample_num_shots; k++) { + shots[k].obs_mask = obs_T[k]; + for (size_t d = 0; d < num_detectors; d++) { + if (dets[d][k]) { + shots[k].hits.push_back(d); + } + } + } + } + + if (!in_fname.empty()) { + FILE* shots_file = fopen(in_fname.c_str(), "r"); + if (!shots_file) { + throw std::invalid_argument("Could not open the file: " + in_fname); + } + stim::FileFormatData shots_in_format = stim::format_name_to_enum_map().at(in_format); + auto reader = stim::MeasureRecordReader::make( + shots_file, shots_in_format.id, 0, config.dem.count_detectors(), + append_observables * config.dem.count_observables()); + stim::SparseShot sparse_shot; + sparse_shot.clear(); + while (reader->start_and_read_entire_record(sparse_shot)) { + shots.push_back(sparse_shot); + sparse_shot.clear(); + } + fclose(shots_file); + } + + if (!obs_in_fname.empty()) { + FILE* obs_file = fopen(obs_in_fname.c_str(), "r"); + if (!obs_file) { + throw std::invalid_argument("Could not open the file: " + obs_in_fname); + } + stim::FileFormatData obs_format = stim::format_name_to_enum_map().at(obs_in_format); + auto obs_reader = stim::MeasureRecordReader::make( + obs_file, obs_format.id, 0, 0, config.dem.count_observables()); + stim::SparseShot sparse_shot; + sparse_shot.clear(); + size_t num_obs_shots = 0; + while (obs_reader->start_and_read_entire_record(sparse_shot)) { + if (num_obs_shots >= shots.size()) { + throw std::invalid_argument("Shot data ended before obs data."); + } + shots[num_obs_shots].obs_mask = sparse_shot.obs_mask; + sparse_shot.clear(); + ++num_obs_shots; + } + if (num_obs_shots != shots.size()) { + throw std::invalid_argument("Obs data ended before shot data ended."); + } + fclose(obs_file); + } + + if (shot_range_begin || shot_range_end) { + if (shot_range_end > shots.size()) { + throw std::invalid_argument("Shot range end is past end of shots array."); + } + std::vector shots_in_range(shots.begin() + shot_range_begin, + shots.begin() + shot_range_end); + std::swap(shots_in_range, shots); + } + + if (!out_fname.empty()) { + stim::FileFormatData predictions_out_format = stim::format_name_to_enum_map().at(out_format); + FILE* predictions_file = stdout; + if (out_fname != "-") { + predictions_file = fopen(out_fname.c_str(), "w"); + } + writer = stim::MeasureRecordWriter::make(predictions_file, predictions_out_format.id); + writer->begin_result_type('L'); + } + } +}; + +int main(int argc, char* argv[]) { + std::cout.precision(16); + argparse::ArgumentParser program("tesseract_trellis"); + Args args; + program.add_argument("--circuit").help("Stim circuit file path").store_into(args.circuit_path); + program.add_argument("--dem").help("Stim dem file path").store_into(args.dem_path); + program.add_argument("--sample-num-shots").store_into(args.sample_num_shots); + program.add_argument("--max-errors").store_into(args.max_errors); + program.add_argument("--sample-seed") + .default_value(static_cast(std::random_device()())) + .store_into(args.sample_seed); + program.add_argument("--shot-range-begin").default_value(size_t(0)).store_into(args.shot_range_begin); + program.add_argument("--shot-range-end").default_value(size_t(0)).store_into(args.shot_range_end); + program.add_argument("--in").default_value(std::string("")).store_into(args.in_fname); + program.add_argument("--in-format", "--in_format") + .default_value(std::string("")) + .store_into(args.in_format); + program.add_argument("--in-includes-appended-observables", "--in_includes_appended_observables") + .default_value(false) + .store_into(args.append_observables) + .flag(); + program.add_argument("--obs_in", "--obs-in").default_value(std::string("")).store_into(args.obs_in_fname); + program.add_argument("--obs-in-format", "--obs_in_format") + .default_value(std::string("")) + .store_into(args.obs_in_format); + program.add_argument("--out").default_value(std::string("")).store_into(args.out_fname); + program.add_argument("--out-format").default_value(std::string("")).store_into(args.out_format); + program.add_argument("--dem-out").default_value(std::string("")).store_into(args.dem_out_fname); + program.add_argument("--stats-out").default_value(std::string("")).store_into(args.stats_out_fname); + program.add_argument("--threads") + .default_value(size_t( + std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) + .store_into(args.num_threads); + program.add_argument("--beam").default_value(size_t(1024)).store_into(args.beam_width); + program.add_argument("--merge-interval").default_value(size_t(1)).store_into(args.merge_interval); + program.add_argument("--prune-mode") + .help("Trellis pruning mode: merged, branch, or none") + .default_value(std::string("merged")) + .store_into(args.prune_mode); + program.add_argument("--verbose").flag().store_into(args.verbose); + program.add_argument("--print-stats").flag().store_into(args.print_stats); + + try { + program.parse_args(argc, argv); + } catch (const std::exception& err) { + std::cerr << err.what() << std::endl; + std::cerr << program; + return EXIT_FAILURE; + } + + args.validate(); + TesseractTrellisConfig config; + std::vector shots; + std::unique_ptr writer; + args.extract(config, shots, writer); + + std::vector obs_predicted(shots.size()); + std::vector mass0_predicted(shots.size()); + std::vector mass1_predicted(shots.size()); + std::vector decoding_time_seconds(shots.size()); + std::vector num_states_expanded_per_shot(shots.size()); + std::vector num_states_merged_per_shot(shots.size()); + std::vector max_beam_size_per_shot(shots.size()); + std::vector max_frontier_width_per_shot(shots.size()); + std::vector time_expand_per_shot(shots.size()); + std::vector time_collapse_per_shot(shots.size()); + std::vector time_truncate_per_shot(shots.size()); + std::vector time_reconstruct_per_shot(shots.size()); + std::vector> low_confidence(shots.size()); + const stim::DetectorErrorModel original_dem = config.dem.flattened(); + std::vector> decoders(args.num_threads); + + bool has_obs = args.has_observables(); + size_t num_errors = 0; + size_t num_low_confidence = 0; + double total_time_seconds = 0; + size_t num_observables = config.dem.count_observables(); + + size_t shot = parallel_for_shots_in_order( + shots.size(), args.num_threads, + [&](size_t thread_index, size_t shot_index) { + if (!decoders[thread_index]) { + decoders[thread_index] = std::make_unique(config); + } + auto& decoder = *decoders[thread_index]; + auto start_time = std::chrono::high_resolution_clock::now(); + decoder.decode_shot(shots[shot_index].hits); + auto stop_time = std::chrono::high_resolution_clock::now(); + decoding_time_seconds[shot_index] = + std::chrono::duration_cast(stop_time - start_time).count() / + 1e6; + obs_predicted[shot_index] = decoder.predicted_obs_mask; + low_confidence[shot_index] = decoder.low_confidence_flag; + mass0_predicted[shot_index] = decoder.total_mass_obs0; + mass1_predicted[shot_index] = decoder.total_mass_obs1; + num_states_expanded_per_shot[shot_index] = decoder.num_states_expanded; + num_states_merged_per_shot[shot_index] = decoder.num_states_merged; + max_beam_size_per_shot[shot_index] = decoder.max_beam_size_seen; + max_frontier_width_per_shot[shot_index] = decoder.max_frontier_width_seen; + time_expand_per_shot[shot_index] = decoder.time_expand_seconds; + time_collapse_per_shot[shot_index] = decoder.time_collapse_seconds; + time_truncate_per_shot[shot_index] = decoder.time_truncate_seconds; + time_reconstruct_per_shot[shot_index] = decoder.time_reconstruct_seconds; + }, + [&](size_t shot_index) { + if (writer) { + writer->write_bits((uint8_t*)&obs_predicted[shot_index], num_observables); + writer->write_end(); + } + if (low_confidence[shot_index]) { + ++num_low_confidence; + } else if (obs_predicted[shot_index] != shots[shot_index].obs_mask_as_u64()) { + ++num_errors; + } + total_time_seconds += decoding_time_seconds[shot_index]; + if (args.print_stats) { + std::cout << "num_shots = " << (shot_index + 1) + << " num_low_confidence = " << num_low_confidence + << " num_errors = " << num_errors + << " states_expanded = " << num_states_expanded_per_shot[shot_index] + << " states_merged = " << num_states_merged_per_shot[shot_index] + << " max_beam = " << max_beam_size_per_shot[shot_index] + << " frontier_width = " << max_frontier_width_per_shot[shot_index] + << " total_time_seconds = " << total_time_seconds << std::endl; + std::cout << "branch_masses" + << " obs0=" << mass0_predicted[shot_index] + << " obs1=" << mass1_predicted[shot_index] << std::endl; + std::cout << "phase_times_seconds" + << " expand=" << time_expand_per_shot[shot_index] + << " collapse=" << time_collapse_per_shot[shot_index] + << " truncate=" << time_truncate_per_shot[shot_index] + << " reconstruct=" << time_reconstruct_per_shot[shot_index] << std::endl; + std::cout.flush(); + } + return num_errors < args.max_errors; + }); + + if (!args.dem_out_fname.empty()) { + throw std::invalid_argument("--dem-out is not supported by tesseract_trellis without path reconstruction."); + } + + bool print_final_stats = true; + if (!args.stats_out_fname.empty()) { + nlohmann::json stats_json = {{"circuit_path", args.circuit_path}, + {"dem_path", args.dem_path}, + {"beam_width", args.beam_width}, + {"sample_seed", args.sample_seed}, + {"sample_num_shots", args.sample_num_shots}, + {"num_threads", args.num_threads}, + {"num_errors", num_errors}, + {"num_low_confidence", num_low_confidence}, + {"num_shots", shot}, + {"total_time_seconds", total_time_seconds}}; + if (args.stats_out_fname == "-") { + std::cout << stats_json << std::endl; + print_final_stats = false; + } else { + std::ofstream out(args.stats_out_fname, std::ofstream::out); + out << stats_json << std::endl; + } + } + + if (print_final_stats) { + std::cout << "num_shots = " << shot << " num_low_confidence = " << num_low_confidence; + if (has_obs) { + std::cout << " num_errors = " << num_errors; + } + std::cout << " total_time_seconds = " << total_time_seconds << std::endl; + } +} From b61dfdca8fd94c32031f29f5e8d70e8f3711d73d Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Wed, 8 Apr 2026 01:19:54 +0800 Subject: [PATCH 12/25] add some trellis beam updates --- src/py/astar/trellis_beam.py | 139 +++++++--- .../trellis_beam_optimized_suspicious.py | 237 ++++++++++++++++++ 2 files changed, 341 insertions(+), 35 deletions(-) create mode 100644 src/py/astar/trellis_beam_optimized_suspicious.py diff --git a/src/py/astar/trellis_beam.py b/src/py/astar/trellis_beam.py index 11df2f8..1b69a96 100644 --- a/src/py/astar/trellis_beam.py +++ b/src/py/astar/trellis_beam.py @@ -1,21 +1,35 @@ +import argparse import heapq -import sys +import time +from dataclasses import dataclass from operator import itemgetter import stim -def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> bool | None: +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + + +def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> BeamDecodeResult: """ Decodes a syndrome using a dynamic programming sweep with a Top-L beam cutoff. """ + start_time = time.perf_counter() + # 1. Extract the Detector Error Model (flattened, decompose_errors=False) dem = circuit.detector_error_model(decompose_errors=False).flattened() # 2. Parse the DEM into a list of faults faults = [] all_possible_dets_mask = 0 - + for inst in dem: if inst.type != "error": continue @@ -42,41 +56,50 @@ def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> for d in actual_dets: actual_dets_mask ^= (1 << d) - # If the quantum computer triggered a detector that our error model says + # If the quantum computer triggered a detector that our error model says # is mathematically impossible to trigger, the syndrome is invalid. if (actual_dets_mask & ~all_possible_dets_mask) != 0: - return None + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=0, + elapsed_seconds=time.perf_counter() - start_time, + ) # 4. Pre-calculate retirement schedules - # retiring_masks[i] stores the bits of detectors that see their final fault at index i. retiring_masks = [0] * len(faults) last_seen_index = {} - + for idx, (_, _, _, det_mask) in enumerate(faults): temp = det_mask d_id = 0 - # Extract which bits are set in the mask to find the latest index for each detector while temp > 0: if temp & 1: last_seen_index[d_id] = idx temp >>= 1 d_id += 1 - + for d_id, idx in last_seen_index.items(): retiring_masks[idx] |= (1 << d_id) + active_mask = 0 + max_width = 0 + for i, (_, _, _, det_mask) in enumerate(faults): + active_mask |= det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + # 5. The Beam Search Sweep - # Each beam entry is (active_syndrome_mask, total_probability, logical_bias), - # where logical_bias = P(L0) - P(L1). Total probability is enough for beam - # ranking, and the bias preserves the final logical comparison. beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 for i, (q, p, delta_scale, det_mask) in enumerate(faults): next_probs: dict[int, list[float]] = {} # A. Expand the beam for s, total, delta in beam: - # Fault absent entry = next_probs.get(s) absent_total = total * q absent_delta = delta * q @@ -86,7 +109,6 @@ def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> entry[0] += absent_total entry[1] += absent_delta - # Fault present t = s ^ det_mask present_total = total * p present_delta = delta * delta_scale @@ -110,7 +132,6 @@ def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> keep_mask = ~retiring_mask for s, (total, delta) in next_probs.items(): - # If the retiring bits don't match our actual observation, kill the state. if (s & retiring_mask) != expected_bits: continue @@ -125,6 +146,18 @@ def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> collapsed_probs = next_probs # C. Truncate the Beam (Top L Cutoff) + total_mass = sum(total for total, _ in collapsed_probs.values()) + if total_mass == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + dropped_mass = 0.0 if len(collapsed_probs) > L: beam = heapq.nlargest( L, @@ -134,48 +167,84 @@ def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> ), key=itemgetter(1), ) + kept_mass = sum(total for _, total, _ in beam) + dropped_mass = total_mass - kept_mass else: beam = [ (state, total, delta) for state, (total, delta) in collapsed_probs.items() ] + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + # 6. Final Likelihood Comparison - # Since all bits are retired, the only surviving state mask should be exactly 0. _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass if final_delta == 0.0: - return None # Tie or beam missed the correct path entirely - return final_delta < 0.0 - + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + return BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + +def run_experiment(circuit_fname: str, L: int, seed=None): + print(f"Running on circuit {circuit_fname}") -def run_experiment(circuit_fname: str, L: int): - """ - Generates a surface code, samples an error, and decodes it using the beam search. - """ - # print(f"--- Running Distance {d}, Rounds {r}, Beam Size {L} ---") - print(f'Running on circuit {circuit_fname}') - circuit = stim.Circuit.from_file(circuit_fname) - sampler = circuit.compile_detector_sampler() + sampler = circuit.compile_detector_sampler(seed=seed) syndromes, logicals = sampler.sample(shots=1, separate_observables=True) actual_dets = set(i for i, triggered in enumerate(syndromes[0]) if triggered) actual_logical = logicals[0][0] - predicted_logical = decode_beam_search(circuit, actual_dets, L) + result = decode_beam_search(circuit, actual_dets, L) print(f"Total Detectors: {circuit.num_detectors}") - print(f"Triggered Detectors: {len(actual_dets)}") - print(f"Predicted Logical: {predicted_logical}") + print(f"Seed: {seed}") + print(f"Triggered Detectors: {len(actual_dets)}") + print(f"Width: {result.max_width}") + print(f"Predicted Logical: {result.predicted_logical}") print(f"Actual Logical: {bool(actual_logical)}") - - if predicted_logical is None: + print(f"Certified: {result.certified}") + print(f"Margin: {result.margin:.6e}") + print(f"Discarded Mass: {result.discarded_mass:.6e}") + print(f"Elapsed Seconds: {result.elapsed_seconds:.6f}") + + if result.predicted_logical is None: print("Result: DECODE FAILED (Tie or Beam too narrow)") else: - print(f"Result: {'SUCCESS' if predicted_logical == actual_logical else 'LOGICAL ERROR'}") + print(f"Result: {'SUCCESS' if result.predicted_logical == actual_logical else 'LOGICAL ERROR'}") print() -if __name__ == '__main__': - run_experiment(sys.argv[1], L=1000) + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run one-shot trellis beam decoding on a Stim circuit.") + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument("--seed", type=int, default=None, help="Sampler seed.") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + run_experiment(args.circuit, L=args.beam, seed=args.seed) diff --git a/src/py/astar/trellis_beam_optimized_suspicious.py b/src/py/astar/trellis_beam_optimized_suspicious.py new file mode 100644 index 0000000..c1f10e7 --- /dev/null +++ b/src/py/astar/trellis_beam_optimized_suspicious.py @@ -0,0 +1,237 @@ +import argparse +import time +from dataclasses import dataclass + +import stim + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + + +def decode_beam_search(circuit: stim.Circuit, actual_dets: set[int], L: int) -> BeamDecodeResult: + """ + Decodes a syndrome using a dynamic programming sweep with a Top-L beam cutoff. + """ + start_time = time.perf_counter() + + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults = [] + all_possible_dets_mask = 0 + + for inst in dem: + if inst.type != "error": + continue + + p = inst.args_copy()[0] + det_mask = 0 + flip_l0 = 0 + + for t in inst.targets_copy(): + if t.is_separator(): + continue + if t.is_relative_detector_id(): + det_mask ^= (1 << t.val) + elif t.is_logical_observable_id() and t.val == 0: + flip_l0 ^= 1 + + q = 1.0 - p + delta_scale = -p if flip_l0 else p + faults.append((q, p, delta_scale, det_mask)) + all_possible_dets_mask |= det_mask + + actual_dets_mask = 0 + for d in actual_dets: + actual_dets_mask ^= (1 << d) + + if (actual_dets_mask & ~all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=0, + elapsed_seconds=time.perf_counter() - start_time, + ) + + retiring_masks = [0] * len(faults) + last_seen_index = {} + + for idx, (_, _, _, det_mask) in enumerate(faults): + temp = det_mask + d_id = 0 + while temp > 0: + if temp & 1: + last_seen_index[d_id] = idx + temp >>= 1 + d_id += 1 + + for d_id, idx in last_seen_index.items(): + retiring_masks[idx] |= (1 << d_id) + + active_mask = 0 + max_width = 0 + for i, (_, _, _, det_mask) in enumerate(faults): + active_mask |= det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + for i, (q, p, delta_scale, det_mask) in enumerate(faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = retiring_masks[i] + + if retiring_mask == 0: + for s, total, delta in beam: + absent_total = total * q + absent_delta = delta * q + total_mass += absent_total + entry = collapsed_probs.get(s) + if entry is None: + collapsed_probs[s] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + t = s ^ det_mask + present_total = total * p + present_delta = delta * delta_scale + total_mass += present_total + entry = collapsed_probs.get(t) + if entry is None: + collapsed_probs[t] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for s, total, delta in beam: + absent_total = total * q + absent_delta = delta * q + if (s & retiring_mask) == expected_bits: + shrunk_s = s & keep_mask + total_mass += absent_total + entry = collapsed_probs.get(shrunk_s) + if entry is None: + collapsed_probs[shrunk_s] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + t = s ^ det_mask + present_total = total * p + present_delta = delta * delta_scale + if (t & retiring_mask) == expected_bits: + shrunk_t = t & keep_mask + total_mass += present_total + entry = collapsed_probs.get(shrunk_t) + if entry is None: + collapsed_probs[shrunk_t] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + ranked_states = [(total, state, delta) for state, (total, delta) in collapsed_probs.items()] + if total_mass == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + dropped_mass = 0.0 + if len(ranked_states) > L: + ranked_states.sort(reverse=True) + kept = ranked_states[:L] + beam = [(state, total, delta) for total, state, delta in kept] + kept_mass = sum(total for total, _, _ in kept) + dropped_mass = total_mass - kept_mass + else: + beam = [(state, total, delta) for total, state, delta in ranked_states] + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + if final_delta == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + return BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + +def run_experiment(circuit_fname: str, L: int, seed=None): + print(f"Running on circuit {circuit_fname}") + + circuit = stim.Circuit.from_file(circuit_fname) + + sampler = circuit.compile_detector_sampler(seed=seed) + syndromes, logicals = sampler.sample(shots=1, separate_observables=True) + + actual_dets = set(i for i, triggered in enumerate(syndromes[0]) if triggered) + actual_logical = logicals[0][0] + + result = decode_beam_search(circuit, actual_dets, L) + + print(f"Total Detectors: {circuit.num_detectors}") + print(f"Seed: {seed}") + print(f"Triggered Detectors: {len(actual_dets)}") + print(f"Width: {result.max_width}") + print(f"Predicted Logical: {result.predicted_logical}") + print(f"Actual Logical: {bool(actual_logical)}") + print(f"Certified: {result.certified}") + print(f"Margin: {result.margin:.6e}") + print(f"Discarded Mass: {result.discarded_mass:.6e}") + print(f"Elapsed Seconds: {result.elapsed_seconds:.6f}") + + if result.predicted_logical is None: + print("Result: DECODE FAILED (Tie or Beam too narrow)") + else: + print(f"Result: {'SUCCESS' if result.predicted_logical == actual_logical else 'LOGICAL ERROR'}") + print() + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run one-shot trellis beam decoding on a Stim circuit.") + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument("--seed", type=int, default=None, help="Sampler seed.") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + run_experiment(args.circuit, L=args.beam, seed=args.seed) From d214008a80dcd033a67a7a02e1fcac50f00882c9 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Wed, 8 Apr 2026 01:20:20 +0800 Subject: [PATCH 13/25] add trellis beam prototype using detcost for pruning --- src/py/astar/trellis_beam_detcost_ranked.py | 390 ++++++++++++++++++++ 1 file changed, 390 insertions(+) create mode 100644 src/py/astar/trellis_beam_detcost_ranked.py diff --git a/src/py/astar/trellis_beam_detcost_ranked.py b/src/py/astar/trellis_beam_detcost_ranked.py new file mode 100644 index 0000000..7315a4d --- /dev/null +++ b/src/py/astar/trellis_beam_detcost_ranked.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 + +import argparse +import math +import time +from dataclasses import dataclass + +import stim + + +@dataclass(frozen=True) +class Fault: + q: float + p: float + delta_scale: float + det_mask: int + likelihood_cost: float + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + + +@dataclass(frozen=True) +class SampledShot: + det_mask: int + actual_logical: bool + + +def _likelihood_cost(probability: float) -> float: + if probability <= 0.0: + return math.inf + if probability >= 1.0: + return 0.0 + return -math.log(probability / (1.0 - probability)) + + +def _detectors_from_mask(mask: int) -> list[int]: + detectors: list[int] = [] + while mask: + low_bit = mask & -mask + detectors.append(low_bit.bit_length() - 1) + mask ^= low_bit + return detectors + + +def _parse_circuit(circuit: stim.Circuit) -> tuple[list[Fault], list[int], list[int], int]: + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults: list[Fault] = [] + all_possible_dets_mask = 0 + last_seen_index: dict[int, int] = {} + + for inst in dem: + if inst.type != "error": + continue + + p = float(inst.args_copy()[0]) + det_mask = 0 + flip_l0 = 0 + for target in inst.targets_copy(): + if target.is_separator(): + continue + if target.is_relative_detector_id(): + det_mask ^= 1 << target.val + elif target.is_logical_observable_id() and target.val == 0: + flip_l0 ^= 1 + + faults.append( + Fault( + q=1.0 - p, + p=p, + delta_scale=(-p if flip_l0 else p), + det_mask=det_mask, + likelihood_cost=_likelihood_cost(p), + ) + ) + all_possible_dets_mask |= det_mask + + for det_id in _detectors_from_mask(det_mask): + last_seen_index[det_id] = len(faults) - 1 + + retiring_masks = [0] * len(faults) + for det_id, index in last_seen_index.items(): + retiring_masks[index] |= 1 << det_id + + live_masks_after = [0] * (len(faults) + 1) + active_mask = 0 + max_width = 0 + for i, fault in enumerate(faults): + active_mask |= fault.det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + live_masks_after[i + 1] = active_mask + + return faults, retiring_masks, live_masks_after, max_width + + +def _future_detcost_by_detector(faults: list[Fault], num_detectors: int) -> list[list[float]]: + future_detcost = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = fault.det_mask.bit_count() + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in _detectors_from_mask(fault.det_mask): + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return future_detcost + + +def _detcost_penalty(mismatch_mask: int, future_detcost: list[float]) -> float: + total = 0.0 + pending = mismatch_mask + + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + + best = future_detcost[detector] + if best == math.inf: + return math.inf + total += best + + return total + + +def sample_shots(circuit: stim.Circuit, shots: int, seed: int | None) -> list[SampledShot]: + sampler = circuit.compile_detector_sampler(seed=seed) + syndromes, logicals = sampler.sample(shots=shots, separate_observables=True) + out: list[SampledShot] = [] + for shot_index in range(shots): + det_mask = 0 + for detector, fired in enumerate(syndromes[shot_index]): + if fired: + det_mask ^= 1 << detector + out.append(SampledShot(det_mask=det_mask, actual_logical=bool(logicals[shot_index][0]))) + return out + + +def decode_beam_search_detcost_ranked( + circuit: stim.Circuit, + actual_dets: set[int], + L: int, +) -> BeamDecodeResult: + start_time = time.perf_counter() + + faults, retiring_masks, live_masks_after, max_width = _parse_circuit(circuit) + + actual_dets_mask = 0 + for detector in actual_dets: + actual_dets_mask ^= 1 << detector + + all_possible_dets_mask = 0 + for fault in faults: + all_possible_dets_mask |= fault.det_mask + if (actual_dets_mask & ~all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=0, + elapsed_seconds=time.perf_counter() - start_time, + ) + + future_detcost = _future_detcost_by_detector(faults, circuit.num_detectors) + + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + for i, fault in enumerate(faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = retiring_masks[i] + + if retiring_mask == 0: + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + total_mass += absent_total + entry = collapsed_probs.get(state) + if entry is None: + collapsed_probs[state] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + total_mass += present_total + entry = collapsed_probs.get(toggled) + if entry is None: + collapsed_probs[toggled] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + if (state & retiring_mask) == expected_bits: + shrunk = state & keep_mask + total_mass += absent_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + if (toggled & retiring_mask) == expected_bits: + shrunk = toggled & keep_mask + total_mass += present_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + if total_mass == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + ranked_states: list[tuple[float, float, int, float]] = [] + live_target_mask = actual_dets_mask & live_masks_after[i + 1] + next_future_detcost = future_detcost[i + 1] + for state, (total, delta) in collapsed_probs.items(): + mismatch_mask = state ^ live_target_mask + penalty = _detcost_penalty(mismatch_mask=mismatch_mask, future_detcost=next_future_detcost) + if penalty == math.inf: + rank_score = -math.inf + else: + rank_score = math.log(total) - penalty + ranked_states.append((rank_score, total, state, delta)) + + dropped_mass = 0.0 + if len(ranked_states) > L: + ranked_states.sort(reverse=True) + kept = ranked_states[:L] + beam = [(state, total, delta) for _, total, state, delta in kept] + kept_mass = sum(total for _, total, _, _ in kept) + dropped_mass = total_mass - kept_mass + else: + beam = [(state, total, delta) for _, total, state, delta in ranked_states] + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + if final_delta == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + return BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=discarded_mass, + max_width=max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + +def run_experiment( + circuit_fname: str, + L: int, + sample_num_shots: int, + sample_seed: int | None = None, + print_per_shot: bool = False, +) -> None: + circuit = stim.Circuit.from_file(circuit_fname) + shots = sample_shots(circuit, shots=sample_num_shots, seed=sample_seed) + + print(f"Running on circuit {circuit_fname}") + print(f"Total Detectors: {circuit.num_detectors}") + print(f"Sample Seed: {sample_seed}") + print(f"Num Shots: {len(shots)}") + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + + for shot_index, shot in enumerate(shots): + actual_dets = set(_detectors_from_mask(shot.det_mask)) + result = decode_beam_search_detcost_ranked(circuit, actual_dets, L) + success = result.predicted_logical == shot.actual_logical if result.predicted_logical is not None else False + low_confidence = result.predicted_logical is None + + if low_confidence: + num_low_confidence += 1 + elif not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + total_triggered += len(actual_dets) + max_width_seen = max(max_width_seen, result.max_width) + + shots_done = shot_index + 1 + resolved_shots = shots_done - num_low_confidence + error_rate_so_far = num_errors / resolved_shots if resolved_shots else 0.0 + print( + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} error_rate_so_far={error_rate_so_far:.6f} " + f"elapsed_total_seconds={total_elapsed:.6f}" + ) + + if print_per_shot: + print( + f"shot={shot_index} triggered_detectors={len(actual_dets)} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"elapsed_seconds={result.elapsed_seconds:.6f}" + ) + + print(f"Beam: {L}") + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}") + print(f"Max Width: {max_width_seen}") + print(f"Certified Shots: {num_certified}") + print(f"Low Confidence: {num_low_confidence}") + print(f"Logical Errors: {num_errors}") + print(f"Total Seconds: {total_elapsed:.6f}") + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}") + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run trellis beam decoding ranked by mass minus a detcost-style future penalty." + ) + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument("--sample-num-shots", type=int, default=1, help="Number of sampled shots.") + parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") + parser.add_argument("--print-per-shot", action="store_true", help="Print a detailed line per decoded shot.") + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + if args.sample_num_shots <= 0: + raise ValueError("--sample-num-shots must be positive.") + run_experiment( + args.circuit, + L=args.beam, + sample_num_shots=args.sample_num_shots, + sample_seed=args.sample_seed, + print_per_shot=args.print_per_shot, + ) From da4a0a3ba1cd0002ccafc2fe6a4083d4c9d416bb Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Mon, 13 Apr 2026 14:07:56 +0800 Subject: [PATCH 14/25] add very weak admissible A* heuristic dual LP penalty to src/tesseract_trellis --- src/tesseract_trellis.cc | 366 +++++++++++++++++++++++++++------- src/tesseract_trellis.h | 38 +++- src/tesseract_trellis_main.cc | 19 +- 3 files changed, 340 insertions(+), 83 deletions(-) diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index d64bfc0..e2b3055 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -15,6 +15,8 @@ #include "tesseract_trellis.h" #include +#include +#include #include #include #include @@ -39,6 +41,7 @@ namespace { struct Fault { size_t error_index; + double likelihood_cost; double log_q; double log_p; uint64_t obs_mask; @@ -70,11 +73,13 @@ struct SmallLayerFault { struct PackedMass { uint64_t key; double mass; + double penalty; }; struct StateMass { uint64_t state; double mass; + double penalty; }; struct ObsAggregate { @@ -108,6 +113,7 @@ std::vector parse_faults(const std::vector& errors, size_t if (p <= 0) continue; Fault fault; fault.error_index = error_index; + fault.likelihood_cost = error.likelihood_cost; fault.log_q = std::log1p(-p); fault.log_p = std::log(p); fault.obs_mask = 0; @@ -208,10 +214,9 @@ std::vector build_layer_faults(const std::vector& faults, siz return layers; } -bool try_build_small_layer_faults(const std::vector& faults, size_t num_detectors, - const std::vector& detections, - std::vector* layers, - size_t* max_frontier_width_seen) { +bool build_small_layer_templates(const std::vector& faults, size_t num_detectors, + std::vector* layers, + size_t* max_frontier_width_seen) { std::vector last_seen(num_detectors, std::numeric_limits::max()); for (size_t i = 0; i < faults.size(); ++i) { for (int d : faults[i].detectors) { @@ -219,14 +224,6 @@ bool try_build_small_layer_faults(const std::vector& faults, size_t num_d } } - boost::dynamic_bitset<> actual_dets(num_detectors); - for (uint64_t d : detections) { - if (d >= num_detectors) { - throw std::runtime_error("Detector index out of range."); - } - actual_dets.flip(d); - } - std::vector active_detectors; active_detectors.reserve(num_detectors); std::vector global_to_local(num_detectors, -1); @@ -235,6 +232,7 @@ bool try_build_small_layer_faults(const std::vector& faults, size_t num_d *max_frontier_width_seen = 0; for (size_t i = 0; i < faults.size(); ++i) { + const size_t previous_width = active_detectors.size(); for (int d : faults[i].detectors) { if (global_to_local[d] == -1) { global_to_local[d] = active_detectors.size(); @@ -247,15 +245,17 @@ bool try_build_small_layer_faults(const std::vector& faults, size_t num_d return false; } - SmallLayerFault layer{ - .error_index = faults[i].error_index, + TesseractTrellisSmallLayerTemplate layer{ .q = std::exp(faults[i].log_q), .p = std::exp(faults[i].log_p), .obs_flip_bit = faults[i].obs_mask & 1, .local_det_mask = 0, .retiring_mask = 0, - .expected_retiring_bits = 0, + .previous_width = previous_width, .surviving_local_indices = {}, + .current_active_detectors = active_detectors, + .next_frontier_costs = {}, + .detcost_transition = {}, }; for (int d : faults[i].detectors) { layer.local_det_mask ^= uint64_t{1} << global_to_local[d]; @@ -264,11 +264,8 @@ bool try_build_small_layer_faults(const std::vector& faults, size_t num_d const int d = active_detectors[local]; if (last_seen[d] == i) { layer.retiring_mask ^= uint64_t{1} << local; - if (actual_dets[d]) { - layer.expected_retiring_bits ^= uint64_t{1} << local; - } } else { - layer.surviving_local_indices.push_back(local); + layer.surviving_local_indices.push_back((uint8_t)local); } } @@ -295,6 +292,144 @@ uint64_t project_small_state(uint64_t state, const std::vector& survivi return out; } +uint64_t compute_target_bits(const std::vector& active_detectors, + const boost::dynamic_bitset<>& actual_dets) { + uint64_t target_bits = 0; + for (size_t local = 0; local < active_detectors.size(); ++local) { + if (actual_dets[(size_t)active_detectors[local]]) { + target_bits |= uint64_t{1} << local; + } + } + return target_bits; +} + +double compute_penalty_from_scratch(uint64_t mismatch_mask, + const std::vector& aligned_future_costs) { + double total = 0.0; + while (mismatch_mask) { + uint64_t low_bit = mismatch_mask & -mismatch_mask; + int detector = std::countr_zero(low_bit); + mismatch_mask ^= low_bit; + double best = aligned_future_costs[(size_t)detector]; + if (best == INF) { + return INF; + } + total += best; + } + return total; +} + +double advance_penalty_row(double current_penalty, uint64_t current_mismatch, + const TesseractTrellisDetcostTransition& transition) { + if (current_penalty == INF) { + return INF; + } + double total = current_penalty; + for (size_t k = 0; k < transition.fault_local_indices.size(); ++k) { + const uint64_t local_bit = uint64_t{1} << transition.fault_local_indices[k]; + if ((current_mismatch & local_bit) == 0) { + continue; + } + double current_cost = transition.current_costs[k]; + double next_cost = transition.next_costs[k]; + if (next_cost == INF) { + return INF; + } + total += next_cost - current_cost; + } + return total; +} + +double adjust_penalty_for_branch(double parent_penalty_next_row, uint64_t base_state, + uint64_t current_target_bits, uint64_t next_target_bits, + bool present_branch, uint64_t projected_state, + const TesseractTrellisSmallLayerTemplate& layer) { + if (parent_penalty_next_row == INF) { + return compute_penalty_from_scratch(projected_state ^ next_target_bits, layer.next_frontier_costs); + } + + double total = parent_penalty_next_row; + for (size_t k = 0; k < layer.detcost_transition.fault_local_indices.size(); ++k) { + uint8_t local = layer.detcost_transition.fault_local_indices[k]; + int8_t next_local = layer.detcost_transition.next_local_indices[k]; + if (next_local < 0) { + continue; + } + + const uint64_t state_bit = + local < layer.previous_width ? ((base_state >> local) & 1ULL) : 0ULL; + const uint64_t prev_mismatch = + local < layer.previous_width ? (state_bit ^ ((current_target_bits >> local) & 1ULL)) : 0ULL; + const uint64_t child_bit = state_bit ^ (present_branch ? 1ULL : 0ULL); + const uint64_t child_mismatch = child_bit ^ ((next_target_bits >> next_local) & 1ULL); + if (prev_mismatch == child_mismatch) { + continue; + } + + double next_cost = layer.detcost_transition.next_costs[k]; + if (child_mismatch) { + if (next_cost == INF) { + return INF; + } + total += next_cost; + } else { + total -= next_cost; + } + } + return total; +} + +void build_future_detcost_transitions(const std::vector& faults, size_t num_detectors, + std::vector* layers, + std::vector* initial_future_detcost) { + std::vector current_row(num_detectors, INF); + for (size_t fault_index = faults.size(); fault_index-- > 0;) { + auto& layer = (*layers)[fault_index]; + const auto& fault = faults[fault_index]; + + layer.next_frontier_costs.resize(layer.surviving_local_indices.size(), INF); + for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { + int global_detector = layer.current_active_detectors[layer.surviving_local_indices[next_local]]; + layer.next_frontier_costs[next_local] = current_row[(size_t)global_detector]; + } + + std::array current_to_next; + current_to_next.fill(-1); + for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { + current_to_next[layer.surviving_local_indices[next_local]] = (int8_t)next_local; + } + + layer.detcost_transition.fault_local_indices.clear(); + layer.detcost_transition.next_local_indices.clear(); + layer.detcost_transition.current_costs.clear(); + layer.detcost_transition.next_costs.clear(); + layer.detcost_transition.fault_local_indices.reserve(fault.detectors.size()); + layer.detcost_transition.next_local_indices.reserve(fault.detectors.size()); + layer.detcost_transition.current_costs.reserve(fault.detectors.size()); + layer.detcost_transition.next_costs.reserve(fault.detectors.size()); + + if (!fault.detectors.empty()) { + double ecost = fault.likelihood_cost / fault.detectors.size(); + for (int detector : fault.detectors) { + auto it = std::find(layer.current_active_detectors.begin(), layer.current_active_detectors.end(), + detector); + if (it == layer.current_active_detectors.end()) { + throw std::runtime_error("Missing detector in active frontier while preparing detcost."); + } + uint8_t local = (uint8_t)std::distance(layer.current_active_detectors.begin(), it); + double next_cost = current_row[(size_t)detector]; + double current_cost = std::min(ecost, next_cost); + layer.detcost_transition.fault_local_indices.push_back(local); + layer.detcost_transition.next_local_indices.push_back(current_to_next[local]); + layer.detcost_transition.current_costs.push_back(current_cost); + layer.detcost_transition.next_costs.push_back(next_cost); + current_row[(size_t)detector] = current_cost; + } + } + } + *initial_future_detcost = std::move(current_row); +} + uint64_t pack_small_key(uint64_t state, uint64_t obs_flip_bit) { return (state << 1) | (obs_flip_bit & 1ULL); } @@ -333,16 +468,18 @@ std::vector merge_equal_keys(std::vector& items) { merged.reserve(items.size()); uint64_t cur_key = items[0].key; double cur_mass = items[0].mass; + double cur_penalty = items[0].penalty; for (size_t i = 1; i < items.size(); ++i) { if (items[i].key == cur_key) { cur_mass += items[i].mass; } else { - merged.push_back({cur_key, cur_mass}); + merged.push_back({cur_key, cur_mass, cur_penalty}); cur_key = items[i].key; cur_mass = items[i].mass; + cur_penalty = items[i].penalty; } } - merged.push_back({cur_key, cur_mass}); + merged.push_back({cur_key, cur_mass, cur_penalty}); return merged; } @@ -354,21 +491,44 @@ std::vector accumulate_state_masses_from_entries(const std::vector& entries, size_t beam_width) { +double branch_score(const PackedMass& item, TesseractTrellisRankingMode ranking_mode) { + if (ranking_mode == TesseractTrellisRankingMode::MassOnly) { + return item.mass; + } + if (item.penalty == INF || item.mass == 0.0) { + return -INF; + } + return std::log(item.mass) - item.penalty; +} + +double state_score(const StateMass& item, TesseractTrellisRankingMode ranking_mode) { + if (ranking_mode == TesseractTrellisRankingMode::MassOnly) { + return item.mass; + } + if (item.penalty == INF || item.mass == 0.0) { + return -INF; + } + return std::log(item.mass) - item.penalty; +} + +void keep_top_states(std::vector& entries, size_t beam_width, + TesseractTrellisRankingMode ranking_mode) { if (entries.empty()) { return; } @@ -377,7 +537,9 @@ void keep_top_states(std::vector& entries, size_t beam_width) { return; } std::nth_element(totals.begin(), totals.begin() + beam_width, totals.end(), - [](const StateMass& a, const StateMass& b) { return a.mass > b.mass; }); + [ranking_mode](const StateMass& a, const StateMass& b) { + return state_score(a, ranking_mode) > state_score(b, ranking_mode); + }); totals.resize(beam_width); std::sort(totals.begin(), totals.end(), [](const StateMass& a, const StateMass& b) { return a.state < b.state; @@ -398,12 +560,15 @@ void keep_top_states(std::vector& entries, size_t beam_width) { entries = std::move(kept); } -void keep_top_branch_entries(std::vector& entries, size_t beam_width) { +void keep_top_branch_entries(std::vector& entries, size_t beam_width, + TesseractTrellisRankingMode ranking_mode) { if (entries.size() <= beam_width) { return; } std::nth_element(entries.begin(), entries.begin() + beam_width, entries.end(), - [](const PackedMass& a, const PackedMass& b) { return a.mass > b.mass; }); + [ranking_mode](const PackedMass& a, const PackedMass& b) { + return branch_score(a, ranking_mode) > branch_score(b, ranking_mode); + }); entries.resize(beam_width); } @@ -418,6 +583,26 @@ TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_) errors = get_errors_from_dem(config.dem.flattened()); num_detectors = config.dem.count_detectors(); num_observables = config.dem.count_observables(); + + all_possible_detectors = boost::dynamic_bitset<>(num_detectors); + for (const auto& error : errors) { + for (int d : error.symptom.detectors) { + all_possible_detectors[(size_t)d] = true; + } + } + + auto faults = parse_faults(errors, num_observables); + size_t small_frontier_width = 0; + has_small_layer_templates = + num_observables <= 1 && + build_small_layer_templates(faults, num_detectors, &small_layer_templates, &small_frontier_width); + if (has_small_layer_templates) { + build_future_detcost_transitions(faults, num_detectors, &small_layer_templates, + &initial_future_detcost); + } else if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) { + throw std::invalid_argument( + "future-detcost ranking is currently implemented only for the packed small trellis path"); + } } void TesseractTrellisDecoder::decode_shot(const std::vector& detections) { @@ -434,66 +619,93 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection total_mass_obs0 = 0; total_mass_obs1 = 0; - auto faults = parse_faults(errors, num_observables); - - std::unordered_set all_possible_dets; - for (const auto& error : errors) { - for (int d : error.symptom.detectors) { - all_possible_dets.insert(d); - } - } + boost::dynamic_bitset<> actual_dets(num_detectors); for (uint64_t d : detections) { - if (!all_possible_dets.contains(int(d))) { + if (d >= num_detectors || !all_possible_detectors[d]) { low_confidence_flag = true; return; } + actual_dets.flip((size_t)d); } - std::vector small_layers; - if (num_observables <= 1 && - try_build_small_layer_faults(faults, num_detectors, detections, &small_layers, - &max_frontier_width_seen)) { + if (has_small_layer_templates) { + max_frontier_width_seen = 0; + std::vector current_target_bits_per_layer(small_layer_templates.size()); + std::vector next_target_bits_per_layer(small_layer_templates.size()); + std::vector expected_retiring_bits_per_layer(small_layer_templates.size()); + for (size_t layer_index = 0; layer_index < small_layer_templates.size(); ++layer_index) { + const auto& layer = small_layer_templates[layer_index]; + max_frontier_width_seen = std::max(max_frontier_width_seen, layer.current_active_detectors.size()); + uint64_t current_target_bits = compute_target_bits(layer.current_active_detectors, actual_dets); + current_target_bits_per_layer[layer_index] = current_target_bits; + expected_retiring_bits_per_layer[layer_index] = current_target_bits & layer.retiring_mask; + + uint64_t next_target_bits = 0; + for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { + uint8_t current_local = layer.surviving_local_indices[next_local]; + next_target_bits |= ((current_target_bits >> current_local) & 1ULL) << next_local; + } + next_target_bits_per_layer[layer_index] = next_target_bits; + } + std::vector beam_entries; - beam_entries.push_back({pack_small_key(0, 0), 1.0}); + double initial_penalty = 0.0; + if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) { + initial_penalty = compute_penalty_from_scratch( + current_target_bits_per_layer.empty() ? 0 : current_target_bits_per_layer.front(), + initial_future_detcost); + } + beam_entries.push_back({pack_small_key(0, 0), 1.0, initial_penalty}); max_beam_size_seen = 1; - for (size_t layer_index = 0; layer_index < small_layers.size(); ++layer_index) { - const auto& layer = small_layers[layer_index]; + for (size_t layer_index = 0; layer_index < small_layer_templates.size(); ++layer_index) { + const auto& layer = small_layer_templates[layer_index]; + const uint64_t current_target_bits = current_target_bits_per_layer[layer_index]; + const uint64_t next_target_bits = next_target_bits_per_layer[layer_index]; + const uint64_t expected_retiring_bits = expected_retiring_bits_per_layer[layer_index]; auto t0 = std::chrono::high_resolution_clock::now(); - std::vector expanded_entries; - expanded_entries.reserve(beam_entries.size() * 2); + std::vector next_entries; + next_entries.reserve(beam_entries.size() * 2); for (const auto& item : beam_entries) { ++num_states_expanded; - uint64_t base_state = unpack_small_state(item.key); - expanded_entries.push_back({pack_small_key(base_state, unpack_small_obs(item.key)), - item.mass * layer.q}); - uint64_t present_key = - pack_small_key(base_state ^ layer.local_det_mask, - unpack_small_obs(item.key) ^ layer.obs_flip_bit); - expanded_entries.push_back({present_key, item.mass * layer.p}); - } - auto t1 = std::chrono::high_resolution_clock::now(); - time_expand_seconds += - std::chrono::duration_cast(t1 - t0).count() / 1e6; + const uint64_t base_state = unpack_small_state(item.key); + const uint64_t base_obs = unpack_small_obs(item.key); + const double parent_penalty_next_row = + config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked + ? advance_penalty_row(item.penalty, base_state ^ current_target_bits, + layer.detcost_transition) + : 0.0; + + if (((base_state ^ expected_retiring_bits) & layer.retiring_mask) == 0) { + uint64_t projected_state = project_small_state(base_state, layer.surviving_local_indices); + double penalty = + config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked + ? adjust_penalty_for_branch(parent_penalty_next_row, base_state, current_target_bits, + next_target_bits, false, projected_state, layer) + : 0.0; + next_entries.push_back( + {pack_small_key(projected_state, base_obs), item.mass * layer.q, penalty}); + } - std::vector next_entries; - next_entries.reserve(expanded_entries.size()); - for (const auto& item : expanded_entries) { - uint64_t state = unpack_small_state(item.key); - if (((state ^ layer.expected_retiring_bits) & layer.retiring_mask) != 0) { - continue; + uint64_t toggled_state = base_state ^ layer.local_det_mask; + if (((toggled_state ^ expected_retiring_bits) & layer.retiring_mask) == 0) { + uint64_t projected_state = project_small_state(toggled_state, layer.surviving_local_indices); + double penalty = + config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked + ? adjust_penalty_for_branch(parent_penalty_next_row, base_state, current_target_bits, + next_target_bits, true, projected_state, layer) + : 0.0; + next_entries.push_back({pack_small_key(projected_state, base_obs ^ layer.obs_flip_bit), + item.mass * layer.p, penalty}); } - uint64_t projected_state = project_small_state(state, layer.surviving_local_indices); - uint64_t projected_key = pack_small_key(projected_state, unpack_small_obs(item.key)); - next_entries.push_back({projected_key, item.mass}); } - auto t1b = std::chrono::high_resolution_clock::now(); + auto t1 = std::chrono::high_resolution_clock::now(); time_collapse_seconds += - std::chrono::duration_cast(t1b - t1).count() / 1e6; + std::chrono::duration_cast(t1 - t0).count() / 1e6; beam_entries = std::move(next_entries); bool at_checkpoint = ((layer_index + 1) % config.merge_interval == 0) || - (layer_index + 1 == small_layers.size()); + (layer_index + 1 == small_layer_templates.size()); if (!at_checkpoint) { max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); if (beam_entries.empty()) { @@ -504,25 +716,25 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection } auto t2a = std::chrono::high_resolution_clock::now(); - if (config.prune_mode != TesseractTrellisPruneMode::kNoMerge) { + if (config.prune_mode != TesseractTrellisPruneMode::NoMerge) { beam_entries = merge_equal_keys(beam_entries); } auto t2 = std::chrono::high_resolution_clock::now(); time_collapse_seconds += std::chrono::duration_cast(t2 - t2a).count() / 1e6; - if (config.prune_mode == TesseractTrellisPruneMode::kMergedStates) { - keep_top_states(beam_entries, config.beam_width); - } else if (config.prune_mode == TesseractTrellisPruneMode::kBranchEntries || - config.prune_mode == TesseractTrellisPruneMode::kNoMerge) { - keep_top_branch_entries(beam_entries, config.beam_width); + if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { + keep_top_states(beam_entries, config.beam_width, config.ranking_mode); + } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries || + config.prune_mode == TesseractTrellisPruneMode::NoMerge) { + keep_top_branch_entries(beam_entries, config.beam_width, config.ranking_mode); } normalize_items(beam_entries); if (beam_entries.empty()) { low_confidence_flag = true; return; } - if (config.prune_mode == TesseractTrellisPruneMode::kNoMerge) { + if (config.prune_mode == TesseractTrellisPruneMode::NoMerge) { num_states_merged += beam_entries.size(); max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); } else { @@ -536,7 +748,8 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection } auto tr0 = std::chrono::high_resolution_clock::now(); - for (const auto& [packed_key, mass] : beam_entries) { + for (const auto& [packed_key, mass, penalty] : beam_entries) { + (void)penalty; if (unpack_small_state(packed_key) != 0) { continue; } @@ -555,6 +768,7 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection time_reconstruct_seconds += std::chrono::duration_cast(tr1 - tr0).count() / 1e6; } else { + auto faults = parse_faults(errors, num_observables); auto layers = build_layer_faults(faults, num_detectors, detections, &max_frontier_width_seen); std::unordered_map, FrontierAggregate> beam; FrontierAggregate init; diff --git a/src/tesseract_trellis.h b/src/tesseract_trellis.h index 491d74f..7715b98 100644 --- a/src/tesseract_trellis.h +++ b/src/tesseract_trellis.h @@ -22,9 +22,34 @@ #include "stim.h" enum class TesseractTrellisPruneMode { - kMergedStates, - kBranchEntries, - kNoMerge, + MergedStates, + BranchEntries, + NoMerge, +}; + +enum class TesseractTrellisRankingMode { + MassOnly, + FutureDetcostRanked, +}; + +struct TesseractTrellisDetcostTransition { + std::vector fault_local_indices; + std::vector next_local_indices; + std::vector current_costs; + std::vector next_costs; +}; + +struct TesseractTrellisSmallLayerTemplate { + double q = 0; + double p = 0; + uint64_t obs_flip_bit = 0; + uint64_t local_det_mask = 0; + uint64_t retiring_mask = 0; + size_t previous_width = 0; + std::vector surviving_local_indices; + std::vector current_active_detectors; + std::vector next_frontier_costs; + TesseractTrellisDetcostTransition detcost_transition; }; struct TesseractTrellisConfig { @@ -32,7 +57,8 @@ struct TesseractTrellisConfig { size_t beam_width = 1024; size_t merge_interval = 1; bool verbose = false; - TesseractTrellisPruneMode prune_mode = TesseractTrellisPruneMode::kMergedStates; + TesseractTrellisPruneMode prune_mode = TesseractTrellisPruneMode::MergedStates; + TesseractTrellisRankingMode ranking_mode = TesseractTrellisRankingMode::MassOnly; }; struct TesseractTrellisDecoder { @@ -62,6 +88,10 @@ struct TesseractTrellisDecoder { std::vector errors; size_t num_observables = 0; size_t num_detectors = 0; + boost::dynamic_bitset<> all_possible_detectors; + bool has_small_layer_templates = false; + std::vector small_layer_templates; + std::vector initial_future_detcost; }; #endif // TESSERACT_TRELLIS_DECODER_H diff --git a/src/tesseract_trellis_main.cc b/src/tesseract_trellis_main.cc index b252e29..d4362aa 100644 --- a/src/tesseract_trellis_main.cc +++ b/src/tesseract_trellis_main.cc @@ -27,12 +27,18 @@ namespace { TesseractTrellisPruneMode parse_prune_mode(const std::string& value) { - if (value == "merged") return TesseractTrellisPruneMode::kMergedStates; - if (value == "branch") return TesseractTrellisPruneMode::kBranchEntries; - if (value == "none") return TesseractTrellisPruneMode::kNoMerge; + if (value == "merged") return TesseractTrellisPruneMode::MergedStates; + if (value == "branch") return TesseractTrellisPruneMode::BranchEntries; + if (value == "none") return TesseractTrellisPruneMode::NoMerge; throw std::invalid_argument("Unknown trellis prune mode: " + value); } +TesseractTrellisRankingMode parse_ranking_mode(const std::string& value) { + if (value == "mass") return TesseractTrellisRankingMode::MassOnly; + if (value == "future-detcost") return TesseractTrellisRankingMode::FutureDetcostRanked; + throw std::invalid_argument("Unknown trellis ranking mode: " + value); +} + } // namespace struct Args { @@ -61,6 +67,7 @@ struct Args { size_t beam_width = 1024; size_t merge_interval = 1; std::string prune_mode = "merged"; + std::string ranking_mode = "mass"; bool verbose = false; bool print_stats = false; @@ -115,6 +122,7 @@ struct Args { throw std::invalid_argument("--merge-interval must be at least 1."); } parse_prune_mode(prune_mode); + parse_ranking_mode(ranking_mode); } void extract(TesseractTrellisConfig& config, std::vector& shots, @@ -150,6 +158,7 @@ struct Args { config.merge_interval = merge_interval; config.verbose = verbose; config.prune_mode = parse_prune_mode(prune_mode); + config.ranking_mode = parse_ranking_mode(ranking_mode); if (sample_num_shots > 0) { assert(!circuit_path.empty()); @@ -272,6 +281,10 @@ int main(int argc, char* argv[]) { .help("Trellis pruning mode: merged, branch, or none") .default_value(std::string("merged")) .store_into(args.prune_mode); + program.add_argument("--ranking-mode") + .help("Trellis ranking mode: mass or future-detcost") + .default_value(std::string("mass")) + .store_into(args.ranking_mode); program.add_argument("--verbose").flag().store_into(args.verbose); program.add_argument("--print-stats").flag().store_into(args.print_stats); From 53ca8da900bb1148b57b142c79a9f6f26862af4c Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Mon, 13 Apr 2026 21:28:30 -0700 Subject: [PATCH 15/25] some optimizations to tesseract trellis --- src/BUILD | 1 - src/tesseract_trellis.cc | 101 ++++++++------------------------------- 2 files changed, 20 insertions(+), 82 deletions(-) diff --git a/src/BUILD b/src/BUILD index 713f1ae..09d559e 100644 --- a/src/BUILD +++ b/src/BUILD @@ -194,7 +194,6 @@ cc_binary( ], ) - cc_binary( name = "tesseract_ftl", srcs = ["tesseract_ftl_main.cc"], diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index e2b3055..0d6e728 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -319,66 +319,6 @@ double compute_penalty_from_scratch(uint64_t mismatch_mask, return total; } -double advance_penalty_row(double current_penalty, uint64_t current_mismatch, - const TesseractTrellisDetcostTransition& transition) { - if (current_penalty == INF) { - return INF; - } - double total = current_penalty; - for (size_t k = 0; k < transition.fault_local_indices.size(); ++k) { - const uint64_t local_bit = uint64_t{1} << transition.fault_local_indices[k]; - if ((current_mismatch & local_bit) == 0) { - continue; - } - double current_cost = transition.current_costs[k]; - double next_cost = transition.next_costs[k]; - if (next_cost == INF) { - return INF; - } - total += next_cost - current_cost; - } - return total; -} - -double adjust_penalty_for_branch(double parent_penalty_next_row, uint64_t base_state, - uint64_t current_target_bits, uint64_t next_target_bits, - bool present_branch, uint64_t projected_state, - const TesseractTrellisSmallLayerTemplate& layer) { - if (parent_penalty_next_row == INF) { - return compute_penalty_from_scratch(projected_state ^ next_target_bits, layer.next_frontier_costs); - } - - double total = parent_penalty_next_row; - for (size_t k = 0; k < layer.detcost_transition.fault_local_indices.size(); ++k) { - uint8_t local = layer.detcost_transition.fault_local_indices[k]; - int8_t next_local = layer.detcost_transition.next_local_indices[k]; - if (next_local < 0) { - continue; - } - - const uint64_t state_bit = - local < layer.previous_width ? ((base_state >> local) & 1ULL) : 0ULL; - const uint64_t prev_mismatch = - local < layer.previous_width ? (state_bit ^ ((current_target_bits >> local) & 1ULL)) : 0ULL; - const uint64_t child_bit = state_bit ^ (present_branch ? 1ULL : 0ULL); - const uint64_t child_mismatch = child_bit ^ ((next_target_bits >> next_local) & 1ULL); - if (prev_mismatch == child_mismatch) { - continue; - } - - double next_cost = layer.detcost_transition.next_costs[k]; - if (child_mismatch) { - if (next_cost == INF) { - return INF; - } - total += next_cost; - } else { - total -= next_cost; - } - } - return total; -} - void build_future_detcost_transitions(const std::vector& faults, size_t num_detectors, std::vector* layers, std::vector* initial_future_detcost) { @@ -651,16 +591,24 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection std::vector beam_entries; double initial_penalty = 0.0; if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) { + std::vector initial_frontier_costs; + if (!small_layer_templates.empty()) { + const auto& first_layer = small_layer_templates.front(); + initial_frontier_costs.resize(first_layer.current_active_detectors.size(), INF); + for (size_t local = 0; local < first_layer.current_active_detectors.size(); ++local) { + initial_frontier_costs[local] = + initial_future_detcost[(size_t)first_layer.current_active_detectors[local]]; + } + } initial_penalty = compute_penalty_from_scratch( current_target_bits_per_layer.empty() ? 0 : current_target_bits_per_layer.front(), - initial_future_detcost); + initial_frontier_costs); } beam_entries.push_back({pack_small_key(0, 0), 1.0, initial_penalty}); max_beam_size_seen = 1; for (size_t layer_index = 0; layer_index < small_layer_templates.size(); ++layer_index) { const auto& layer = small_layer_templates[layer_index]; - const uint64_t current_target_bits = current_target_bits_per_layer[layer_index]; const uint64_t next_target_bits = next_target_bits_per_layer[layer_index]; const uint64_t expected_retiring_bits = expected_retiring_bits_per_layer[layer_index]; auto t0 = std::chrono::high_resolution_clock::now(); @@ -670,33 +618,17 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection ++num_states_expanded; const uint64_t base_state = unpack_small_state(item.key); const uint64_t base_obs = unpack_small_obs(item.key); - const double parent_penalty_next_row = - config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked - ? advance_penalty_row(item.penalty, base_state ^ current_target_bits, - layer.detcost_transition) - : 0.0; if (((base_state ^ expected_retiring_bits) & layer.retiring_mask) == 0) { uint64_t projected_state = project_small_state(base_state, layer.surviving_local_indices); - double penalty = - config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked - ? adjust_penalty_for_branch(parent_penalty_next_row, base_state, current_target_bits, - next_target_bits, false, projected_state, layer) - : 0.0; - next_entries.push_back( - {pack_small_key(projected_state, base_obs), item.mass * layer.q, penalty}); + next_entries.push_back({pack_small_key(projected_state, base_obs), item.mass * layer.q, 0.0}); } uint64_t toggled_state = base_state ^ layer.local_det_mask; if (((toggled_state ^ expected_retiring_bits) & layer.retiring_mask) == 0) { uint64_t projected_state = project_small_state(toggled_state, layer.surviving_local_indices); - double penalty = - config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked - ? adjust_penalty_for_branch(parent_penalty_next_row, base_state, current_target_bits, - next_target_bits, true, projected_state, layer) - : 0.0; - next_entries.push_back({pack_small_key(projected_state, base_obs ^ layer.obs_flip_bit), - item.mass * layer.p, penalty}); + next_entries.push_back( + {pack_small_key(projected_state, base_obs ^ layer.obs_flip_bit), item.mass * layer.p, 0.0}); } } auto t1 = std::chrono::high_resolution_clock::now(); @@ -723,6 +655,13 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection time_collapse_seconds += std::chrono::duration_cast(t2 - t2a).count() / 1e6; + if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) { + for (auto& item : beam_entries) { + item.penalty = compute_penalty_from_scratch(unpack_small_state(item.key) ^ next_target_bits, + layer.next_frontier_costs); + } + } + if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { keep_top_states(beam_entries, config.beam_width, config.ranking_mode); } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries || From 5929c5829a379e0aa7824c4e28297c285789c4e0 Mon Sep 17 00:00:00 2001 From: noajshu Date: Sat, 18 Apr 2026 18:20:59 +0000 Subject: [PATCH 16/25] updates to python prototypes and optimize the cc trellis prototype --- src/py/astar/multipass_beam_decoder.py | 1189 ++++++++++++++++ src/py/astar/plot_log.py | 139 ++ src/py/astar/trellis_beam_detcost_ranked.py | 639 +++++++-- ...trellis_beam_iterative_forward_backward.py | 1249 +++++++++++++++++ .../trellis_beam_opt_singleton_lp_ranked.py | 1218 ++++++++++++++++ src/tesseract_ftl.cc | 52 +- src/tesseract_ftl_main.cc | 15 +- src/tesseract_trellis.cc | 1117 ++++++++++----- src/tesseract_trellis.h | 29 +- src/tesseract_trellis_main.cc | 32 +- 10 files changed, 5165 insertions(+), 514 deletions(-) create mode 100644 src/py/astar/multipass_beam_decoder.py create mode 100644 src/py/astar/plot_log.py create mode 100644 src/py/astar/trellis_beam_iterative_forward_backward.py create mode 100644 src/py/astar/trellis_beam_opt_singleton_lp_ranked.py diff --git a/src/py/astar/multipass_beam_decoder.py b/src/py/astar/multipass_beam_decoder.py new file mode 100644 index 0000000..adb514c --- /dev/null +++ b/src/py/astar/multipass_beam_decoder.py @@ -0,0 +1,1189 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import heapq +import math +import shutil +import sys +import tempfile +import time +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import stim + + +STIM_RESULT_FORMATS = ("01", "b8", "r8", "ptb64", "hits", "dets") +STIM_RESULT_FORMATS_HELP = "/".join(STIM_RESULT_FORMATS) + + +@dataclass(frozen=True) +class Fault: + q: float + p: float + delta_scale: float + det_mask: int + likelihood_cost: float + + +@dataclass(frozen=True) +class DecoderModel: + faults: tuple[Fault, ...] + retiring_masks: tuple[int, ...] + live_masks_after: tuple[int, ...] + future_detcost: tuple[tuple[float, ...], ...] + all_possible_dets_mask: int + max_width: int + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + selected_pass: int = 1 + diagnostic_lines: tuple[str, ...] = () + + +@dataclass(frozen=True) +class DecodingShot: + det_mask: int + actual_logical: bool | None + + +@dataclass(frozen=True) +class ExperimentSummary: + predictions: list[bool | None] + num_certified: int + num_low_confidence: int + num_errors: int + num_truth_shots: int + num_scored_shots: int + total_elapsed: float + total_triggered: int + max_width_seen: int + + +HeuristicTables = tuple[dict[int, float], ...] +CandidateStates = tuple[tuple[int, ...], ...] + + +def _likelihood_cost(probability: float) -> float: + if probability <= 0.0: + return math.inf + if probability >= 1.0: + return 0.0 + return -math.log(probability / (1.0 - probability)) + + +def _detectors_from_mask(mask: int) -> list[int]: + detectors: list[int] = [] + while mask: + low_bit = mask & -mask + detectors.append(low_bit.bit_length() - 1) + mask ^= low_bit + return detectors + + +def _mask_from_bool_row(row: np.ndarray) -> int: + mask = 0 + for index in np.flatnonzero(row): + mask |= 1 << int(index) + return mask + + +def _future_detcost_by_detector(faults: tuple[Fault, ...], num_detectors: int) -> tuple[tuple[float, ...], ...]: + future_detcost: list[list[float]] = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = fault.det_mask.bit_count() + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in _detectors_from_mask(fault.det_mask): + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return tuple(tuple(row) for row in future_detcost) + + +def _build_decoder_model(circuit: stim.Circuit) -> DecoderModel: + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults: list[Fault] = [] + all_possible_dets_mask = 0 + last_seen_index: dict[int, int] = {} + + for inst in dem: + if inst.type != "error": + continue + + p = float(inst.args_copy()[0]) + det_mask = 0 + flip_l0 = 0 + for target in inst.targets_copy(): + if target.is_separator(): + continue + if target.is_relative_detector_id(): + det_mask ^= 1 << target.val + elif target.is_logical_observable_id() and target.val == 0: + flip_l0 ^= 1 + + faults.append( + Fault( + q=1.0 - p, + p=p, + delta_scale=(-p if flip_l0 else p), + det_mask=det_mask, + likelihood_cost=_likelihood_cost(p), + ) + ) + all_possible_dets_mask |= det_mask + + for det_id in _detectors_from_mask(det_mask): + last_seen_index[det_id] = len(faults) - 1 + + retiring_masks = [0] * len(faults) + for det_id, index in last_seen_index.items(): + retiring_masks[index] |= 1 << det_id + + live_masks_after = [0] * (len(faults) + 1) + active_mask = 0 + max_width = 0 + for i, fault in enumerate(faults): + active_mask |= fault.det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + live_masks_after[i + 1] = active_mask + + frozen_faults = tuple(faults) + return DecoderModel( + faults=frozen_faults, + retiring_masks=tuple(retiring_masks), + live_masks_after=tuple(live_masks_after), + future_detcost=_future_detcost_by_detector(frozen_faults, circuit.num_detectors), + all_possible_dets_mask=all_possible_dets_mask, + max_width=max_width, + ) + + +def _detcost_penalty(mismatch_mask: int, future_detcost: tuple[float, ...]) -> float: + total = 0.0 + pending = mismatch_mask + + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + + best = future_detcost[detector] + if best == math.inf: + return math.inf + total += best + + return total + + +def _candidate_state_limit(beam: int) -> int: + # One pass feeds a slightly wider neighborhood to the next pass, while still + # capping memory for long circuits. + return max(1, min(128, 2 * beam)) + + +def _top_ranked_entries( + entries: list[tuple[float, float, int, float]], + limit: int, +) -> list[tuple[float, float, int, float]]: + if limit <= 0: + return [] + if len(entries) <= limit: + return sorted(entries, reverse=True) + return heapq.nlargest(limit, entries) + + +def _base_penalty_at_layer( + *, + model: DecoderModel, + live_target_masks: tuple[int, ...], + layer: int, + state: int, +) -> float: + mismatch_mask = state ^ live_target_masks[layer] + return _detcost_penalty(mismatch_mask=mismatch_mask, future_detcost=model.future_detcost[layer]) + + +def _lookup_existing_penalty( + *, + model: DecoderModel, + live_target_masks: tuple[int, ...], + layer: int, + state: int, + heuristic_tables: HeuristicTables | None, +) -> float: + penalty = _base_penalty_at_layer( + model=model, + live_target_masks=live_target_masks, + layer=layer, + state=state, + ) + if heuristic_tables is not None: + refined = heuristic_tables[layer].get(state) + if refined is not None and refined > penalty: + penalty = refined + return penalty + + +def _forward_beam_pass( + *, + model: DecoderModel, + actual_dets_mask: int, + live_target_masks: tuple[int, ...], + beam_width: int, + heuristic_tables: HeuristicTables | None, + collect_candidates: bool, + selected_pass: int, +) -> tuple[BeamDecodeResult, CandidateStates, dict[str, float]]: + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + candidate_limit = _candidate_state_limit(beam_width) if collect_candidates else 0 + + candidate_states_list: list[tuple[int, ...]] = [tuple() for _ in range(len(model.faults) + 1)] + candidate_states_list[0] = (0,) + + stats: dict[str, float] = { + "ranked_states_total": 0.0, + "candidate_states_total": 1.0, + "layers_pruned": 0.0, + "peak_ranked_states": 0.0, + "states_using_refined_lb": 0.0, + "states_blocked_by_refined": 0.0, + "finite_penalty_gain_hits": 0.0, + "total_penalty_uplift": 0.0, + "max_penalty_uplift": 0.0, + } + + for i, fault in enumerate(model.faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = model.retiring_masks[i] + + if retiring_mask == 0: + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + total_mass += absent_total + entry = collapsed_probs.get(state) + if entry is None: + collapsed_probs[state] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + total_mass += present_total + entry = collapsed_probs.get(toggled) + if entry is None: + collapsed_probs[toggled] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + if (state & retiring_mask) == expected_bits: + shrunk = state & keep_mask + total_mass += absent_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + if (toggled & retiring_mask) == expected_bits: + shrunk = toggled & keep_mask + total_mass += present_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + if total_mass == 0.0: + return ( + BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=0.0, + selected_pass=selected_pass, + ), + tuple(candidate_states_list), + stats, + ) + + ranked_states: list[tuple[float, float, int, float]] = [] + next_live_target_mask = live_target_masks[i + 1] + next_future_detcost = model.future_detcost[i + 1] + next_heuristics = None if heuristic_tables is None else heuristic_tables[i + 1] + + ranked_count = len(collapsed_probs) + stats["ranked_states_total"] += ranked_count + stats["peak_ranked_states"] = max(stats["peak_ranked_states"], float(ranked_count)) + + for state, (total, delta) in collapsed_probs.items(): + mismatch_mask = state ^ next_live_target_mask + base_penalty = _detcost_penalty(mismatch_mask=mismatch_mask, future_detcost=next_future_detcost) + penalty = base_penalty + + if next_heuristics is not None: + refined_penalty = next_heuristics.get(state) + if refined_penalty is not None and refined_penalty > penalty: + penalty = refined_penalty + stats["states_using_refined_lb"] += 1.0 + if refined_penalty == math.inf and base_penalty != math.inf: + stats["states_blocked_by_refined"] += 1.0 + elif refined_penalty != math.inf and base_penalty != math.inf: + uplift = refined_penalty - base_penalty + stats["finite_penalty_gain_hits"] += 1.0 + stats["total_penalty_uplift"] += uplift + stats["max_penalty_uplift"] = max(stats["max_penalty_uplift"], uplift) + + if penalty == math.inf: + rank_score = -math.inf + else: + rank_score = math.log(total) - penalty + ranked_states.append((rank_score, total, state, delta)) + + top_needed = max(beam_width, candidate_limit) + top_entries = _top_ranked_entries(ranked_states, top_needed) + + if collect_candidates: + candidate_slice = top_entries[:candidate_limit] + candidate_states_list[i + 1] = tuple(state for _, _, state, _ in candidate_slice) + stats["candidate_states_total"] += len(candidate_slice) + + dropped_mass = 0.0 + if len(ranked_states) > beam_width: + stats["layers_pruned"] += 1.0 + kept = top_entries[:beam_width] + kept_mass = sum(total for _, total, _, _ in kept) + dropped_mass = total_mass - kept_mass + else: + kept = top_entries + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for _, total, state, delta in kept + ] + + candidate_states_list[-1] = (0,) + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + if final_delta == 0.0: + return ( + BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=0.0, + selected_pass=selected_pass, + ), + tuple(candidate_states_list), + stats, + ) + + return ( + BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=0.0, + selected_pass=selected_pass, + ), + tuple(candidate_states_list), + stats, + ) + + +def _build_refined_lower_bounds( + *, + model: DecoderModel, + actual_dets_mask: int, + live_target_masks: tuple[int, ...], + candidate_states: CandidateStates, + existing_tables: HeuristicTables | None, +) -> tuple[HeuristicTables, dict[str, float]]: + tables_list: list[dict[int, float]] = [dict() for _ in range(len(model.faults) + 1)] + tables_list[-1][0] = 0.0 + + stats: dict[str, float] = { + "candidate_states_total": float(sum(len(layer) for layer in candidate_states)), + "states_evaluated": 1.0, + "layers_with_candidates": 1.0, + "exact_successor_hits": 0.0, + "prior_successor_hits": 0.0, + "base_successor_hits": 0.0, + "states_improved": 0.0, + "states_ruled_out": 0.0, + "finite_gain_hits": 0.0, + "total_lb_gain": 0.0, + "max_lb_gain": 0.0, + } + + for i in range(len(model.faults) - 1, -1, -1): + states_here = candidate_states[i] + if not states_here: + continue + + stats["layers_with_candidates"] += 1.0 + fault = model.faults[i] + retiring_mask = model.retiring_masks[i] + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + next_refined = tables_list[i + 1] + current_refined = tables_list[i] + + for state in states_here: + best = math.inf + + if (state & retiring_mask) == expected_bits: + next_state = state & keep_mask + successor_penalty = _lookup_existing_penalty( + model=model, + live_target_masks=live_target_masks, + layer=i + 1, + state=next_state, + heuristic_tables=existing_tables, + ) + exact_successor = next_refined.get(next_state) + if exact_successor is not None and exact_successor > successor_penalty: + successor_penalty = exact_successor + stats["exact_successor_hits"] += 1.0 + elif existing_tables is not None and next_state in existing_tables[i + 1]: + stats["prior_successor_hits"] += 1.0 + else: + stats["base_successor_hits"] += 1.0 + best = min(best, successor_penalty) + + toggled = state ^ fault.det_mask + if (toggled & retiring_mask) == expected_bits: + next_state = toggled & keep_mask + successor_penalty = _lookup_existing_penalty( + model=model, + live_target_masks=live_target_masks, + layer=i + 1, + state=next_state, + heuristic_tables=existing_tables, + ) + exact_successor = next_refined.get(next_state) + if exact_successor is not None and exact_successor > successor_penalty: + successor_penalty = exact_successor + stats["exact_successor_hits"] += 1.0 + elif existing_tables is not None and next_state in existing_tables[i + 1]: + stats["prior_successor_hits"] += 1.0 + else: + stats["base_successor_hits"] += 1.0 + best = min(best, fault.likelihood_cost + successor_penalty) + + old_penalty = _lookup_existing_penalty( + model=model, + live_target_masks=live_target_masks, + layer=i, + state=state, + heuristic_tables=existing_tables, + ) + new_penalty = best if best > old_penalty else old_penalty + current_refined[state] = new_penalty + stats["states_evaluated"] += 1.0 + + if new_penalty > old_penalty: + stats["states_improved"] += 1.0 + if new_penalty == math.inf: + stats["states_ruled_out"] += 1.0 + elif old_penalty != math.inf: + gain = new_penalty - old_penalty + stats["finite_gain_hits"] += 1.0 + stats["total_lb_gain"] += gain + stats["max_lb_gain"] = max(stats["max_lb_gain"], gain) + + return tuple(tables_list), stats + + +def _result_confidence_key(result: BeamDecodeResult) -> tuple[float, ...]: + return ( + float(int(result.certified)), + float(int(result.predicted_logical is not None)), + result.margin - result.discarded_mass, + result.margin, + -result.discarded_mass, + float(result.selected_pass), + ) + + +def _format_forward_summary( + *, + shot_index: int | None, + pass_index: int, + num_passes: int, + beam_width: int, + candidate_limit: int, + stats: dict[str, float], + result: BeamDecodeResult, +) -> str: + refined_hits = int(stats["states_using_refined_lb"]) + finite_gain_hits = int(stats["finite_penalty_gain_hits"]) + avg_gain = stats["total_penalty_uplift"] / max(1, finite_gain_hits) + return ( + f"multipass shot={shot_index} pass={pass_index}/{num_passes} phase=forward " + f"beam={beam_width} candidate_limit={candidate_limit} ranked_states={int(stats['ranked_states_total'])} " + f"peak_layer_states={int(stats['peak_ranked_states'])} layers_pruned={int(stats['layers_pruned'])} " + f"refined_hits={refined_hits} refined_blocks={int(stats['states_blocked_by_refined'])} " + f"avg_penalty_gain={avg_gain:.6f} max_penalty_gain={stats['max_penalty_uplift']:.6f} " + f"prediction={result.predicted_logical} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e}" + ) + + +def _format_backward_summary( + *, + shot_index: int | None, + pass_index: int, + num_passes: int, + stats: dict[str, float], +) -> str: + finite_gain_hits = int(stats["finite_gain_hits"]) + avg_gain = stats["total_lb_gain"] / max(1, finite_gain_hits) + return ( + f"multipass shot={shot_index} pass={pass_index}/{num_passes} phase=backward " + f"candidate_states={int(stats['candidate_states_total'])} states_evaluated={int(stats['states_evaluated'])} " + f"layers_with_candidates={int(stats['layers_with_candidates'])} improved_states={int(stats['states_improved'])} " + f"ruled_out={int(stats['states_ruled_out'])} avg_lb_gain={avg_gain:.6f} " + f"max_lb_gain={stats['max_lb_gain']:.6f} successor_hits=(exact:{int(stats['exact_successor_hits'])}," + f"prior:{int(stats['prior_successor_hits'])},base:{int(stats['base_successor_hits'])})" + ) + + +def _format_selection_summary( + *, + shot_index: int | None, + chosen: BeamDecodeResult, + num_passes: int, +) -> str: + confidence_gap = chosen.margin - chosen.discarded_mass + return ( + f"multipass shot={shot_index} selection chosen_pass={chosen.selected_pass}/{num_passes} " + f"prediction={chosen.predicted_logical} certified={chosen.certified} " + f"confidence_gap={confidence_gap:.6e} margin={chosen.margin:.6e} " + f"discarded_mass={chosen.discarded_mass:.6e}" + ) + + +def _as_bool_2d(data: np.ndarray, *, expected_cols: int, description: str) -> np.ndarray: + arr = np.asarray(data) + if arr.ndim != 2: + raise ValueError(f"Expected {description} to be a 2D array but got shape {arr.shape!r}.") + if arr.shape[1] != expected_cols: + raise ValueError( + f"Expected {description} to have {expected_cols} columns but got {arr.shape[1]}." + ) + if arr.dtype != np.bool_: + arr = arr.astype(np.bool_, copy=False) + return arr + + +def _sample_shot_arrays( + circuit: stim.Circuit, + *, + shots: int, + seed: int | None, +) -> tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets, obs = sampler.sample(shots=shots, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=circuit.num_detectors, description="sampled detector data"), + _as_bool_2d(obs, expected_cols=circuit.num_observables, description="sampled observable data"), + ) + + +def _read_detector_shot_arrays( + *, + path: str, + fmt: str, + num_detectors: int, + num_observables: int, +) -> tuple[np.ndarray, np.ndarray | None]: + flat = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=num_detectors, + num_observables=num_observables, + ) + + expected_cols = num_detectors + num_observables + flat = _as_bool_2d( + flat, + expected_cols=expected_cols, + description="combined detector/observable input data", + ) + if num_observables: + return flat[:, :num_detectors], flat[:, num_detectors:] + return flat, None + + +def _read_observable_shot_array(*, path: str, fmt: str, num_observables: int) -> np.ndarray: + obs = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=0, + num_observables=num_observables, + ) + return _as_bool_2d(obs, expected_cols=num_observables, description="observable input data") + + +def _apply_shot_range( + dets: np.ndarray, + obs: np.ndarray | None, + *, + shot_range_begin: int, + shot_range_end: int, +) -> tuple[np.ndarray, np.ndarray | None]: + if not (shot_range_begin or shot_range_end): + return dets, obs + + if shot_range_end < shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if shot_range_end > len(dets): + raise ValueError( + f"Shot range end {shot_range_end} is past the end of the shot data (size {len(dets)})." + ) + + dets = dets[shot_range_begin:shot_range_end] + if obs is not None: + obs = obs[shot_range_begin:shot_range_end] + return dets, obs + + +def _shots_from_arrays(dets: np.ndarray, obs: np.ndarray | None) -> list[DecodingShot]: + shots: list[DecodingShot] = [] + for shot_index in range(dets.shape[0]): + actual_logical = None if obs is None else bool(obs[shot_index, 0]) + shots.append( + DecodingShot( + det_mask=_mask_from_bool_row(dets[shot_index]), + actual_logical=actual_logical, + ) + ) + return shots + + +def _resolve_stdin_path_if_needed(path: str, *, temp_dir: str, stem: str) -> str: + if path != "-": + return path + temp_path = str(Path(temp_dir) / f"{stem}.bin") + with open(temp_path, "wb") as f: + f.write(sys.stdin.buffer.read()) + return temp_path + + +def _resolve_stdout_path_if_needed(path: str, *, temp_dir: str, stem: str) -> tuple[str, bool]: + if path != "-": + return path, False + return str(Path(temp_dir) / f"{stem}.bin"), True + + +def _copy_file_to_stdout(path: str) -> None: + sys.stdout.flush() + with open(path, "rb") as f: + shutil.copyfileobj(f, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def _load_shots( + circuit: stim.Circuit, + args: argparse.Namespace, + *, + temp_dir: str, +) -> list[DecodingShot]: + if args.in_file: + in_path = _resolve_stdin_path_if_needed(args.in_file, temp_dir=temp_dir, stem="shots_in") + appended_obs_count = circuit.num_observables if args.in_includes_appended_observables else 0 + dets, obs = _read_detector_shot_arrays( + path=in_path, + fmt=args.in_format, + num_detectors=circuit.num_detectors, + num_observables=appended_obs_count, + ) + + if args.obs_in_file: + obs_in_path = _resolve_stdin_path_if_needed(args.obs_in_file, temp_dir=temp_dir, stem="obs_in") + obs = _read_observable_shot_array( + path=obs_in_path, + fmt=args.obs_in_format, + num_observables=circuit.num_observables, + ) + if len(obs) != len(dets): + raise ValueError("Observable input ended before, or after, the detector shot data.") + else: + dets, obs = _sample_shot_arrays(circuit, shots=args.sample_num_shots, seed=args.sample_seed) + + dets, obs = _apply_shot_range( + dets, + obs, + shot_range_begin=args.shot_range_begin, + shot_range_end=args.shot_range_end, + ) + return _shots_from_arrays(dets, obs) + + +def decode_beam_search_detcost_ranked( + model: DecoderModel, + actual_dets_mask: int, + L: int, + *, + num_passes: int = 1, + shot_index: int | None = None, +) -> BeamDecodeResult: + start_time = time.perf_counter() + + if (actual_dets_mask & ~model.all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + selected_pass=1, + ) + + live_target_masks = tuple(actual_dets_mask & mask for mask in model.live_masks_after) + candidate_limit = _candidate_state_limit(L) + current_tables: HeuristicTables | None = None + per_pass_results: list[BeamDecodeResult] = [] + diagnostic_lines: list[str] = [] + + for pass_index in range(1, num_passes + 1): + collect_candidates = pass_index < num_passes + pass_result, candidate_states, forward_stats = _forward_beam_pass( + model=model, + actual_dets_mask=actual_dets_mask, + live_target_masks=live_target_masks, + beam_width=L, + heuristic_tables=current_tables, + collect_candidates=collect_candidates, + selected_pass=pass_index, + ) + per_pass_results.append(pass_result) + + if num_passes > 1: + diagnostic_lines.append( + _format_forward_summary( + shot_index=shot_index, + pass_index=pass_index, + num_passes=num_passes, + beam_width=L, + candidate_limit=candidate_limit, + stats=forward_stats, + result=pass_result, + ) + ) + + if collect_candidates: + current_tables, backward_stats = _build_refined_lower_bounds( + model=model, + actual_dets_mask=actual_dets_mask, + live_target_masks=live_target_masks, + candidate_states=candidate_states, + existing_tables=current_tables, + ) + if num_passes > 1: + diagnostic_lines.append( + _format_backward_summary( + shot_index=shot_index, + pass_index=pass_index, + num_passes=num_passes, + stats=backward_stats, + ) + ) + + chosen = max(per_pass_results, key=_result_confidence_key) + if num_passes > 1: + diagnostic_lines.append( + _format_selection_summary( + shot_index=shot_index, + chosen=chosen, + num_passes=num_passes, + ) + ) + + return BeamDecodeResult( + predicted_logical=chosen.predicted_logical, + certified=chosen.certified, + margin=chosen.margin, + discarded_mass=chosen.discarded_mass, + max_width=chosen.max_width, + elapsed_seconds=time.perf_counter() - start_time, + selected_pass=chosen.selected_pass, + diagnostic_lines=tuple(diagnostic_lines), + ) + + +def _print_run_header( + *, + circuit: stim.Circuit, + args: argparse.Namespace, + num_shots: int, + log_stream, +) -> None: + print(f"Running on circuit {args.circuit}", file=log_stream) + print(f"Total Detectors: {circuit.num_detectors}", file=log_stream) + print(f"Total Observables: {circuit.num_observables}", file=log_stream) + if args.in_file: + print(f"Shot Input: {args.in_file}", file=log_stream) + print(f"Shot Input Format: {args.in_format}", file=log_stream) + if args.in_includes_appended_observables: + print("Observable Input: appended to --in", file=log_stream) + elif args.obs_in_file: + print(f"Observable Input: {args.obs_in_file}", file=log_stream) + print(f"Observable Format: {args.obs_in_format}", file=log_stream) + else: + print("Observable Input: none", file=log_stream) + else: + print(f"Sample Seed: {args.sample_seed}", file=log_stream) + print(f"Requested Shots: {args.sample_num_shots}", file=log_stream) + if args.shot_range_begin or args.shot_range_end: + print( + f"Shot Range: [{args.shot_range_begin}, {args.shot_range_end})", + file=log_stream, + ) + print(f"Beam: {args.beam}", file=log_stream) + print(f"Num Passes: {args.num_passes}", file=log_stream) + if args.num_passes > 1: + print(f"Pass Candidate Limit: {_candidate_state_limit(args.beam)}", file=log_stream) + print( + "Pass Logic: forward beam -> candidate residual states -> backward Bellman lower bounds -> choose best-confidence pass", + file=log_stream, + ) + print(f"Num Shots: {num_shots}", file=log_stream) + + +def run_experiment(args: argparse.Namespace) -> ExperimentSummary: + circuit = stim.Circuit.from_file(args.circuit) + if circuit.num_observables != 1: + raise ValueError( + "This decoder currently supports exactly one logical observable, because it only tracks L0. " + f"The circuit has {circuit.num_observables} observables." + ) + + model = _build_decoder_model(circuit) + log_stream = sys.stderr if args.out_file == "-" else sys.stdout + + with tempfile.TemporaryDirectory() as temp_dir: + shots = _load_shots(circuit, args, temp_dir=temp_dir) + _print_run_header(circuit=circuit, args=args, num_shots=len(shots), log_stream=log_stream) + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + num_truth_shots = 0 + num_scored_shots = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + predictions: list[bool | None] = [] + selected_pass_counts = [0] * (args.num_passes + 1) + + detailed_multipass_for_all = args.print_per_shot or len(shots) <= 10 + + for shot_index, shot in enumerate(shots): + result = decode_beam_search_detcost_ranked( + model, + shot.det_mask, + args.beam, + num_passes=args.num_passes, + shot_index=shot_index, + ) + predictions.append(result.predicted_logical) + selected_pass_counts[result.selected_pass] += 1 + + if result.diagnostic_lines: + if detailed_multipass_for_all or shot_index == 0: + for line in result.diagnostic_lines: + print(line, file=log_stream) + else: + print(result.diagnostic_lines[-1], file=log_stream) + + success: bool | None + if shot.actual_logical is None or result.predicted_logical is None: + success = None + else: + success = result.predicted_logical == shot.actual_logical + + if result.predicted_logical is None: + num_low_confidence += 1 + if shot.actual_logical is not None: + num_truth_shots += 1 + if success is not None: + num_scored_shots += 1 + if not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + triggered_dets = shot.det_mask.bit_count() + total_triggered += triggered_dets + max_width_seen = max(max_width_seen, result.max_width) + + shots_done = shot_index + 1 + error_rate_so_far = num_errors / num_scored_shots if num_scored_shots else 0.0 + progress_line = ( + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} scored_shots_so_far={num_scored_shots} " + f"error_rate_so_far={error_rate_so_far:.6f} elapsed_total_seconds={total_elapsed:.6f}" + ) + if args.num_passes > 1: + progress_line += f" selected_pass={result.selected_pass}" + print(progress_line, file=log_stream) + + if args.print_per_shot: + print( + f"shot={shot_index} triggered_detectors={triggered_dets} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} selected_pass={result.selected_pass} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"elapsed_seconds={result.elapsed_seconds:.6f}", + file=log_stream, + ) + + if args.out_file: + output_path, copy_to_stdout = _resolve_stdout_path_if_needed( + args.out_file, + temp_dir=temp_dir, + stem="predictions_out", + ) + prediction_data = np.zeros((len(predictions), circuit.num_observables), dtype=np.bool_) + for shot_index, predicted_logical in enumerate(predictions): + prediction_data[shot_index, 0] = bool(predicted_logical) if predicted_logical is not None else False + + if args.out_format == "ptb64" and len(prediction_data) % 64 != 0: + raise ValueError("The ptb64 format requires the number of shots to be a multiple of 64.") + + stim.write_shot_data_file( + data=prediction_data, + path=output_path, + format=args.out_format, + num_measurements=0, + num_detectors=0, + num_observables=circuit.num_observables, + ) + if copy_to_stdout: + _copy_file_to_stdout(output_path) + if num_low_confidence: + print( + f"warning: wrote {num_low_confidence} low-confidence predictions as L0=0 because Stim result " + "files can only store bits, not unknown values.", + file=log_stream, + ) + + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}", file=log_stream) + print(f"Max Width: {max_width_seen}", file=log_stream) + print(f"Certified Shots: {num_certified}", file=log_stream) + print(f"Low Confidence: {num_low_confidence}", file=log_stream) + print(f"Truth-Labeled Shots: {num_truth_shots}", file=log_stream) + print(f"Scored Shots: {num_scored_shots}", file=log_stream) + if args.num_passes > 1: + selected_summary = " ".join( + f"P{pass_index}={count}" + for pass_index, count in enumerate(selected_pass_counts[1:], start=1) + ) + print(f"Selected Passes: {selected_summary}", file=log_stream) + if num_truth_shots: + print(f"Logical Errors: {num_errors}", file=log_stream) + else: + print("Logical Errors: n/a", file=log_stream) + print(f"Total Seconds: {total_elapsed:.6f}", file=log_stream) + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}", file=log_stream) + + return ExperimentSummary( + predictions=predictions, + num_certified=num_certified, + num_low_confidence=num_low_confidence, + num_errors=num_errors, + num_truth_shots=num_truth_shots, + num_scored_shots=num_scored_shots, + total_elapsed=total_elapsed, + total_triggered=total_triggered, + max_width_seen=max_width_seen, + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run trellis beam decoding ranked by mass minus a detcost-style future penalty, " + "optionally refined by multi-pass candidate-state Bellman backups, " + "with Stim-compatible shot-data I/O options." + ), + allow_abbrev=False, + ) + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument( + "--num-passes", + type=int, + default=1, + help=( + "Number of forward/backward refinement passes. 1 reproduces the original single-pass beam search. " + "Larger values reuse beam states from one pass to sharpen the remaining-cost estimates of the next." + ), + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=None, + help="Number of sampled shots. Defaults to 1 unless --in is provided.", + ) + parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") + parser.add_argument( + "--shot-range-begin", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--shot-range-end", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--in", + dest="in_file", + default="", + help="File to read detection events from (use - for stdin).", + ) + parser.add_argument( + "--in-format", + "--in_format", + dest="in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--in-includes-appended-observables", + "--in_includes_appended_observables", + dest="in_includes_appended_observables", + action="store_true", + help="Assume the observable flips are appended to each shot in --in.", + ) + parser.add_argument( + "--obs-in", + "--obs_in", + dest="obs_in_file", + default="", + help="File to read observable flips from (use - for stdin).", + ) + parser.add_argument( + "--obs-in-format", + "--obs_in_format", + dest="obs_in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --obs-in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--out", + dest="out_file", + default="", + help="File to write predicted observable flips to (use - for stdout).", + ) + parser.add_argument( + "--out-format", + "--out_format", + dest="out_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file written by --out ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--print-per-shot", + action="store_true", + help="Print a detailed line per decoded shot.", + ) + args = parser.parse_args() + + if args.sample_num_shots is None: + # Preserve the original script's one-shot default while still allowing + # file input without requiring --sample-num-shots 0. + args.sample_num_shots = 0 if args.in_file else 1 + + if args.beam <= 0: + raise ValueError("--beam must be positive.") + if args.num_passes <= 0: + raise ValueError("--num-passes must be positive.") + if args.sample_num_shots < 0: + raise ValueError("--sample-num-shots must be non-negative.") + if args.sample_seed is not None and args.sample_seed < 0: + raise ValueError("--sample-seed must be non-negative.") + if args.shot_range_begin < 0 or args.shot_range_end < 0: + raise ValueError("--shot-range-begin and --shot-range-end must be non-negative.") + if args.shot_range_end < args.shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if args.in_includes_appended_observables and args.obs_in_file: + raise ValueError( + "Choose either --in-includes-appended-observables or --obs-in, not both." + ) + if args.obs_in_file and not args.in_file: + raise ValueError("Cannot load observable flips from --obs-in without also providing --in.") + if args.in_file == "-" and args.obs_in_file == "-": + raise ValueError("At most one of --in and --obs-in may read from stdin.") + + num_shot_sources = int(args.sample_num_shots > 0) + int(bool(args.in_file)) + if num_shot_sources != 1: + raise ValueError("Requires exactly one source of shots: either --sample-num-shots > 0 or --in.") + + return args + + +if __name__ == "__main__": + run_experiment(_parse_args()) diff --git a/src/py/astar/plot_log.py b/src/py/astar/plot_log.py new file mode 100644 index 0000000..2ffd9ad --- /dev/null +++ b/src/py/astar/plot_log.py @@ -0,0 +1,139 @@ +import sys +import os +import matplotlib.pyplot as plt +import numpy as np + +def analyze_log(filename): + min_masses = [] + errors = [] + + current_errors = 0 + pending_error_diff = None + + # Parse the log file line by line + with open(filename, 'r') as f: + for line in f: + parts = line.split() + if not parts: + continue + + if parts[0] == "num_shots": + # Find 'num_errors' and grab the value two indices over (past the '=') + idx = parts.index("num_errors") + errs = int(parts[idx + 2]) + + # Calculate errors for this specific shot and store it as pending + pending_error_diff = errs - current_errors + current_errors = errs + + elif parts[0] == "branch_masses": + # Parse out the values from obs0=... and obs1=... + obs0 = float(parts[1].split("=")[1]) + obs1 = float(parts[2].split("=")[1]) + norm = obs0 + obs1 + if norm == 0: + obs0 = 0.5 + obs1 = 0.5 + else: + obs0 /= norm + obs1 /= norm + + # Only append if we just successfully parsed a num_shots line + if pending_error_diff is not None: + min_masses.append(min(obs0, obs1)) + errors.append(pending_error_diff) + + # Reset pending diff to ensure we don't double-count + pending_error_diff = None + + min_masses = np.array(min_masses) + errors = np.array(errors) + + if len(min_masses) == 0: + print("No valid shot data found in the file.") + return + + # To calculate how error rates change based on our cutoff, + # we sort the shots from most certain (lowest min_mass) to least certain. + sorted_idx = np.argsort(min_masses) + sorted_masses = min_masses[sorted_idx] + sorted_errors = errors[sorted_idx] + + N = len(sorted_masses) + + # K represents the number of shots we *accept* (1 to N) + K_arr = np.arange(1, N + 1) + + # Cumulative errors in the accepted subset of shots + accepted_errors = np.cumsum(sorted_errors) + + # Error rate = (errors in accepted subset) / (number of accepted shots) + error_rates = accepted_errors / K_arr + + # Rejection rate = (number of rejected shots) / (total shots) + rejection_rates = (N - K_arr) / N + + # ------------------ + # Pre-process for Log Scale Histogram + # ------------------ + # Find the smallest non-zero mass. If everything is 0, default to 1e-10 + if np.any(min_masses > 0): + min_nonzero = np.min(min_masses[min_masses > 0]) + # Set exact 0s to half the minimum non-zero value so they fall in the leftmost bin + epsilon = min_nonzero / 2.0 + else: + epsilon = 1e-10 + + # Replace 0s with epsilon + masses_for_hist = np.where(min_masses == 0, epsilon, min_masses) + + # Safely get max mass to define bin edges + max_mass = np.max(masses_for_hist) + if max_mass == epsilon: + max_mass = epsilon * 10 # Fallback in case all values were 0 + + # Generate 50 logarithmically spaced bins + log_bins = np.logspace(np.log10(epsilon), np.log10(max_mass), 50) + + # ------------------ + # Create the Figures + # ------------------ + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + # Plot 1: Distribution of min masses (Log Scale X) + axes[0].hist(masses_for_hist, bins=log_bins, color='skyblue', edgecolor='black') + axes[0].set_xscale('log') + axes[0].set_xlabel('Min Mass (Log Scale, 0s in leftmost bin)') + axes[0].set_ylabel('Frequency') + axes[0].set_title('Distribution of Min Masses') + + # Plot 2: Logical error rate vs Min Mass Cutoff + axes[1].plot(sorted_masses, error_rates, color='purple', lw=2) + axes[1].set_xlabel('Min Mass Cutoff (Threshold)') + axes[1].set_ylabel('Logical Error Rate (Accepted Shots)') + axes[1].set_title('Error Rate vs Min Mass Cutoff') + axes[1].grid(True, linestyle='--', alpha=0.7) + + # Plot 3: Logical error rate vs Rejection rate + axes[2].plot(rejection_rates, error_rates, color='red', lw=2) + axes[2].set_xlabel('Rejection Rate') + axes[2].set_ylabel('Logical Error Rate (Accepted Shots)') + axes[2].set_title('Error Rate vs Rejection Rate') + axes[2].grid(True, linestyle='--', alpha=0.7) + axes[2].set_xlim(0, 1) + + plt.tight_layout() + + # Generate output filename based on input filename + base_name = os.path.splitext(os.path.basename(filename))[0] + out_filename = f"{base_name}_analysis.png" + + # Save to disk instead of displaying + plt.savefig(out_filename, dpi=300, bbox_inches='tight') + print(f"Success! Plot saved to disk as: {out_filename}") + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(f"Usage: python {sys.argv[0]} ") + else: + analyze_log(sys.argv[1]) diff --git a/src/py/astar/trellis_beam_detcost_ranked.py b/src/py/astar/trellis_beam_detcost_ranked.py index 7315a4d..be934ee 100644 --- a/src/py/astar/trellis_beam_detcost_ranked.py +++ b/src/py/astar/trellis_beam_detcost_ranked.py @@ -1,13 +1,23 @@ #!/usr/bin/env python3 +from __future__ import annotations import argparse import math +import shutil +import sys +import tempfile import time from dataclasses import dataclass +from pathlib import Path +import numpy as np import stim +STIM_RESULT_FORMATS = ("01", "b8", "r8", "ptb64", "hits", "dets") +STIM_RESULT_FORMATS_HELP = "/".join(STIM_RESULT_FORMATS) + + @dataclass(frozen=True) class Fault: q: float @@ -17,6 +27,16 @@ class Fault: likelihood_cost: float +@dataclass(frozen=True) +class DecoderModel: + faults: tuple[Fault, ...] + retiring_masks: tuple[int, ...] + live_masks_after: tuple[int, ...] + future_detcost: tuple[tuple[float, ...], ...] + all_possible_dets_mask: int + max_width: int + + @dataclass(frozen=True) class BeamDecodeResult: predicted_logical: bool | None @@ -28,9 +48,22 @@ class BeamDecodeResult: @dataclass(frozen=True) -class SampledShot: +class DecodingShot: det_mask: int - actual_logical: bool + actual_logical: bool | None + + +@dataclass(frozen=True) +class ExperimentSummary: + predictions: list[bool | None] + num_certified: int + num_low_confidence: int + num_errors: int + num_truth_shots: int + num_scored_shots: int + total_elapsed: float + total_triggered: int + max_width_seen: int def _likelihood_cost(probability: float) -> float: @@ -50,7 +83,31 @@ def _detectors_from_mask(mask: int) -> list[int]: return detectors -def _parse_circuit(circuit: stim.Circuit) -> tuple[list[Fault], list[int], list[int], int]: +def _mask_from_bool_row(row: np.ndarray) -> int: + mask = 0 + for index in np.flatnonzero(row): + mask |= 1 << int(index) + return mask + + +def _future_detcost_by_detector(faults: tuple[Fault, ...], num_detectors: int) -> tuple[tuple[float, ...], ...]: + future_detcost: list[list[float]] = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = fault.det_mask.bit_count() + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in _detectors_from_mask(fault.det_mask): + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return tuple(tuple(row) for row in future_detcost) + + +def _build_decoder_model(circuit: stim.Circuit) -> DecoderModel: dem = circuit.detector_error_model(decompose_errors=False).flattened() faults: list[Fault] = [] @@ -99,27 +156,18 @@ def _parse_circuit(circuit: stim.Circuit) -> tuple[list[Fault], list[int], list[ active_mask &= ~retiring_masks[i] live_masks_after[i + 1] = active_mask - return faults, retiring_masks, live_masks_after, max_width - - -def _future_detcost_by_detector(faults: list[Fault], num_detectors: int) -> list[list[float]]: - future_detcost = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] - next_row = future_detcost[-1] - for fault_index in range(len(faults) - 1, -1, -1): - row = next_row.copy() - fault = faults[fault_index] - det_count = fault.det_mask.bit_count() - if det_count: - ecost = fault.likelihood_cost / det_count - for det_id in _detectors_from_mask(fault.det_mask): - if ecost < row[det_id]: - row[det_id] = ecost - future_detcost[fault_index] = row - next_row = row - return future_detcost + frozen_faults = tuple(faults) + return DecoderModel( + faults=frozen_faults, + retiring_masks=tuple(retiring_masks), + live_masks_after=tuple(live_masks_after), + future_detcost=_future_detcost_by_detector(frozen_faults, circuit.num_detectors), + all_possible_dets_mask=all_possible_dets_mask, + max_width=max_width, + ) -def _detcost_penalty(mismatch_mask: int, future_detcost: list[float]) -> float: +def _detcost_penalty(mismatch_mask: int, future_detcost: tuple[float, ...]) -> float: total = 0.0 pending = mismatch_mask @@ -136,54 +184,200 @@ def _detcost_penalty(mismatch_mask: int, future_detcost: list[float]) -> float: return total -def sample_shots(circuit: stim.Circuit, shots: int, seed: int | None) -> list[SampledShot]: +def _as_bool_2d(data: np.ndarray, *, expected_cols: int, description: str) -> np.ndarray: + arr = np.asarray(data) + if arr.ndim != 2: + raise ValueError(f"Expected {description} to be a 2D array but got shape {arr.shape!r}.") + if arr.shape[1] != expected_cols: + raise ValueError( + f"Expected {description} to have {expected_cols} columns but got {arr.shape[1]}." + ) + if arr.dtype != np.bool_: + arr = arr.astype(np.bool_, copy=False) + return arr + + +def _sample_shot_arrays( + circuit: stim.Circuit, + *, + shots: int, + seed: int | None, +) -> tuple[np.ndarray, np.ndarray]: sampler = circuit.compile_detector_sampler(seed=seed) - syndromes, logicals = sampler.sample(shots=shots, separate_observables=True) - out: list[SampledShot] = [] - for shot_index in range(shots): - det_mask = 0 - for detector, fired in enumerate(syndromes[shot_index]): - if fired: - det_mask ^= 1 << detector - out.append(SampledShot(det_mask=det_mask, actual_logical=bool(logicals[shot_index][0]))) - return out + dets, obs = sampler.sample(shots=shots, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=circuit.num_detectors, description="sampled detector data"), + _as_bool_2d(obs, expected_cols=circuit.num_observables, description="sampled observable data"), + ) -def decode_beam_search_detcost_ranked( +def _read_detector_shot_arrays( + *, + path: str, + fmt: str, + num_detectors: int, + num_observables: int, +) -> tuple[np.ndarray, np.ndarray | None]: + common_kwargs = dict( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=num_detectors, + num_observables=num_observables, + ) + + if num_observables: + try: + dets, obs = stim.read_shot_data_file(**common_kwargs, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=num_detectors, description="input detector data"), + _as_bool_2d(obs, expected_cols=num_observables, description="appended observable data"), + ) + except TypeError: + flat = stim.read_shot_data_file(**common_kwargs) + flat = _as_bool_2d( + flat, + expected_cols=num_detectors + num_observables, + description="combined detector/observable input data", + ) + return flat[:, :num_detectors], flat[:, num_detectors:] + + flat = stim.read_shot_data_file(**common_kwargs) + return _as_bool_2d(flat, expected_cols=num_detectors, description="input detector data"), None + + +def _read_observable_shot_array(*, path: str, fmt: str, num_observables: int) -> np.ndarray: + obs = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=0, + num_observables=num_observables, + ) + return _as_bool_2d(obs, expected_cols=num_observables, description="observable input data") + + +def _apply_shot_range( + dets: np.ndarray, + obs: np.ndarray | None, + *, + shot_range_begin: int, + shot_range_end: int, +) -> tuple[np.ndarray, np.ndarray | None]: + if not (shot_range_begin or shot_range_end): + return dets, obs + + if shot_range_end < shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if shot_range_end > len(dets): + raise ValueError( + f"Shot range end {shot_range_end} is past the end of the shot data (size {len(dets)})." + ) + + dets = dets[shot_range_begin:shot_range_end] + if obs is not None: + obs = obs[shot_range_begin:shot_range_end] + return dets, obs + + +def _shots_from_arrays(dets: np.ndarray, obs: np.ndarray | None) -> list[DecodingShot]: + shots: list[DecodingShot] = [] + for shot_index in range(dets.shape[0]): + actual_logical = None if obs is None else bool(obs[shot_index, 0]) + shots.append( + DecodingShot( + det_mask=_mask_from_bool_row(dets[shot_index]), + actual_logical=actual_logical, + ) + ) + return shots + + +def _resolve_stdin_path_if_needed(path: str, *, temp_dir: str, stem: str) -> str: + if path != "-": + return path + temp_path = str(Path(temp_dir) / f"{stem}.bin") + with open(temp_path, "wb") as f: + f.write(sys.stdin.buffer.read()) + return temp_path + + +def _resolve_stdout_path_if_needed(path: str, *, temp_dir: str, stem: str) -> tuple[str, bool]: + if path != "-": + return path, False + return str(Path(temp_dir) / f"{stem}.bin"), True + + +def _copy_file_to_stdout(path: str) -> None: + sys.stdout.flush() + with open(path, "rb") as f: + shutil.copyfileobj(f, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def _load_shots( circuit: stim.Circuit, - actual_dets: set[int], + args: argparse.Namespace, + *, + temp_dir: str, +) -> list[DecodingShot]: + if args.in_file: + in_path = _resolve_stdin_path_if_needed(args.in_file, temp_dir=temp_dir, stem="shots_in") + appended_obs_count = circuit.num_observables if args.in_includes_appended_observables else 0 + dets, obs = _read_detector_shot_arrays( + path=in_path, + fmt=args.in_format, + num_detectors=circuit.num_detectors, + num_observables=appended_obs_count, + ) + + if args.obs_in_file: + obs_in_path = _resolve_stdin_path_if_needed(args.obs_in_file, temp_dir=temp_dir, stem="obs_in") + obs = _read_observable_shot_array( + path=obs_in_path, + fmt=args.obs_in_format, + num_observables=circuit.num_observables, + ) + if len(obs) != len(dets): + raise ValueError("Observable input ended before, or after, the detector shot data.") + else: + dets, obs = _sample_shot_arrays(circuit, shots=args.sample_num_shots, seed=args.sample_seed) + + dets, obs = _apply_shot_range( + dets, + obs, + shot_range_begin=args.shot_range_begin, + shot_range_end=args.shot_range_end, + ) + return _shots_from_arrays(dets, obs) + + +def decode_beam_search_detcost_ranked( + model: DecoderModel, + actual_dets_mask: int, L: int, ) -> BeamDecodeResult: start_time = time.perf_counter() - faults, retiring_masks, live_masks_after, max_width = _parse_circuit(circuit) - - actual_dets_mask = 0 - for detector in actual_dets: - actual_dets_mask ^= 1 << detector - - all_possible_dets_mask = 0 - for fault in faults: - all_possible_dets_mask |= fault.det_mask - if (actual_dets_mask & ~all_possible_dets_mask) != 0: + if (actual_dets_mask & ~model.all_possible_dets_mask) != 0: return BeamDecodeResult( predicted_logical=None, certified=False, margin=0.0, discarded_mass=0.0, - max_width=0, + max_width=model.max_width, elapsed_seconds=time.perf_counter() - start_time, ) - future_detcost = _future_detcost_by_detector(faults, circuit.num_detectors) - beam = [(0, 1.0, 1.0)] discarded_mass = 0.0 - for i, fault in enumerate(faults): + for i, fault in enumerate(model.faults): collapsed_probs: dict[int, list[float]] = {} total_mass = 0.0 - retiring_mask = retiring_masks[i] + retiring_mask = model.retiring_masks[i] if retiring_mask == 0: for state, total, delta in beam: @@ -242,13 +436,13 @@ def decode_beam_search_detcost_ranked( certified=False, margin=0.0, discarded_mass=discarded_mass, - max_width=max_width, + max_width=model.max_width, elapsed_seconds=time.perf_counter() - start_time, ) ranked_states: list[tuple[float, float, int, float]] = [] - live_target_mask = actual_dets_mask & live_masks_after[i + 1] - next_future_detcost = future_detcost[i + 1] + live_target_mask = actual_dets_mask & model.live_masks_after[i + 1] + next_future_detcost = model.future_detcost[i + 1] for state, (total, delta) in collapsed_probs.items(): mismatch_mask = state ^ live_target_mask penalty = _detcost_penalty(mismatch_mask=mismatch_mask, future_detcost=next_future_detcost) @@ -285,7 +479,7 @@ def decode_beam_search_detcost_ranked( certified=False, margin=margin, discarded_mass=discarded_mass, - max_width=max_width, + max_width=model.max_width, elapsed_seconds=time.perf_counter() - start_time, ) return BeamDecodeResult( @@ -293,98 +487,291 @@ def decode_beam_search_detcost_ranked( certified=certified, margin=margin, discarded_mass=discarded_mass, - max_width=max_width, + max_width=model.max_width, elapsed_seconds=time.perf_counter() - start_time, ) -def run_experiment( - circuit_fname: str, - L: int, - sample_num_shots: int, - sample_seed: int | None = None, - print_per_shot: bool = False, +def _print_run_header( + *, + circuit: stim.Circuit, + args: argparse.Namespace, + num_shots: int, + log_stream, ) -> None: - circuit = stim.Circuit.from_file(circuit_fname) - shots = sample_shots(circuit, shots=sample_num_shots, seed=sample_seed) - - print(f"Running on circuit {circuit_fname}") - print(f"Total Detectors: {circuit.num_detectors}") - print(f"Sample Seed: {sample_seed}") - print(f"Num Shots: {len(shots)}") - - num_errors = 0 - num_low_confidence = 0 - num_certified = 0 - total_elapsed = 0.0 - total_triggered = 0 - max_width_seen = 0 - - for shot_index, shot in enumerate(shots): - actual_dets = set(_detectors_from_mask(shot.det_mask)) - result = decode_beam_search_detcost_ranked(circuit, actual_dets, L) - success = result.predicted_logical == shot.actual_logical if result.predicted_logical is not None else False - low_confidence = result.predicted_logical is None - - if low_confidence: - num_low_confidence += 1 - elif not success: - num_errors += 1 - if result.certified: - num_certified += 1 - - total_elapsed += result.elapsed_seconds - total_triggered += len(actual_dets) - max_width_seen = max(max_width_seen, result.max_width) - - shots_done = shot_index + 1 - resolved_shots = shots_done - num_low_confidence - error_rate_so_far = num_errors / resolved_shots if resolved_shots else 0.0 + print(f"Running on circuit {args.circuit}", file=log_stream) + print(f"Total Detectors: {circuit.num_detectors}", file=log_stream) + print(f"Total Observables: {circuit.num_observables}", file=log_stream) + if args.in_file: + print(f"Shot Input: {args.in_file}", file=log_stream) + print(f"Shot Input Format: {args.in_format}", file=log_stream) + if args.in_includes_appended_observables: + print("Observable Input: appended to --in", file=log_stream) + elif args.obs_in_file: + print(f"Observable Input: {args.obs_in_file}", file=log_stream) + print(f"Observable Format: {args.obs_in_format}", file=log_stream) + else: + print("Observable Input: none", file=log_stream) + else: + print(f"Sample Seed: {args.sample_seed}", file=log_stream) + print(f"Requested Shots: {args.sample_num_shots}", file=log_stream) + if args.shot_range_begin or args.shot_range_end: print( - f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " - f"low_conf_so_far={num_low_confidence} error_rate_so_far={error_rate_so_far:.6f} " - f"elapsed_total_seconds={total_elapsed:.6f}" + f"Shot Range: [{args.shot_range_begin}, {args.shot_range_end})", + file=log_stream, + ) + print(f"Num Shots: {num_shots}", file=log_stream) + + +def run_experiment(args: argparse.Namespace) -> ExperimentSummary: + circuit = stim.Circuit.from_file(args.circuit) + if circuit.num_observables != 1: + raise ValueError( + "This decoder currently supports exactly one logical observable, because it only tracks L0. " + f"The circuit has {circuit.num_observables} observables." ) - if print_per_shot: + model = _build_decoder_model(circuit) + log_stream = sys.stderr if args.out_file == "-" else sys.stdout + + with tempfile.TemporaryDirectory() as temp_dir: + shots = _load_shots(circuit, args, temp_dir=temp_dir) + _print_run_header(circuit=circuit, args=args, num_shots=len(shots), log_stream=log_stream) + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + num_truth_shots = 0 + num_scored_shots = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + predictions: list[bool | None] = [] + + for shot_index, shot in enumerate(shots): + result = decode_beam_search_detcost_ranked(model, shot.det_mask, args.beam) + predictions.append(result.predicted_logical) + + success: bool | None + if shot.actual_logical is None or result.predicted_logical is None: + success = None + else: + success = result.predicted_logical == shot.actual_logical + + if result.predicted_logical is None: + num_low_confidence += 1 + if shot.actual_logical is not None: + num_truth_shots += 1 + if success is not None: + num_scored_shots += 1 + if not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + triggered_dets = shot.det_mask.bit_count() + total_triggered += triggered_dets + max_width_seen = max(max_width_seen, result.max_width) + + shots_done = shot_index + 1 + error_rate_so_far = num_errors / num_scored_shots if num_scored_shots else 0.0 print( - f"shot={shot_index} triggered_detectors={len(actual_dets)} " - f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " - f"success={success} certified={result.certified} " - f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " - f"elapsed_seconds={result.elapsed_seconds:.6f}" + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} scored_shots_so_far={num_scored_shots} " + f"error_rate_so_far={error_rate_so_far:.6f} elapsed_total_seconds={total_elapsed:.6f}", + file=log_stream, ) - print(f"Beam: {L}") - print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}") - print(f"Max Width: {max_width_seen}") - print(f"Certified Shots: {num_certified}") - print(f"Low Confidence: {num_low_confidence}") - print(f"Logical Errors: {num_errors}") - print(f"Total Seconds: {total_elapsed:.6f}") - print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}") + if args.print_per_shot: + print( + f"shot={shot_index} triggered_detectors={triggered_dets} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"elapsed_seconds={result.elapsed_seconds:.6f}", + file=log_stream, + ) + + if args.out_file: + output_path, copy_to_stdout = _resolve_stdout_path_if_needed( + args.out_file, + temp_dir=temp_dir, + stem="predictions_out", + ) + prediction_data = np.zeros((len(predictions), circuit.num_observables), dtype=np.bool_) + for shot_index, predicted_logical in enumerate(predictions): + prediction_data[shot_index, 0] = bool(predicted_logical) if predicted_logical is not None else False + + if args.out_format == "ptb64" and len(prediction_data) % 64 != 0: + raise ValueError("The ptb64 format requires the number of shots to be a multiple of 64.") + + stim.write_shot_data_file( + data=prediction_data, + path=output_path, + format=args.out_format, + num_measurements=0, + num_detectors=0, + num_observables=circuit.num_observables, + ) + if copy_to_stdout: + _copy_file_to_stdout(output_path) + if num_low_confidence: + print( + f"warning: wrote {num_low_confidence} low-confidence predictions as L0=0 because Stim result " + "files can only store bits, not unknown values.", + file=log_stream, + ) + + print(f"Beam: {args.beam}", file=log_stream) + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}", file=log_stream) + print(f"Max Width: {max_width_seen}", file=log_stream) + print(f"Certified Shots: {num_certified}", file=log_stream) + print(f"Low Confidence: {num_low_confidence}", file=log_stream) + print(f"Truth-Labeled Shots: {num_truth_shots}", file=log_stream) + print(f"Scored Shots: {num_scored_shots}", file=log_stream) + if num_truth_shots: + print(f"Logical Errors: {num_errors}", file=log_stream) + else: + print("Logical Errors: n/a", file=log_stream) + print(f"Total Seconds: {total_elapsed:.6f}", file=log_stream) + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}", file=log_stream) + + return ExperimentSummary( + predictions=predictions, + num_certified=num_certified, + num_low_confidence=num_low_confidence, + num_errors=num_errors, + num_truth_shots=num_truth_shots, + num_scored_shots=num_scored_shots, + total_elapsed=total_elapsed, + total_triggered=total_triggered, + max_width_seen=max_width_seen, + ) def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Run trellis beam decoding ranked by mass minus a detcost-style future penalty." + description=( + "Run trellis beam decoding ranked by mass minus a detcost-style future penalty, " + "with Stim-compatible shot-data I/O options." + ), + allow_abbrev=False, ) parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") - parser.add_argument("--sample-num-shots", type=int, default=1, help="Number of sampled shots.") + parser.add_argument( + "--sample-num-shots", + type=int, + default=None, + help="Number of sampled shots. Defaults to 1 unless --in is provided.", + ) parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") - parser.add_argument("--print-per-shot", action="store_true", help="Print a detailed line per decoded shot.") - return parser.parse_args() + parser.add_argument( + "--shot-range-begin", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--shot-range-end", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--in", + dest="in_file", + default="", + help="File to read detection events from (use - for stdin).", + ) + parser.add_argument( + "--in-format", + "--in_format", + dest="in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--in-includes-appended-observables", + "--in_includes_appended_observables", + dest="in_includes_appended_observables", + action="store_true", + help="Assume the observable flips are appended to each shot in --in.", + ) + parser.add_argument( + "--obs-in", + "--obs_in", + dest="obs_in_file", + default="", + help="File to read observable flips from (use - for stdin).", + ) + parser.add_argument( + "--obs-in-format", + "--obs_in_format", + dest="obs_in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --obs-in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--out", + dest="out_file", + default="", + help="File to write predicted observable flips to (use - for stdout).", + ) + parser.add_argument( + "--out-format", + "--out_format", + dest="out_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file written by --out ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--print-per-shot", + action="store_true", + help="Print a detailed line per decoded shot.", + ) + args = parser.parse_args() + + if args.sample_num_shots is None: + # Preserve the original script's one-shot default while still allowing + # file input without requiring --sample-num-shots 0. + args.sample_num_shots = 0 if args.in_file else 1 + + if args.beam <= 0: + raise ValueError("--beam must be positive.") + if args.sample_num_shots < 0: + raise ValueError("--sample-num-shots must be non-negative.") + if args.sample_seed is not None and args.sample_seed < 0: + raise ValueError("--sample-seed must be non-negative.") + if args.shot_range_begin < 0 or args.shot_range_end < 0: + raise ValueError("--shot-range-begin and --shot-range-end must be non-negative.") + if args.shot_range_end < args.shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if args.in_includes_appended_observables and args.obs_in_file: + raise ValueError( + "Choose either --in-includes-appended-observables or --obs-in, not both." + ) + if args.obs_in_file and not args.in_file: + raise ValueError("Cannot load observable flips from --obs-in without also providing --in.") + if args.in_file == "-" and args.obs_in_file == "-": + raise ValueError("At most one of --in and --obs-in may read from stdin.") + + num_shot_sources = int(args.sample_num_shots > 0) + int(bool(args.in_file)) + if num_shot_sources != 1: + raise ValueError("Requires exactly one source of shots: either --sample-num-shots > 0 or --in.") + + return args if __name__ == "__main__": - args = _parse_args() - if args.sample_num_shots <= 0: - raise ValueError("--sample-num-shots must be positive.") - run_experiment( - args.circuit, - L=args.beam, - sample_num_shots=args.sample_num_shots, - sample_seed=args.sample_seed, - print_per_shot=args.print_per_shot, - ) + run_experiment(_parse_args()) diff --git a/src/py/astar/trellis_beam_iterative_forward_backward.py b/src/py/astar/trellis_beam_iterative_forward_backward.py new file mode 100644 index 0000000..6217639 --- /dev/null +++ b/src/py/astar/trellis_beam_iterative_forward_backward.py @@ -0,0 +1,1249 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import math +import shutil +import sys +import tempfile +import time +from collections import Counter +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import stim + + +STIM_RESULT_FORMATS = ("01", "b8", "r8", "ptb64", "hits", "dets") +STIM_RESULT_FORMATS_HELP = "/".join(STIM_RESULT_FORMATS) + + +@dataclass(frozen=True) +class Fault: + q: float + p: float + delta_scale: float + det_mask: int + likelihood_cost: float + + +@dataclass(frozen=True) +class DirectionalModel: + faults: tuple[Fault, ...] + retiring_masks: tuple[int, ...] + frontier_masks_after_step: tuple[int, ...] + future_detcost: tuple[tuple[float, ...], ...] + cut_after_step: tuple[int, ...] + direction_name: str + + +@dataclass(frozen=True) +class FrontierSnapshot: + local_states: tuple[int, ...] + masses: tuple[float, ...] + kept_total_mass: float + discarded_mass: float + one_masses: tuple[float, ...] + + +@dataclass(frozen=True) +class DecodePassSummary: + pass_index: int + direction: str + final_delta: float + discarded_mass: float + elapsed_seconds: float + exact_upper_hits: int + exact_upper_misses: int + marginal_tighter_count: int + opposite_selected_count: int + detcost_selected_count: int + mean_frontier_width: float + max_frontier_width: int + mean_beam_size: float + max_beam_size: int + opposite_available_steps: int + + +@dataclass(frozen=True) +class DecoderModel: + faults: tuple[Fault, ...] + forward_model: DirectionalModel + backward_model: DirectionalModel + frontier_masks_by_cut: tuple[int, ...] + frontier_detector_ids_by_cut: tuple[tuple[int, ...], ...] + frontier_global_bit_masks_by_cut: tuple[tuple[int, ...], ...] + all_possible_dets_mask: int + max_width: int + repeated_frontier_mask_count: int + max_frontier_mask_repeat: int + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + pass_summaries: tuple[DecodePassSummary, ...] + + +@dataclass(frozen=True) +class DecodingShot: + det_mask: int + actual_logical: bool | None + + +@dataclass(frozen=True) +class ExperimentSummary: + predictions: list[bool | None] + num_certified: int + num_low_confidence: int + num_errors: int + num_truth_shots: int + num_scored_shots: int + total_elapsed: float + total_triggered: int + max_width_seen: int + + +@dataclass(frozen=True) +class _BeamPassOutcome: + failed: bool + final_delta: float + discarded_mass: float + snapshots: tuple[FrontierSnapshot | None, ...] + summary: DecodePassSummary + + +def _likelihood_cost(probability: float) -> float: + if probability <= 0.0: + return math.inf + if probability >= 1.0: + return 0.0 + return -math.log(probability / (1.0 - probability)) + + +def _detectors_from_mask(mask: int) -> list[int]: + detectors: list[int] = [] + while mask: + low_bit = mask & -mask + detectors.append(low_bit.bit_length() - 1) + mask ^= low_bit + return detectors + + +def _mask_from_bool_row(row: np.ndarray) -> int: + mask = 0 + for index in np.flatnonzero(row): + mask |= 1 << int(index) + return mask + + +def _future_detcost_by_detector(faults: tuple[Fault, ...], num_detectors: int) -> tuple[tuple[float, ...], ...]: + future_detcost: list[list[float]] = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = fault.det_mask.bit_count() + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in _detectors_from_mask(fault.det_mask): + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return tuple(tuple(row) for row in future_detcost) + + +def _build_directional_model( + *, + faults_in_order: tuple[Fault, ...], + num_detectors: int, + cut_after_step: tuple[int, ...], + direction_name: str, +) -> DirectionalModel: + last_seen_index: dict[int, int] = {} + for fault_index, fault in enumerate(faults_in_order): + for det_id in _detectors_from_mask(fault.det_mask): + last_seen_index[det_id] = fault_index + + retiring_masks = [0] * len(faults_in_order) + for det_id, fault_index in last_seen_index.items(): + retiring_masks[fault_index] |= 1 << det_id + + frontier_masks_after_step = [0] * (len(faults_in_order) + 1) + active_mask = 0 + for fault_index, fault in enumerate(faults_in_order): + active_mask |= fault.det_mask + active_mask &= ~retiring_masks[fault_index] + frontier_masks_after_step[fault_index + 1] = active_mask + + return DirectionalModel( + faults=faults_in_order, + retiring_masks=tuple(retiring_masks), + frontier_masks_after_step=tuple(frontier_masks_after_step), + future_detcost=_future_detcost_by_detector(faults_in_order, num_detectors), + cut_after_step=cut_after_step, + direction_name=direction_name, + ) + + +def _build_decoder_model(circuit: stim.Circuit) -> DecoderModel: + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults: list[Fault] = [] + all_possible_dets_mask = 0 + + for inst in dem: + if inst.type != "error": + continue + + p = float(inst.args_copy()[0]) + det_mask = 0 + flip_l0 = 0 + for target in inst.targets_copy(): + if target.is_separator(): + continue + if target.is_relative_detector_id(): + det_mask ^= 1 << target.val + elif target.is_logical_observable_id() and target.val == 0: + flip_l0 ^= 1 + + faults.append( + Fault( + q=1.0 - p, + p=p, + delta_scale=(-p if flip_l0 else p), + det_mask=det_mask, + likelihood_cost=_likelihood_cost(p), + ) + ) + all_possible_dets_mask |= det_mask + + frozen_faults = tuple(faults) + num_faults = len(frozen_faults) + + forward_model = _build_directional_model( + faults_in_order=frozen_faults, + num_detectors=circuit.num_detectors, + cut_after_step=tuple(range(num_faults + 1)), + direction_name="forward", + ) + backward_model = _build_directional_model( + faults_in_order=tuple(reversed(frozen_faults)), + num_detectors=circuit.num_detectors, + cut_after_step=tuple(num_faults - step for step in range(num_faults + 1)), + direction_name="backward", + ) + + for cut in range(num_faults + 1): + forward_mask = forward_model.frontier_masks_after_step[cut] + backward_mask = backward_model.frontier_masks_after_step[num_faults - cut] + if forward_mask != backward_mask: + raise ValueError( + "Internal frontier alignment check failed: the forward and backward cuts did not produce the " + f"same frontier detector set at cut {cut}." + ) + + frontier_masks_by_cut = forward_model.frontier_masks_after_step + frontier_detector_ids_by_cut = tuple( + tuple(_detectors_from_mask(mask)) for mask in frontier_masks_by_cut + ) + frontier_global_bit_masks_by_cut = tuple( + tuple(1 << det_id for det_id in detector_ids) + for detector_ids in frontier_detector_ids_by_cut + ) + repeated_frontier_counts = Counter(frontier_masks_by_cut) + repeated_frontier_mask_count = sum(1 for count in repeated_frontier_counts.values() if count > 1) + max_frontier_mask_repeat = max(repeated_frontier_counts.values(), default=0) + max_width = max((mask.bit_count() for mask in frontier_masks_by_cut), default=0) + + return DecoderModel( + faults=frozen_faults, + forward_model=forward_model, + backward_model=backward_model, + frontier_masks_by_cut=frontier_masks_by_cut, + frontier_detector_ids_by_cut=frontier_detector_ids_by_cut, + frontier_global_bit_masks_by_cut=frontier_global_bit_masks_by_cut, + all_possible_dets_mask=all_possible_dets_mask, + max_width=max_width, + repeated_frontier_mask_count=repeated_frontier_mask_count, + max_frontier_mask_repeat=max_frontier_mask_repeat, + ) + + +def _detcost_penalty(mismatch_mask: int, future_detcost: tuple[float, ...]) -> float: + total = 0.0 + pending = mismatch_mask + + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + + best = future_detcost[detector] + if best == math.inf: + return math.inf + total += best + + return total + + +def _compress_global_state_to_local_state(global_state: int, global_bit_masks: tuple[int, ...]) -> int: + local_state = 0 + for local_index, global_bit in enumerate(global_bit_masks): + if global_state & global_bit: + local_state |= 1 << local_index + return local_state + + +def _record_frontier_snapshot( + *, + model: DecoderModel, + cut: int, + beam: list[tuple[int, float, float]], + discarded_mass: float, +) -> FrontierSnapshot: + global_bit_masks = model.frontier_global_bit_masks_by_cut[cut] + one_masses = [0.0] * len(global_bit_masks) + local_states: list[int] = [] + masses: list[float] = [] + + for state, total, _ in beam: + local_state = 0 + for local_index, global_bit in enumerate(global_bit_masks): + if state & global_bit: + local_state |= 1 << local_index + one_masses[local_index] += total + local_states.append(local_state) + masses.append(total) + + return FrontierSnapshot( + local_states=tuple(local_states), + masses=tuple(masses), + kept_total_mass=sum(masses), + discarded_mass=discarded_mass, + one_masses=tuple(one_masses), + ) + + +def _opposite_pass_cost_lower_bound( + *, + current_state: int, + live_target_mask: int, + snapshot: FrontierSnapshot, + snapshot_lookup: dict[int, float], + frontier_global_bit_masks: tuple[int, ...], +) -> tuple[float, bool, bool]: + if not frontier_global_bit_masks: + return 0.0, False, False + + compatible_other_state = current_state ^ live_target_mask + local_compatible_state = _compress_global_state_to_local_state( + global_state=compatible_other_state, + global_bit_masks=frontier_global_bit_masks, + ) + + # The opposite pass only records lower bounds on the surviving frontier-state + # masses, together with an upper bound on all omitted mass. Therefore + # exact_mass + discarded_mass is still an admissible upper bound on the true + # compatible-state mass, and -log(upper_bound) is an admissible lower bound + # on the remaining cost. + missing_upper_bound = min(1.0, max(0.0, snapshot.discarded_mass)) + exact_mass = snapshot_lookup.get(local_compatible_state) + exact_hit = exact_mass is not None + if exact_mass is None: + exact_upper_bound = missing_upper_bound + else: + exact_upper_bound = min(1.0, exact_mass + missing_upper_bound) + + kept_total_mass = min(1.0, max(0.0, snapshot.kept_total_mass)) + marginal_upper_bound = 1.0 + for local_index, observed_one_mass in enumerate(snapshot.one_masses): + if (local_compatible_state >> local_index) & 1: + upper_bound = min(1.0, observed_one_mass + missing_upper_bound) + else: + observed_zero_mass = max(0.0, kept_total_mass - observed_one_mass) + upper_bound = min(1.0, observed_zero_mass + missing_upper_bound) + if upper_bound < marginal_upper_bound: + marginal_upper_bound = upper_bound + + used_marginal_bound = marginal_upper_bound < exact_upper_bound + compatible_upper_bound = min(exact_upper_bound, marginal_upper_bound) + if compatible_upper_bound <= 0.0: + return math.inf, exact_hit, used_marginal_bound + return -math.log(compatible_upper_bound), exact_hit, used_marginal_bound + + +def _run_beam_pass( + *, + model: DecoderModel, + directional_model: DirectionalModel, + actual_dets_mask: int, + L: int, + pass_index: int, + opposite_snapshots: tuple[FrontierSnapshot | None, ...] | None, +) -> _BeamPassOutcome: + pass_start_time = time.perf_counter() + beam: list[tuple[int, float, float]] = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + num_faults = len(directional_model.faults) + snapshots: list[FrontierSnapshot | None] = [None] * (num_faults + 1) + initial_cut = directional_model.cut_after_step[0] + snapshots[initial_cut] = _record_frontier_snapshot( + model=model, + cut=initial_cut, + beam=beam, + discarded_mass=discarded_mass, + ) + + exact_upper_hits = 0 + exact_upper_misses = 0 + marginal_tighter_count = 0 + opposite_selected_count = 0 + detcost_selected_count = 0 + opposite_available_steps = 0 + frontier_width_total = len(model.frontier_detector_ids_by_cut[initial_cut]) + frontier_width_steps = 1 + max_frontier_width = len(model.frontier_detector_ids_by_cut[initial_cut]) + beam_size_total = len(beam) + beam_size_steps = 1 + max_beam_size = len(beam) + + for fault_index, fault in enumerate(directional_model.faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = directional_model.retiring_masks[fault_index] + + if retiring_mask == 0: + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + total_mass += absent_total + entry = collapsed_probs.get(state) + if entry is None: + collapsed_probs[state] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + total_mass += present_total + entry = collapsed_probs.get(toggled) + if entry is None: + collapsed_probs[toggled] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + if (state & retiring_mask) == expected_bits: + shrunk = state & keep_mask + total_mass += absent_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + if (toggled & retiring_mask) == expected_bits: + shrunk = toggled & keep_mask + total_mass += present_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + if total_mass == 0.0: + summary = DecodePassSummary( + pass_index=pass_index, + direction=directional_model.direction_name, + final_delta=0.0, + discarded_mass=discarded_mass, + elapsed_seconds=time.perf_counter() - pass_start_time, + exact_upper_hits=exact_upper_hits, + exact_upper_misses=exact_upper_misses, + marginal_tighter_count=marginal_tighter_count, + opposite_selected_count=opposite_selected_count, + detcost_selected_count=detcost_selected_count, + mean_frontier_width=frontier_width_total / max(1, frontier_width_steps), + max_frontier_width=max_frontier_width, + mean_beam_size=beam_size_total / max(1, beam_size_steps), + max_beam_size=max_beam_size, + opposite_available_steps=opposite_available_steps, + ) + return _BeamPassOutcome( + failed=True, + final_delta=0.0, + discarded_mass=discarded_mass, + snapshots=tuple(snapshots), + summary=summary, + ) + + next_cut = directional_model.cut_after_step[fault_index + 1] + frontier_mask = directional_model.frontier_masks_after_step[fault_index + 1] + live_target_mask = actual_dets_mask & frontier_mask + next_future_detcost = directional_model.future_detcost[fault_index + 1] + frontier_global_bit_masks = model.frontier_global_bit_masks_by_cut[next_cut] + + opposite_snapshot = None if opposite_snapshots is None else opposite_snapshots[next_cut] + opposite_lookup: dict[int, float] | None = None + if opposite_snapshot is not None: + opposite_available_steps += 1 + opposite_lookup = {state: mass for state, mass in zip(opposite_snapshot.local_states, opposite_snapshot.masses)} + + ranked_states: list[tuple[float, float, int, float]] = [] + for state, (total, delta) in collapsed_probs.items(): + mismatch_mask = state ^ live_target_mask + heuristic_cost = _detcost_penalty( + mismatch_mask=mismatch_mask, + future_detcost=next_future_detcost, + ) + if opposite_snapshot is not None and opposite_lookup is not None: + opposite_cost, exact_hit, used_marginal_bound = _opposite_pass_cost_lower_bound( + current_state=state, + live_target_mask=live_target_mask, + snapshot=opposite_snapshot, + snapshot_lookup=opposite_lookup, + frontier_global_bit_masks=frontier_global_bit_masks, + ) + if exact_hit: + exact_upper_hits += 1 + else: + exact_upper_misses += 1 + if used_marginal_bound: + marginal_tighter_count += 1 + if opposite_cost > heuristic_cost: + heuristic_cost = opposite_cost + opposite_selected_count += 1 + else: + detcost_selected_count += 1 + else: + detcost_selected_count += 1 + + if heuristic_cost == math.inf: + rank_score = -math.inf + else: + rank_score = math.log(total) - heuristic_cost + ranked_states.append((rank_score, total, state, delta)) + + dropped_mass = 0.0 + if len(ranked_states) > L: + ranked_states.sort(reverse=True) + kept = ranked_states[:L] + beam = [(state, total, delta) for _, total, state, delta in kept] + kept_mass = sum(total for _, total, _, _ in kept) + dropped_mass = total_mass - kept_mass + else: + beam = [(state, total, delta) for _, total, state, delta in ranked_states] + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + snapshots[next_cut] = _record_frontier_snapshot( + model=model, + cut=next_cut, + beam=beam, + discarded_mass=discarded_mass, + ) + + frontier_width = len(model.frontier_detector_ids_by_cut[next_cut]) + frontier_width_total += frontier_width + frontier_width_steps += 1 + max_frontier_width = max(max_frontier_width, frontier_width) + beam_size = len(beam) + beam_size_total += beam_size + beam_size_steps += 1 + max_beam_size = max(max_beam_size, beam_size) + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + summary = DecodePassSummary( + pass_index=pass_index, + direction=directional_model.direction_name, + final_delta=final_delta, + discarded_mass=discarded_mass, + elapsed_seconds=time.perf_counter() - pass_start_time, + exact_upper_hits=exact_upper_hits, + exact_upper_misses=exact_upper_misses, + marginal_tighter_count=marginal_tighter_count, + opposite_selected_count=opposite_selected_count, + detcost_selected_count=detcost_selected_count, + mean_frontier_width=frontier_width_total / max(1, frontier_width_steps), + max_frontier_width=max_frontier_width, + mean_beam_size=beam_size_total / max(1, beam_size_steps), + max_beam_size=max_beam_size, + opposite_available_steps=opposite_available_steps, + ) + return _BeamPassOutcome( + failed=False, + final_delta=final_delta, + discarded_mass=discarded_mass, + snapshots=tuple(snapshots), + summary=summary, + ) + + +def decode_beam_search_iterative( + model: DecoderModel, + actual_dets_mask: int, + L: int, + *, + num_passes: int, +) -> BeamDecodeResult: + start_time = time.perf_counter() + + if (actual_dets_mask & ~model.all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + pass_summaries=(), + ) + + if num_passes <= 0: + raise ValueError("num_passes must be positive.") + + opposite_snapshots: tuple[FrontierSnapshot | None, ...] | None = None + pass_summaries: list[DecodePassSummary] = [] + last_outcome: _BeamPassOutcome | None = None + + for pass_offset in range(num_passes): + pass_index = pass_offset + 1 + directional_model = model.forward_model if pass_offset % 2 == 0 else model.backward_model + last_outcome = _run_beam_pass( + model=model, + directional_model=directional_model, + actual_dets_mask=actual_dets_mask, + L=L, + pass_index=pass_index, + opposite_snapshots=opposite_snapshots, + ) + pass_summaries.append(last_outcome.summary) + if last_outcome.failed: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=last_outcome.discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + pass_summaries=tuple(pass_summaries), + ) + opposite_snapshots = last_outcome.snapshots + + assert last_outcome is not None + final_delta = last_outcome.final_delta + margin = abs(final_delta) + certified = margin > last_outcome.discarded_mass + + if final_delta == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=last_outcome.discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + pass_summaries=tuple(pass_summaries), + ) + return BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=last_outcome.discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + pass_summaries=tuple(pass_summaries), + ) + + +def decode_beam_search_detcost_ranked( + model: DecoderModel, + actual_dets_mask: int, + L: int, +) -> BeamDecodeResult: + return decode_beam_search_iterative( + model=model, + actual_dets_mask=actual_dets_mask, + L=L, + num_passes=1, + ) + + +def _as_bool_2d(data: np.ndarray, *, expected_cols: int, description: str) -> np.ndarray: + arr = np.asarray(data) + if arr.ndim != 2: + raise ValueError(f"Expected {description} to be a 2D array but got shape {arr.shape!r}.") + if arr.shape[1] != expected_cols: + raise ValueError( + f"Expected {description} to have {expected_cols} columns but got {arr.shape[1]}." + ) + if arr.dtype != np.bool_: + arr = arr.astype(np.bool_, copy=False) + return arr + + +def _sample_shot_arrays( + circuit: stim.Circuit, + *, + shots: int, + seed: int | None, +) -> tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets, obs = sampler.sample(shots=shots, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=circuit.num_detectors, description="sampled detector data"), + _as_bool_2d(obs, expected_cols=circuit.num_observables, description="sampled observable data"), + ) + + +def _read_detector_shot_arrays( + *, + path: str, + fmt: str, + num_detectors: int, + num_observables: int, +) -> tuple[np.ndarray, np.ndarray | None]: + common_kwargs = dict( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=num_detectors, + num_observables=num_observables, + ) + + if num_observables: + try: + dets, obs = stim.read_shot_data_file(**common_kwargs, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=num_detectors, description="input detector data"), + _as_bool_2d(obs, expected_cols=num_observables, description="appended observable data"), + ) + except TypeError: + flat = stim.read_shot_data_file(**common_kwargs) + flat = _as_bool_2d( + flat, + expected_cols=num_detectors + num_observables, + description="combined detector/observable input data", + ) + return flat[:, :num_detectors], flat[:, num_detectors:] + + flat = stim.read_shot_data_file(**common_kwargs) + return _as_bool_2d(flat, expected_cols=num_detectors, description="input detector data"), None + + +def _read_observable_shot_array(*, path: str, fmt: str, num_observables: int) -> np.ndarray: + obs = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=0, + num_observables=num_observables, + ) + return _as_bool_2d(obs, expected_cols=num_observables, description="observable input data") + + +def _apply_shot_range( + dets: np.ndarray, + obs: np.ndarray | None, + *, + shot_range_begin: int, + shot_range_end: int, +) -> tuple[np.ndarray, np.ndarray | None]: + if not (shot_range_begin or shot_range_end): + return dets, obs + + if shot_range_end < shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if shot_range_end > len(dets): + raise ValueError( + f"Shot range end {shot_range_end} is past the end of the shot data (size {len(dets)})." + ) + + dets = dets[shot_range_begin:shot_range_end] + if obs is not None: + obs = obs[shot_range_begin:shot_range_end] + return dets, obs + + +def _shots_from_arrays(dets: np.ndarray, obs: np.ndarray | None) -> list[DecodingShot]: + shots: list[DecodingShot] = [] + for shot_index in range(dets.shape[0]): + actual_logical = None if obs is None else bool(obs[shot_index, 0]) + shots.append( + DecodingShot( + det_mask=_mask_from_bool_row(dets[shot_index]), + actual_logical=actual_logical, + ) + ) + return shots + + +def _resolve_stdin_path_if_needed(path: str, *, temp_dir: str, stem: str) -> str: + if path != "-": + return path + temp_path = str(Path(temp_dir) / f"{stem}.bin") + with open(temp_path, "wb") as f: + f.write(sys.stdin.buffer.read()) + return temp_path + + +def _resolve_stdout_path_if_needed(path: str, *, temp_dir: str, stem: str) -> tuple[str, bool]: + if path != "-": + return path, False + return str(Path(temp_dir) / f"{stem}.bin"), True + + +def _copy_file_to_stdout(path: str) -> None: + sys.stdout.flush() + with open(path, "rb") as f: + shutil.copyfileobj(f, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def _load_shots( + circuit: stim.Circuit, + args: argparse.Namespace, + *, + temp_dir: str, +) -> list[DecodingShot]: + if args.in_file: + in_path = _resolve_stdin_path_if_needed(args.in_file, temp_dir=temp_dir, stem="shots_in") + appended_obs_count = circuit.num_observables if args.in_includes_appended_observables else 0 + dets, obs = _read_detector_shot_arrays( + path=in_path, + fmt=args.in_format, + num_detectors=circuit.num_detectors, + num_observables=appended_obs_count, + ) + + if args.obs_in_file: + obs_in_path = _resolve_stdin_path_if_needed(args.obs_in_file, temp_dir=temp_dir, stem="obs_in") + obs = _read_observable_shot_array( + path=obs_in_path, + fmt=args.obs_in_format, + num_observables=circuit.num_observables, + ) + if len(obs) != len(dets): + raise ValueError("Observable input ended before, or after, the detector shot data.") + else: + dets, obs = _sample_shot_arrays(circuit, shots=args.sample_num_shots, seed=args.sample_seed) + + dets, obs = _apply_shot_range( + dets, + obs, + shot_range_begin=args.shot_range_begin, + shot_range_end=args.shot_range_end, + ) + return _shots_from_arrays(dets, obs) + + +def _print_run_header( + *, + circuit: stim.Circuit, + model: DecoderModel, + args: argparse.Namespace, + num_shots: int, + log_stream, +) -> None: + print(f"Running on circuit {args.circuit}", file=log_stream) + print(f"Total Detectors: {circuit.num_detectors}", file=log_stream) + print(f"Total Observables: {circuit.num_observables}", file=log_stream) + print(f"Beam: {args.beam}", file=log_stream) + print(f"Num Passes: {args.num_passes}", file=log_stream) + print("Frontier Matching: keyed by cut index (forward/backward cut check verified)", file=log_stream) + if args.num_passes > 1: + print( + f"Repeated Frontiers: {model.repeated_frontier_mask_count} repeated detector-set masks; " + f"max repeat={model.max_frontier_mask_repeat}", + file=log_stream, + ) + if args.in_file: + print(f"Shot Input: {args.in_file}", file=log_stream) + print(f"Shot Input Format: {args.in_format}", file=log_stream) + if args.in_includes_appended_observables: + print("Observable Input: appended to --in", file=log_stream) + elif args.obs_in_file: + print(f"Observable Input: {args.obs_in_file}", file=log_stream) + print(f"Observable Format: {args.obs_in_format}", file=log_stream) + else: + print("Observable Input: none", file=log_stream) + else: + print(f"Sample Seed: {args.sample_seed}", file=log_stream) + print(f"Requested Shots: {args.sample_num_shots}", file=log_stream) + if args.shot_range_begin or args.shot_range_end: + print( + f"Shot Range: [{args.shot_range_begin}, {args.shot_range_end})", + file=log_stream, + ) + print(f"Num Shots: {num_shots}", file=log_stream) + + +def _print_pass_summary( + *, + pass_summary: DecodePassSummary, + log_stream, +) -> None: + print( + f" pass={pass_summary.pass_index} direction={pass_summary.direction} " + f"final_delta={pass_summary.final_delta:.6e} discarded_mass={pass_summary.discarded_mass:.6e} " + f"exact_hits={pass_summary.exact_upper_hits} exact_misses={pass_summary.exact_upper_misses} " + f"marginal_tighter={pass_summary.marginal_tighter_count} " + f"opp_selected={pass_summary.opposite_selected_count} detcost_selected={pass_summary.detcost_selected_count} " + f"opp_steps={pass_summary.opposite_available_steps} " + f"mean_frontier_width={pass_summary.mean_frontier_width:.2f} max_frontier_width={pass_summary.max_frontier_width} " + f"mean_beam={pass_summary.mean_beam_size:.2f} max_beam={pass_summary.max_beam_size} " + f"elapsed_seconds={pass_summary.elapsed_seconds:.6f}", + file=log_stream, + ) + + +def run_experiment(args: argparse.Namespace) -> ExperimentSummary: + circuit = stim.Circuit.from_file(args.circuit) + if circuit.num_observables != 1: + raise ValueError( + "This decoder currently supports exactly one logical observable, because it only tracks L0. " + f"The circuit has {circuit.num_observables} observables." + ) + + model = _build_decoder_model(circuit) + log_stream = sys.stderr if args.out_file == "-" else sys.stdout + + with tempfile.TemporaryDirectory() as temp_dir: + shots = _load_shots(circuit, args, temp_dir=temp_dir) + _print_run_header(circuit=circuit, model=model, args=args, num_shots=len(shots), log_stream=log_stream) + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + num_truth_shots = 0 + num_scored_shots = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + predictions: list[bool | None] = [] + + pass_aggregates = [ + { + "count": 0, + "elapsed_seconds": 0.0, + "exact_upper_hits": 0, + "exact_upper_misses": 0, + "marginal_tighter_count": 0, + "opposite_selected_count": 0, + "detcost_selected_count": 0, + "opposite_available_steps": 0, + "mean_frontier_width_sum": 0.0, + "max_frontier_width": 0, + "mean_beam_size_sum": 0.0, + "max_beam_size": 0, + } + for _ in range(args.num_passes) + ] + + for shot_index, shot in enumerate(shots): + result = decode_beam_search_iterative( + model, + shot.det_mask, + args.beam, + num_passes=args.num_passes, + ) + predictions.append(result.predicted_logical) + + success: bool | None + if shot.actual_logical is None or result.predicted_logical is None: + success = None + else: + success = result.predicted_logical == shot.actual_logical + + if result.predicted_logical is None: + num_low_confidence += 1 + if shot.actual_logical is not None: + num_truth_shots += 1 + if success is not None: + num_scored_shots += 1 + if not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + triggered_dets = shot.det_mask.bit_count() + total_triggered += triggered_dets + max_width_seen = max(max_width_seen, result.max_width) + + for pass_summary in result.pass_summaries: + agg = pass_aggregates[pass_summary.pass_index - 1] + agg["count"] += 1 + agg["elapsed_seconds"] += pass_summary.elapsed_seconds + agg["exact_upper_hits"] += pass_summary.exact_upper_hits + agg["exact_upper_misses"] += pass_summary.exact_upper_misses + agg["marginal_tighter_count"] += pass_summary.marginal_tighter_count + agg["opposite_selected_count"] += pass_summary.opposite_selected_count + agg["detcost_selected_count"] += pass_summary.detcost_selected_count + agg["opposite_available_steps"] += pass_summary.opposite_available_steps + agg["mean_frontier_width_sum"] += pass_summary.mean_frontier_width + agg["max_frontier_width"] = max(agg["max_frontier_width"], pass_summary.max_frontier_width) + agg["mean_beam_size_sum"] += pass_summary.mean_beam_size + agg["max_beam_size"] = max(agg["max_beam_size"], pass_summary.max_beam_size) + + shots_done = shot_index + 1 + error_rate_so_far = num_errors / num_scored_shots if num_scored_shots else 0.0 + print( + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} scored_shots_so_far={num_scored_shots} " + f"error_rate_so_far={error_rate_so_far:.6f} elapsed_total_seconds={total_elapsed:.6f}", + file=log_stream, + ) + + if args.print_per_shot: + print( + f"shot={shot_index} triggered_detectors={triggered_dets} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"elapsed_seconds={result.elapsed_seconds:.6f}", + file=log_stream, + ) + for pass_summary in result.pass_summaries: + _print_pass_summary(pass_summary=pass_summary, log_stream=log_stream) + + if args.out_file: + output_path, copy_to_stdout = _resolve_stdout_path_if_needed( + args.out_file, + temp_dir=temp_dir, + stem="predictions_out", + ) + prediction_data = np.zeros((len(predictions), circuit.num_observables), dtype=np.bool_) + for shot_index, predicted_logical in enumerate(predictions): + prediction_data[shot_index, 0] = bool(predicted_logical) if predicted_logical is not None else False + + if args.out_format == "ptb64" and len(prediction_data) % 64 != 0: + raise ValueError("The ptb64 format requires the number of shots to be a multiple of 64.") + + stim.write_shot_data_file( + data=prediction_data, + path=output_path, + format=args.out_format, + num_measurements=0, + num_detectors=0, + num_observables=circuit.num_observables, + ) + if copy_to_stdout: + _copy_file_to_stdout(output_path) + if num_low_confidence: + print( + f"warning: wrote {num_low_confidence} low-confidence predictions as L0=0 because Stim result " + "files can only store bits, not unknown values.", + file=log_stream, + ) + + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}", file=log_stream) + print(f"Max Width: {max_width_seen}", file=log_stream) + print(f"Certified Shots: {num_certified}", file=log_stream) + print(f"Low Confidence: {num_low_confidence}", file=log_stream) + print(f"Truth-Labeled Shots: {num_truth_shots}", file=log_stream) + print(f"Scored Shots: {num_scored_shots}", file=log_stream) + if num_truth_shots: + print(f"Logical Errors: {num_errors}", file=log_stream) + else: + print("Logical Errors: n/a", file=log_stream) + print(f"Total Seconds: {total_elapsed:.6f}", file=log_stream) + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}", file=log_stream) + + print("Pass Diagnostics:", file=log_stream) + for pass_index, agg in enumerate(pass_aggregates, start=1): + if agg["count"] == 0: + continue + print( + f" pass={pass_index} direction={'forward' if pass_index % 2 == 1 else 'backward'} " + f"mean_elapsed_seconds={agg['elapsed_seconds'] / agg['count']:.6f} " + f"exact_hits={agg['exact_upper_hits']} exact_misses={agg['exact_upper_misses']} " + f"marginal_tighter={agg['marginal_tighter_count']} " + f"opp_selected={agg['opposite_selected_count']} detcost_selected={agg['detcost_selected_count']} " + f"opp_steps={agg['opposite_available_steps']} " + f"mean_frontier_width={agg['mean_frontier_width_sum'] / agg['count']:.2f} " + f"max_frontier_width={agg['max_frontier_width']} " + f"mean_beam={agg['mean_beam_size_sum'] / agg['count']:.2f} " + f"max_beam={agg['max_beam_size']}", + file=log_stream, + ) + + return ExperimentSummary( + predictions=predictions, + num_certified=num_certified, + num_low_confidence=num_low_confidence, + num_errors=num_errors, + num_truth_shots=num_truth_shots, + num_scored_shots=num_scored_shots, + total_elapsed=total_elapsed, + total_triggered=total_triggered, + max_width_seen=max_width_seen, + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run trellis beam decoding with detcost-style future penalties and optional iterative " + "forward/backward cross-pass frontier-mass heuristics, with Stim-compatible shot-data I/O options." + ), + allow_abbrev=False, + ) + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument( + "--num-passes", + "--num_passes", + dest="num_passes", + type=int, + default=1, + help=( + "Number of alternating beam passes to run. Pass 1 is the original forward pass, pass 2 is backward, " + "pass 3 is forward again, and so on." + ), + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=None, + help="Number of sampled shots. Defaults to 1 unless --in is provided.", + ) + parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") + parser.add_argument( + "--shot-range-begin", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--shot-range-end", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--in", + dest="in_file", + default="", + help="File to read detection events from (use - for stdin).", + ) + parser.add_argument( + "--in-format", + "--in_format", + dest="in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--in-includes-appended-observables", + "--in_includes_appended_observables", + dest="in_includes_appended_observables", + action="store_true", + help="Assume the observable flips are appended to each shot in --in.", + ) + parser.add_argument( + "--obs-in", + "--obs_in", + dest="obs_in_file", + default="", + help="File to read observable flips from (use - for stdin).", + ) + parser.add_argument( + "--obs-in-format", + "--obs_in_format", + dest="obs_in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --obs-in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--out", + dest="out_file", + default="", + help="File to write predicted observable flips to (use - for stdout).", + ) + parser.add_argument( + "--out-format", + "--out_format", + dest="out_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file written by --out ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--print-per-shot", + action="store_true", + help="Print a detailed line per decoded shot, plus one line per decoding pass.", + ) + args = parser.parse_args() + + if args.sample_num_shots is None: + # Preserve the original script's one-shot default while still allowing + # file input without requiring --sample-num-shots 0. + args.sample_num_shots = 0 if args.in_file else 1 + + if args.beam <= 0: + raise ValueError("--beam must be positive.") + if args.num_passes <= 0: + raise ValueError("--num-passes must be positive.") + if args.sample_num_shots < 0: + raise ValueError("--sample-num-shots must be non-negative.") + if args.sample_seed is not None and args.sample_seed < 0: + raise ValueError("--sample-seed must be non-negative.") + if args.shot_range_begin < 0 or args.shot_range_end < 0: + raise ValueError("--shot-range-begin and --shot-range-end must be non-negative.") + if args.shot_range_end < args.shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if args.in_includes_appended_observables and args.obs_in_file: + raise ValueError( + "Choose either --in-includes-appended-observables or --obs-in, not both." + ) + if args.obs_in_file and not args.in_file: + raise ValueError("Cannot load observable flips from --obs-in without also providing --in.") + if args.in_file == "-" and args.obs_in_file == "-": + raise ValueError("At most one of --in and --obs-in may read from stdin.") + + num_shot_sources = int(args.sample_num_shots > 0) + int(bool(args.in_file)) + if num_shot_sources != 1: + raise ValueError("Requires exactly one source of shots: either --sample-num-shots > 0 or --in.") + + return args + + +if __name__ == "__main__": + run_experiment(_parse_args()) diff --git a/src/py/astar/trellis_beam_opt_singleton_lp_ranked.py b/src/py/astar/trellis_beam_opt_singleton_lp_ranked.py new file mode 100644 index 0000000..634e465 --- /dev/null +++ b/src/py/astar/trellis_beam_opt_singleton_lp_ranked.py @@ -0,0 +1,1218 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import math +import shutil +import sys +import tempfile +import time +from bisect import bisect_left +from collections import OrderedDict +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +import numpy as np + +try: # pragma: no cover - optional at runtime in this environment. + import stim # type: ignore +except ModuleNotFoundError: # pragma: no cover - exercised when Stim is unavailable. + stim = None + +try: # pragma: no cover - optional at runtime. + from scipy.optimize import linprog # type: ignore + from scipy.sparse import csr_matrix # type: ignore +except ModuleNotFoundError: # pragma: no cover - exercised if SciPy is unavailable. + linprog = None + csr_matrix = None + + +STIM_RESULT_FORMATS = ("01", "b8", "r8", "ptb64", "hits", "dets") +STIM_RESULT_FORMATS_HELP = "/".join(STIM_RESULT_FORMATS) +INF = float("inf") + + +@dataclass(frozen=True) +class Fault: + q: float + p: float + delta_scale: float + det_mask: int + detector_ids: tuple[int, ...] + likelihood_cost: float + + +@dataclass(frozen=True) +class DecoderModel: + faults: tuple[Fault, ...] + retiring_masks: tuple[int, ...] + live_masks_after: tuple[int, ...] + plain_future_detcost: tuple[tuple[float, ...], ...] + detector_to_faults: tuple[tuple[int, ...], ...] + all_possible_dets_mask: int + max_width: int + num_detectors: int + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + max_width: int + elapsed_seconds: float + heuristic_calls: int = 0 + cache_hits: int = 0 + lp_calls: int = 0 + lp_seconds: float = 0.0 + + +@dataclass(frozen=True) +class DecodingShot: + det_mask: int + actual_logical: bool | None + + +@dataclass(frozen=True) +class ExperimentSummary: + predictions: list[bool | None] + num_certified: int + num_low_confidence: int + num_errors: int + num_truth_shots: int + num_scored_shots: int + total_elapsed: float + total_triggered: int + max_width_seen: int + total_heuristic_calls: int + total_cache_hits: int + total_lp_calls: int + total_lp_seconds: float + + +@dataclass +class ShotSingletonLPContext: + row_index: int + detector_fault_offsets: list[int] + seen_fault_marks: list[int] + current_mark: int = 0 + + def next_mark(self) -> int: + self.current_mark += 1 + if self.current_mark >= (1 << 60): + self.seen_fault_marks[:] = [0] * len(self.seen_fault_marks) + self.current_mark = 1 + return self.current_mark + + +class UnionFind: + def __init__(self, n: int) -> None: + self.parent = list(range(n)) + self.rank = [0] * n + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] + x = self.parent[x] + return x + + def union(self, a: int, b: int) -> None: + ra = self.find(a) + rb = self.find(b) + if ra == rb: + return + if self.rank[ra] < self.rank[rb]: + self.parent[ra] = rb + elif self.rank[ra] > self.rank[rb]: + self.parent[rb] = ra + else: + self.parent[rb] = ra + self.rank[ra] += 1 + + +class OptimalSingletonLPEvaluator: + """Evaluates the exact singleton-budget LP on a suffix of future faults. + + The dual LP is + maximize sum_d y_d + subject to sum_{d in support(e) ∩ M} y_d <= w_e for each future fault e + y_d >= 0 + where M is the current residual live-detector mismatch mask. + + Results are cached across shots by (suffix_row, mismatch_mask). Within one shot, + the suffix row advances monotonically, so per-detector pointers into the future + fault lists can be updated incrementally instead of re-bisecting each time. + """ + + def __init__( + self, + model: DecoderModel, + *, + use_cache: bool = True, + cache_max_entries: int = 0, + split_components: bool = True, + ) -> None: + self.model = model + self.use_cache = use_cache + self.cache_max_entries = cache_max_entries + self.split_components = split_components + self.cache: OrderedDict[tuple[int, int], float] = OrderedDict() + self.heuristic_calls = 0 + self.cache_hits = 0 + self.lp_calls = 0 + self.lp_seconds = 0.0 + + def clear_cache(self) -> None: + self.cache.clear() + + def begin_shot(self) -> ShotSingletonLPContext: + return ShotSingletonLPContext( + row_index=0, + detector_fault_offsets=[0] * self.model.num_detectors, + seen_fault_marks=[0] * len(self.model.faults), + ) + + def advance_past_fault(self, context: ShotSingletonLPContext, fault_index: int) -> None: + context.row_index = fault_index + 1 + target_row = context.row_index + fault = self.model.faults[fault_index] + for detector in fault.detector_ids: + future_faults = self.model.detector_to_faults[detector] + pos = context.detector_fault_offsets[detector] + while pos < len(future_faults) and future_faults[pos] < target_row: + pos += 1 + context.detector_fault_offsets[detector] = pos + + def evaluate(self, context: ShotSingletonLPContext, mismatch_mask: int) -> float: + self.heuristic_calls += 1 + + if mismatch_mask == 0: + return 0.0 + + cache_key = (context.row_index, mismatch_mask) + if self.use_cache: + cached = self.cache.get(cache_key) + if cached is not None: + self.cache_hits += 1 + self.cache.move_to_end(cache_key) + return cached + + if linprog is None or csr_matrix is None: + raise RuntimeError( + "The exact singleton-LP heuristic requires SciPy (scipy.optimize.linprog)." + ) + + mark = context.next_mark() + seen_fault_marks = context.seen_fault_marks + support_to_weight: dict[int, float] = {} + covered_mask = 0 + + for detector in _detectors_from_mask(mismatch_mask): + future_faults = self.model.detector_to_faults[detector] + start = context.detector_fault_offsets[detector] + for fault_index in future_faults[start:]: + if seen_fault_marks[fault_index] == mark: + continue + seen_fault_marks[fault_index] = mark + + fault = self.model.faults[fault_index] + support_mask = fault.det_mask & mismatch_mask + if support_mask == 0: + continue + covered_mask |= support_mask + previous = support_to_weight.get(support_mask) + if previous is None or fault.likelihood_cost < previous: + support_to_weight[support_mask] = fault.likelihood_cost + + if covered_mask != mismatch_mask: + return self._store(cache_key, INF) + + if len(support_to_weight) == 1: + only_value = next(iter(support_to_weight.values())) + return self._store(cache_key, only_value) + + if mismatch_mask.bit_count() == 1: + best = min(support_to_weight.values()) + return self._store(cache_key, best) + + start_time = time.perf_counter() + value = self._solve_support_system(support_to_weight=support_to_weight, mismatch_mask=mismatch_mask) + self.lp_seconds += time.perf_counter() - start_time + return self._store(cache_key, value) + + def _store(self, cache_key: tuple[int, int], value: float) -> float: + if self.use_cache: + self.cache[cache_key] = value + self.cache.move_to_end(cache_key) + if self.cache_max_entries > 0: + while len(self.cache) > self.cache_max_entries: + self.cache.popitem(last=False) + return value + + def _solve_support_system(self, *, support_to_weight: dict[int, float], mismatch_mask: int) -> float: + active_detectors = _detectors_from_mask(mismatch_mask) + if not active_detectors: + return 0.0 + + detector_index = {detector: i for i, detector in enumerate(active_detectors)} + + if not self.split_components: + return self._solve_component_lp( + supports=tuple(support_to_weight.items()), + detector_index=detector_index, + component_detectors=tuple(active_detectors), + ) + + uf = UnionFind(len(active_detectors)) + support_bits_cache: dict[int, tuple[int, ...]] = {} + for support_mask in support_to_weight: + bits = _detectors_from_mask(support_mask) + support_bits_cache[support_mask] = tuple(bits) + if len(bits) > 1: + base = detector_index[bits[0]] + for detector in bits[1:]: + uf.union(base, detector_index[detector]) + + component_detectors: dict[int, list[int]] = {} + for detector in active_detectors: + root = uf.find(detector_index[detector]) + component_detectors.setdefault(root, []).append(detector) + + component_supports: dict[int, list[tuple[int, float]]] = {root: [] for root in component_detectors} + for support_mask, weight in support_to_weight.items(): + bits = support_bits_cache[support_mask] + root = uf.find(detector_index[bits[0]]) + component_supports[root].append((support_mask, weight)) + + total = 0.0 + for root, detectors in component_detectors.items(): + supports = component_supports[root] + if len(detectors) == 1: + total += min(weight for _support_mask, weight in supports) + continue + total += self._solve_component_lp( + supports=tuple(supports), + detector_index=detector_index, + component_detectors=tuple(detectors), + ) + return total + + def _solve_component_lp( + self, + *, + supports: tuple[tuple[int, float], ...], + detector_index: dict[int, int], + component_detectors: tuple[int, ...], + ) -> float: + if linprog is None or csr_matrix is None: + raise RuntimeError( + "The exact singleton-LP heuristic requires SciPy (scipy.optimize.linprog)." + ) + + local_index = {detector: i for i, detector in enumerate(component_detectors)} + row_indices: list[int] = [] + col_indices: list[int] = [] + data: list[float] = [] + rhs: list[float] = [] + + for row, (support_mask, weight) in enumerate(supports): + rhs.append(weight) + pending = support_mask + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + col_indices.append(local_index[detector]) + row_indices.append(row) + data.append(1.0) + + a_ub = csr_matrix( + (data, (row_indices, col_indices)), + shape=(len(supports), len(component_detectors)), + dtype=np.float64, + ) + self.lp_calls += 1 + result = linprog( + c=-np.ones(len(component_detectors), dtype=np.float64), + A_ub=a_ub, + b_ub=np.array(rhs, dtype=np.float64), + bounds=[(0.0, None)] * len(component_detectors), + method="highs", + ) + if result.status == 0: + return max(0.0, float(-result.fun)) + if result.status in {2, 3}: + return INF + raise RuntimeError(f"linprog failed with status={result.status}: {result.message}") + + +def _require_stim() -> None: + if stim is None: + raise RuntimeError( + "This script requires stim for CLI operation. Install stim, or import the module and build models manually." + ) + + +def _likelihood_cost(probability: float) -> float: + if probability <= 0.0: + return math.inf + if probability >= 1.0: + return 0.0 + return -math.log(probability / (1.0 - probability)) + + +def _iter_mask_bits(mask: int) -> Iterable[int]: + while mask: + low_bit = mask & -mask + yield low_bit.bit_length() - 1 + mask ^= low_bit + + +def _detectors_from_mask(mask: int) -> list[int]: + return list(_iter_mask_bits(mask)) + + +def _mask_from_bool_row(row: np.ndarray) -> int: + mask = 0 + for index in np.flatnonzero(row): + mask |= 1 << int(index) + return mask + + +def _future_detcost_by_detector(faults: tuple[Fault, ...], num_detectors: int) -> tuple[tuple[float, ...], ...]: + future_detcost: list[list[float]] = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = len(fault.detector_ids) + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in fault.detector_ids: + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return tuple(tuple(row) for row in future_detcost) + + +def _build_decoder_model(circuit: stim.Circuit) -> DecoderModel: + _require_stim() + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults: list[Fault] = [] + all_possible_dets_mask = 0 + last_seen_index: dict[int, int] = {} + detector_to_faults_lists: list[list[int]] = [[] for _ in range(circuit.num_detectors)] + + for inst in dem: + if inst.type != "error": + continue + + p = float(inst.args_copy()[0]) + det_mask = 0 + flip_l0 = 0 + for target in inst.targets_copy(): + if target.is_separator(): + continue + if target.is_relative_detector_id(): + det_mask ^= 1 << target.val + elif target.is_logical_observable_id() and target.val == 0: + flip_l0 ^= 1 + + detector_ids = tuple(_detectors_from_mask(det_mask)) + fault = Fault( + q=1.0 - p, + p=p, + delta_scale=(-p if flip_l0 else p), + det_mask=det_mask, + detector_ids=detector_ids, + likelihood_cost=_likelihood_cost(p), + ) + faults.append(fault) + all_possible_dets_mask |= det_mask + fault_index = len(faults) - 1 + for det_id in detector_ids: + last_seen_index[det_id] = fault_index + detector_to_faults_lists[det_id].append(fault_index) + + retiring_masks = [0] * len(faults) + for det_id, index in last_seen_index.items(): + retiring_masks[index] |= 1 << det_id + + live_masks_after = [0] * (len(faults) + 1) + active_mask = 0 + max_width = 0 + for i, fault in enumerate(faults): + active_mask |= fault.det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + live_masks_after[i + 1] = active_mask + + frozen_faults = tuple(faults) + return DecoderModel( + faults=frozen_faults, + retiring_masks=tuple(retiring_masks), + live_masks_after=tuple(live_masks_after), + plain_future_detcost=_future_detcost_by_detector(frozen_faults, circuit.num_detectors), + detector_to_faults=tuple(tuple(v) for v in detector_to_faults_lists), + all_possible_dets_mask=all_possible_dets_mask, + max_width=max_width, + num_detectors=circuit.num_detectors, + ) + + +def _detcost_penalty(mismatch_mask: int, future_detcost: tuple[float, ...]) -> float: + total = 0.0 + pending = mismatch_mask + + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + + best = future_detcost[detector] + if best == math.inf: + return math.inf + total += best + + return total + + +def _as_bool_2d(data: np.ndarray, *, expected_cols: int, description: str) -> np.ndarray: + arr = np.asarray(data) + if arr.ndim != 2: + raise ValueError(f"Expected {description} to be a 2D array but got shape {arr.shape!r}.") + if arr.shape[1] != expected_cols: + raise ValueError( + f"Expected {description} to have {expected_cols} columns but got {arr.shape[1]}." + ) + if arr.dtype != np.bool_: + arr = arr.astype(np.bool_, copy=False) + return arr + + +def _sample_shot_arrays( + circuit: stim.Circuit, + *, + shots: int, + seed: int | None, +) -> tuple[np.ndarray, np.ndarray]: + _require_stim() + sampler = circuit.compile_detector_sampler(seed=seed) + dets, obs = sampler.sample(shots=shots, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=circuit.num_detectors, description="sampled detector data"), + _as_bool_2d(obs, expected_cols=circuit.num_observables, description="sampled observable data"), + ) + + +def _read_detector_shot_arrays( + *, + path: str, + fmt: str, + num_detectors: int, + num_observables: int, +) -> tuple[np.ndarray, np.ndarray | None]: + _require_stim() + common_kwargs = dict( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=num_detectors, + num_observables=num_observables, + ) + + if num_observables: + try: + dets, obs = stim.read_shot_data_file(**common_kwargs, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=num_detectors, description="input detector data"), + _as_bool_2d(obs, expected_cols=num_observables, description="appended observable data"), + ) + except TypeError: + flat = stim.read_shot_data_file(**common_kwargs) + flat = _as_bool_2d( + flat, + expected_cols=num_detectors + num_observables, + description="combined detector/observable input data", + ) + return flat[:, :num_detectors], flat[:, num_detectors:] + + flat = stim.read_shot_data_file(**common_kwargs) + return _as_bool_2d(flat, expected_cols=num_detectors, description="input detector data"), None + + +def _read_observable_shot_array(*, path: str, fmt: str, num_observables: int) -> np.ndarray: + _require_stim() + obs = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=0, + num_observables=num_observables, + ) + return _as_bool_2d(obs, expected_cols=num_observables, description="observable input data") + + +def _apply_shot_range( + dets: np.ndarray, + obs: np.ndarray | None, + *, + shot_range_begin: int, + shot_range_end: int, +) -> tuple[np.ndarray, np.ndarray | None]: + if not (shot_range_begin or shot_range_end): + return dets, obs + + if shot_range_end < shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if shot_range_end > len(dets): + raise ValueError( + f"Shot range end {shot_range_end} is past the end of the shot data (size {len(dets)})." + ) + + dets = dets[shot_range_begin:shot_range_end] + if obs is not None: + obs = obs[shot_range_begin:shot_range_end] + return dets, obs + + +def _shots_from_arrays(dets: np.ndarray, obs: np.ndarray | None) -> list[DecodingShot]: + shots: list[DecodingShot] = [] + for shot_index in range(dets.shape[0]): + actual_logical = None if obs is None else bool(obs[shot_index, 0]) + shots.append( + DecodingShot( + det_mask=_mask_from_bool_row(dets[shot_index]), + actual_logical=actual_logical, + ) + ) + return shots + + +def _resolve_stdin_path_if_needed(path: str, *, temp_dir: str, stem: str) -> str: + if path != "-": + return path + temp_path = str(Path(temp_dir) / f"{stem}.bin") + with open(temp_path, "wb") as f: + f.write(sys.stdin.buffer.read()) + return temp_path + + +def _resolve_stdout_path_if_needed(path: str, *, temp_dir: str, stem: str) -> tuple[str, bool]: + if path != "-": + return path, False + return str(Path(temp_dir) / f"{stem}.bin"), True + + +def _copy_file_to_stdout(path: str) -> None: + sys.stdout.flush() + with open(path, "rb") as f: + shutil.copyfileobj(f, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def _load_shots( + circuit: stim.Circuit, + args: argparse.Namespace, + *, + temp_dir: str, +) -> list[DecodingShot]: + if args.in_file: + in_path = _resolve_stdin_path_if_needed(args.in_file, temp_dir=temp_dir, stem="shots_in") + appended_obs_count = circuit.num_observables if args.in_includes_appended_observables else 0 + dets, obs = _read_detector_shot_arrays( + path=in_path, + fmt=args.in_format, + num_detectors=circuit.num_detectors, + num_observables=appended_obs_count, + ) + + if args.obs_in_file: + obs_in_path = _resolve_stdin_path_if_needed(args.obs_in_file, temp_dir=temp_dir, stem="obs_in") + obs = _read_observable_shot_array( + path=obs_in_path, + fmt=args.obs_in_format, + num_observables=circuit.num_observables, + ) + if len(obs) != len(dets): + raise ValueError("Observable input ended before, or after, the detector shot data.") + else: + dets, obs = _sample_shot_arrays(circuit, shots=args.sample_num_shots, seed=args.sample_seed) + + dets, obs = _apply_shot_range( + dets, + obs, + shot_range_begin=args.shot_range_begin, + shot_range_end=args.shot_range_end, + ) + return _shots_from_arrays(dets, obs) + + +def decode_beam_search_singleton_lp_ranked( + model: DecoderModel, + actual_dets_mask: int, + L: int, + *, + heuristic: str, + evaluator: OptimalSingletonLPEvaluator | None = None, +) -> BeamDecodeResult: + start_time = time.perf_counter() + + if heuristic not in {"opt_singleton_lp", "plain_detcost"}: + raise ValueError(f"Unsupported heuristic {heuristic!r}.") + + if (actual_dets_mask & ~model.all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + if heuristic == "opt_singleton_lp": + if evaluator is None: + evaluator = OptimalSingletonLPEvaluator(model) + context = evaluator.begin_shot() + start_heuristic_calls = evaluator.heuristic_calls + start_cache_hits = evaluator.cache_hits + start_lp_calls = evaluator.lp_calls + start_lp_seconds = evaluator.lp_seconds + else: + context = None + start_heuristic_calls = 0 + start_cache_hits = 0 + start_lp_calls = 0 + start_lp_seconds = 0.0 + + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + for i, fault in enumerate(model.faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = model.retiring_masks[i] + + if retiring_mask == 0: + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + total_mass += absent_total + entry = collapsed_probs.get(state) + if entry is None: + collapsed_probs[state] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + total_mass += present_total + entry = collapsed_probs.get(toggled) + if entry is None: + collapsed_probs[toggled] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for state, total, delta in beam: + absent_total = total * fault.q + absent_delta = delta * fault.q + if (state & retiring_mask) == expected_bits: + shrunk = state & keep_mask + total_mass += absent_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [absent_total, absent_delta] + else: + entry[0] += absent_total + entry[1] += absent_delta + + toggled = state ^ fault.det_mask + present_total = total * fault.p + present_delta = delta * fault.delta_scale + if (toggled & retiring_mask) == expected_bits: + shrunk = toggled & keep_mask + total_mass += present_total + entry = collapsed_probs.get(shrunk) + if entry is None: + collapsed_probs[shrunk] = [present_total, present_delta] + else: + entry[0] += present_total + entry[1] += present_delta + + if total_mass == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + heuristic_calls=(0 if evaluator is None else evaluator.heuristic_calls - start_heuristic_calls), + cache_hits=(0 if evaluator is None else evaluator.cache_hits - start_cache_hits), + lp_calls=(0 if evaluator is None else evaluator.lp_calls - start_lp_calls), + lp_seconds=(0.0 if evaluator is None else evaluator.lp_seconds - start_lp_seconds), + ) + + live_target_mask = actual_dets_mask & model.live_masks_after[i + 1] + if context is not None: + evaluator.advance_past_fault(context, i) + + ranked_states: list[tuple[float, float, int, float]] = [] + for state, (total, delta) in collapsed_probs.items(): + mismatch_mask = state ^ live_target_mask + if heuristic == "plain_detcost": + penalty = _detcost_penalty( + mismatch_mask=mismatch_mask, + future_detcost=model.plain_future_detcost[i + 1], + ) + else: + assert evaluator is not None and context is not None + penalty = evaluator.evaluate(context, mismatch_mask) + if penalty == math.inf: + rank_score = -math.inf + else: + rank_score = math.log(total) - penalty + ranked_states.append((rank_score, total, state, delta)) + + dropped_mass = 0.0 + if len(ranked_states) > L: + ranked_states.sort(reverse=True) + kept = ranked_states[:L] + beam = [(state, total, delta) for _, total, state, delta in kept] + kept_mass = sum(total for _, total, _, _ in kept) + dropped_mass = total_mass - kept_mass + else: + beam = [(state, total, delta) for _, total, state, delta in ranked_states] + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + result = BeamDecodeResult( + predicted_logical=None if final_delta == 0.0 else (final_delta < 0.0), + certified=(False if final_delta == 0.0 else certified), + margin=margin, + discarded_mass=discarded_mass, + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + heuristic_calls=(0 if evaluator is None else evaluator.heuristic_calls - start_heuristic_calls), + cache_hits=(0 if evaluator is None else evaluator.cache_hits - start_cache_hits), + lp_calls=(0 if evaluator is None else evaluator.lp_calls - start_lp_calls), + lp_seconds=(0.0 if evaluator is None else evaluator.lp_seconds - start_lp_seconds), + ) + return result + + +def _print_run_header( + *, + circuit: stim.Circuit, + args: argparse.Namespace, + num_shots: int, + log_stream, + evaluator: OptimalSingletonLPEvaluator | None, +) -> None: + print(f"Running on circuit {args.circuit}", file=log_stream) + print(f"Total Detectors: {circuit.num_detectors}", file=log_stream) + print(f"Total Observables: {circuit.num_observables}", file=log_stream) + print(f"Heuristic: {args.heuristic}", file=log_stream) + if args.heuristic == "opt_singleton_lp": + print( + f"Singleton LP Cache: {'on' if not args.no_singleton_lp_cache else 'off'}", + file=log_stream, + ) + if evaluator is not None and evaluator.cache_max_entries > 0: + print(f"Cache Max Entries: {evaluator.cache_max_entries}", file=log_stream) + else: + print("Cache Max Entries: unlimited", file=log_stream) + print( + f"Component Splitting: {'on' if not args.no_singleton_lp_component_splitting else 'off'}", + file=log_stream, + ) + if args.in_file: + print(f"Shot Input: {args.in_file}", file=log_stream) + print(f"Shot Input Format: {args.in_format}", file=log_stream) + if args.in_includes_appended_observables: + print("Observable Input: appended to --in", file=log_stream) + elif args.obs_in_file: + print(f"Observable Input: {args.obs_in_file}", file=log_stream) + print(f"Observable Format: {args.obs_in_format}", file=log_stream) + else: + print("Observable Input: none", file=log_stream) + else: + print(f"Sample Seed: {args.sample_seed}", file=log_stream) + print(f"Requested Shots: {args.sample_num_shots}", file=log_stream) + if args.shot_range_begin or args.shot_range_end: + print( + f"Shot Range: [{args.shot_range_begin}, {args.shot_range_end})", + file=log_stream, + ) + print(f"Num Shots: {num_shots}", file=log_stream) + + +def run_experiment(args: argparse.Namespace) -> ExperimentSummary: + _require_stim() + circuit = stim.Circuit.from_file(args.circuit) + if circuit.num_observables != 1: + raise ValueError( + "This decoder currently supports exactly one logical observable, because it only tracks L0. " + f"The circuit has {circuit.num_observables} observables." + ) + + model = _build_decoder_model(circuit) + evaluator = None + if args.heuristic == "opt_singleton_lp": + evaluator = OptimalSingletonLPEvaluator( + model, + use_cache=not args.no_singleton_lp_cache, + cache_max_entries=args.singleton_lp_cache_max_entries, + split_components=not args.no_singleton_lp_component_splitting, + ) + log_stream = sys.stderr if args.out_file == "-" else sys.stdout + + with tempfile.TemporaryDirectory() as temp_dir: + shots = _load_shots(circuit, args, temp_dir=temp_dir) + _print_run_header( + circuit=circuit, + args=args, + num_shots=len(shots), + log_stream=log_stream, + evaluator=evaluator, + ) + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + num_truth_shots = 0 + num_scored_shots = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + total_heuristic_calls = 0 + total_cache_hits = 0 + total_lp_calls = 0 + total_lp_seconds = 0.0 + predictions: list[bool | None] = [] + + for shot_index, shot in enumerate(shots): + if args.singleton_lp_clear_cache_between_shots and evaluator is not None: + evaluator.clear_cache() + + result = decode_beam_search_singleton_lp_ranked( + model, + shot.det_mask, + args.beam, + heuristic=args.heuristic, + evaluator=evaluator, + ) + predictions.append(result.predicted_logical) + + success: bool | None + if shot.actual_logical is None or result.predicted_logical is None: + success = None + else: + success = result.predicted_logical == shot.actual_logical + + if result.predicted_logical is None: + num_low_confidence += 1 + if shot.actual_logical is not None: + num_truth_shots += 1 + if success is not None: + num_scored_shots += 1 + if not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + total_heuristic_calls += result.heuristic_calls + total_cache_hits += result.cache_hits + total_lp_calls += result.lp_calls + total_lp_seconds += result.lp_seconds + triggered_dets = shot.det_mask.bit_count() + total_triggered += triggered_dets + max_width_seen = max(max_width_seen, result.max_width) + + shots_done = shot_index + 1 + error_rate_so_far = num_errors / num_scored_shots if num_scored_shots else 0.0 + progress = ( + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} scored_shots_so_far={num_scored_shots} " + f"error_rate_so_far={error_rate_so_far:.6f} elapsed_total_seconds={total_elapsed:.6f}" + ) + if args.print_heuristic_stats: + progress += ( + f" heuristic_calls_so_far={total_heuristic_calls} cache_hits_so_far={total_cache_hits} " + f"lp_calls_so_far={total_lp_calls} lp_seconds_so_far={total_lp_seconds:.6f}" + ) + print(progress, file=log_stream) + + if args.print_per_shot: + line = ( + f"shot={shot_index} triggered_detectors={triggered_dets} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"elapsed_seconds={result.elapsed_seconds:.6f}" + ) + if args.print_heuristic_stats: + line += ( + f" heuristic_calls={result.heuristic_calls} cache_hits={result.cache_hits} " + f"lp_calls={result.lp_calls} lp_seconds={result.lp_seconds:.6f}" + ) + print(line, file=log_stream) + + if args.out_file: + output_path, copy_to_stdout = _resolve_stdout_path_if_needed( + args.out_file, + temp_dir=temp_dir, + stem="predictions_out", + ) + prediction_data = np.zeros((len(predictions), circuit.num_observables), dtype=np.bool_) + for shot_index, predicted_logical in enumerate(predictions): + prediction_data[shot_index, 0] = bool(predicted_logical) if predicted_logical is not None else False + + if args.out_format == "ptb64" and len(prediction_data) % 64 != 0: + raise ValueError("The ptb64 format requires the number of shots to be a multiple of 64.") + + stim.write_shot_data_file( + data=prediction_data, + path=output_path, + format=args.out_format, + num_measurements=0, + num_detectors=0, + num_observables=circuit.num_observables, + ) + if copy_to_stdout: + _copy_file_to_stdout(output_path) + if num_low_confidence: + print( + f"warning: wrote {num_low_confidence} low-confidence predictions as L0=0 because Stim result " + "files can only store bits, not unknown values.", + file=log_stream, + ) + + print(f"Beam: {args.beam}", file=log_stream) + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}", file=log_stream) + print(f"Max Width: {max_width_seen}", file=log_stream) + print(f"Certified Shots: {num_certified}", file=log_stream) + print(f"Low Confidence: {num_low_confidence}", file=log_stream) + print(f"Truth-Labeled Shots: {num_truth_shots}", file=log_stream) + print(f"Scored Shots: {num_scored_shots}", file=log_stream) + if num_truth_shots: + print(f"Logical Errors: {num_errors}", file=log_stream) + else: + print("Logical Errors: n/a", file=log_stream) + print(f"Total Seconds: {total_elapsed:.6f}", file=log_stream) + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}", file=log_stream) + if args.print_heuristic_stats: + print(f"Heuristic Calls: {total_heuristic_calls}", file=log_stream) + print(f"LP Cache Hits: {total_cache_hits}", file=log_stream) + print(f"LP Solves: {total_lp_calls}", file=log_stream) + print(f"LP Seconds: {total_lp_seconds:.6f}", file=log_stream) + if evaluator is not None: + print(f"Cache Entries: {len(evaluator.cache)}", file=log_stream) + + return ExperimentSummary( + predictions=predictions, + num_certified=num_certified, + num_low_confidence=num_low_confidence, + num_errors=num_errors, + num_truth_shots=num_truth_shots, + num_scored_shots=num_scored_shots, + total_elapsed=total_elapsed, + total_triggered=total_triggered, + max_width_seen=max_width_seen, + total_heuristic_calls=total_heuristic_calls, + total_cache_hits=total_cache_hits, + total_lp_calls=total_lp_calls, + total_lp_seconds=total_lp_seconds, + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run trellis beam decoding ranked by mass minus an exact optimal singleton-LP future penalty, " + "with Stim-compatible shot-data I/O options." + ), + allow_abbrev=False, + ) + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument("--beam", type=int, default=1000, help="Beam width cutoff.") + parser.add_argument( + "--heuristic", + choices=("opt_singleton_lp", "plain_detcost"), + default="opt_singleton_lp", + help=( + "Future-penalty heuristic used for ranking beam states. " + "'opt_singleton_lp' uses the exact optimal singleton LP; 'plain_detcost' recovers the original decoder." + ), + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=None, + help="Number of sampled shots. Defaults to 1 unless --in is provided.", + ) + parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") + parser.add_argument( + "--shot-range-begin", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--shot-range-end", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--in", + dest="in_file", + default="", + help="File to read detection events from (use - for stdin).", + ) + parser.add_argument( + "--in-format", + "--in_format", + dest="in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--in-includes-appended-observables", + "--in_includes_appended_observables", + dest="in_includes_appended_observables", + action="store_true", + help="Assume the observable flips are appended to each shot in --in.", + ) + parser.add_argument( + "--obs-in", + "--obs_in", + dest="obs_in_file", + default="", + help="File to read observable flips from (use - for stdin).", + ) + parser.add_argument( + "--obs-in-format", + "--obs_in_format", + dest="obs_in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --obs-in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--out", + dest="out_file", + default="", + help="File to write predicted observable flips to (use - for stdout).", + ) + parser.add_argument( + "--out-format", + "--out_format", + dest="out_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file written by --out ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--no-singleton-lp-cache", + action="store_true", + help="Disable reuse of exact singleton-LP values across shots.", + ) + parser.add_argument( + "--singleton-lp-cache-max-entries", + type=int, + default=0, + help="Optional LRU cap on cached exact singleton-LP states. 0 means unlimited.", + ) + parser.add_argument( + "--singleton-lp-clear-cache-between-shots", + action="store_true", + help="Clear the exact singleton-LP cache before every shot.", + ) + parser.add_argument( + "--no-singleton-lp-component-splitting", + action="store_true", + help="Disable decomposition of the singleton LP into disconnected detector components.", + ) + parser.add_argument( + "--print-heuristic-stats", + action="store_true", + help="Print exact singleton-LP and cache statistics during the run.", + ) + parser.add_argument( + "--print-per-shot", + action="store_true", + help="Print a detailed line per decoded shot.", + ) + args = parser.parse_args() + + if args.sample_num_shots is None: + args.sample_num_shots = 0 if args.in_file else 1 + + if args.beam <= 0: + raise ValueError("--beam must be positive.") + if args.sample_num_shots < 0: + raise ValueError("--sample-num-shots must be non-negative.") + if args.sample_seed is not None and args.sample_seed < 0: + raise ValueError("--sample-seed must be non-negative.") + if args.shot_range_begin < 0 or args.shot_range_end < 0: + raise ValueError("--shot-range-begin and --shot-range-end must be non-negative.") + if args.shot_range_end < args.shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if args.in_includes_appended_observables and args.obs_in_file: + raise ValueError( + "Choose either --in-includes-appended-observables or --obs-in, not both." + ) + if args.obs_in_file and not args.in_file: + raise ValueError("Cannot load observable flips from --obs-in without also providing --in.") + if args.in_file == "-" and args.obs_in_file == "-": + raise ValueError("At most one of --in and --obs-in may read from stdin.") + if args.singleton_lp_cache_max_entries < 0: + raise ValueError("--singleton-lp-cache-max-entries must be non-negative.") + if args.heuristic == "plain_detcost" and ( + args.no_singleton_lp_cache + or args.singleton_lp_cache_max_entries + or args.singleton_lp_clear_cache_between_shots + or args.no_singleton_lp_component_splitting + ): + # Allowed but pointless; keep the CLI permissive. + pass + + num_shot_sources = int(args.sample_num_shots > 0) + int(bool(args.in_file)) + if num_shot_sources != 1: + raise ValueError("Requires exactly one source of shots: either --sample-num-shots > 0 or --in.") + + return args + + +if __name__ == "__main__": + run_experiment(_parse_args()) diff --git a/src/tesseract_ftl.cc b/src/tesseract_ftl.cc index 9b6f884..da2cc5b 100644 --- a/src/tesseract_ftl.cc +++ b/src/tesseract_ftl.cc @@ -321,9 +321,8 @@ SingletonComponentSolveResult solve_singleton_component_lp( throw std::runtime_error("Constraint generation exceeded the number of unique constraints."); } - DenseSimplexResult simplex = - solve_dense_primal_packing_lp(num_local_detectors, constraints, selected_indices, - &seed_budgets); + DenseSimplexResult simplex = solve_dense_primal_packing_lp(num_local_detectors, constraints, + selected_indices, &seed_budgets); result.simplex_solves++; if (simplex.unbounded) { result.unbounded = true; @@ -430,8 +429,8 @@ std::string TesseractFTLConfig::str() { ss << "subset_detcost_size=" << subset_detcost_size << ", "; ss << "ignore_blocked_errors_in_heuristic=" << ignore_blocked_errors_in_heuristic << ", "; ss << "num_min_dets_to_consider=" << num_min_dets_to_consider << ", "; - ss << "detector_choice_policy=" - << detector_choice_policy_to_string(detector_choice_policy) << ", "; + ss << "detector_choice_policy=" << detector_choice_policy_to_string(detector_choice_policy) + << ", "; ss << "error_order_policy=" << error_order_policy_to_string(error_order_policy) << ", "; ss << "root_det_order_count=" << root_det_order_count << ", "; ss << "root_det_order_depth=" << root_det_order_depth << ", "; @@ -656,11 +655,10 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c } } const auto candidate_stop_time = std::chrono::high_resolution_clock::now(); - stats.component_candidate_total_seconds += - std::chrono::duration_cast(candidate_stop_time - - candidate_start_time) - .count() / - 1e6; + stats.component_candidate_total_seconds += std::chrono::duration_cast( + candidate_stop_time - candidate_start_time) + .count() / + 1e6; const auto union_start_time = std::chrono::high_resolution_clock::now(); UnionFind uf(active_detectors.size()); @@ -711,8 +709,7 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c const auto dedup_start_time = std::chrono::high_resolution_clock::now(); std::vector, double, IntVectorHash>>> - min_rhs_by_pattern( - result.components.size()); + min_rhs_by_pattern(result.components.size()); std::vector local_hits; local_hits.reserve(16); @@ -786,10 +783,10 @@ TesseractFTLDecoder::SingletonBuildResult TesseractFTLDecoder::build_singleton_c } } const auto finalize_stop_time = std::chrono::high_resolution_clock::now(); - stats.component_finalize_total_seconds += - std::chrono::duration_cast(finalize_stop_time - finalize_start_time) - .count() / - 1e6; + stats.component_finalize_total_seconds += std::chrono::duration_cast( + finalize_stop_time - finalize_start_time) + .count() / + 1e6; return result; } @@ -860,11 +857,10 @@ TesseractFTLDecoder::ExactSubsetSolution TesseractFTLDecoder::solve_exact_subset component.cheapest_constraint_for_local_detector, seed_budgets); const auto simplex_stop_time = std::chrono::high_resolution_clock::now(); stats.simplex_calls++; - stats.simplex_total_seconds += - std::chrono::duration_cast(simplex_stop_time - - simplex_start_time) - .count() / - 1e6; + stats.simplex_total_seconds += std::chrono::duration_cast( + simplex_stop_time - simplex_start_time) + .count() / + 1e6; stats.lp_calls += component_result.simplex_solves; if (component_result.unbounded) { @@ -963,8 +959,9 @@ std::vector TesseractFTLDecoder::select_min_detectors( double budget; }; - const size_t order_count = - depth < config.root_det_order_depth ? std::min(config.root_det_order_count, config.det_orders.size()) : 1; + const size_t order_count = depth < config.root_det_order_depth + ? std::min(config.root_det_order_count, config.det_orders.size()) + : 1; std::vector seen(num_detectors, 0); std::vector candidates; candidates.reserve(detectors.count()); @@ -1411,7 +1408,8 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio continue; } - double child_h = project_from_exact_solution(exact_solution, child_detectors, prefix_blocked); + double child_h = + project_from_exact_solution(exact_solution, child_detectors, prefix_blocked); stats.projected_nodes_generated++; children_projected++; if (child_h == INF_D) { @@ -1442,9 +1440,9 @@ void TesseractFTLDecoder::decode_to_errors(const std::vector& detectio if (config.exact_child_refine_count > 0 && children_exactly_refined < config.exact_child_refine_count) { - ExactSubsetSolution child_exact = solve_exact_subset_lp( - detector_state_arena[(size_t)child.detector_state_idx], prefix_blocked, - child.warm_solution_idx); + ExactSubsetSolution child_exact = + solve_exact_subset_lp(detector_state_arena[(size_t)child.detector_state_idx], + prefix_blocked, child.warm_solution_idx); if (child_exact.value == INF_D) { children_infeasible++; stats.total_children_infeasible++; diff --git a/src/tesseract_ftl_main.cc b/src/tesseract_ftl_main.cc index 9ca6fc0..92758b0 100644 --- a/src/tesseract_ftl_main.cc +++ b/src/tesseract_ftl_main.cc @@ -304,13 +304,15 @@ int main(int argc, char* argv[]) { .flag() .store_into(args.ignore_blocked_errors_in_heuristic); program.add_argument("--num-min-dets-to-consider") - .help("Experimental: when expanding a node, branch on the first N active detectors in the " - "selected detector order.") + .help( + "Experimental: when expanding a node, branch on the first N active detectors in the " + "selected detector order.") .default_value(size_t(1)) .store_into(args.num_min_dets_to_consider); program.add_argument("--detector-choice-policy") - .help("Experimental detector pivot policy: order, fewest_incident_errors, " - "largest_budget, or largest_budget_per_incident.") + .help( + "Experimental detector pivot policy: order, fewest_incident_errors, " + "largest_budget, or largest_budget_per_incident.") .default_value(std::string("order")) .store_into(args.detector_choice_policy); program.add_argument("--error-order-policy") @@ -326,8 +328,9 @@ int main(int argc, char* argv[]) { .default_value(size_t(0)) .store_into(args.root_det_order_depth); program.add_argument("--exact-child-refine-count") - .help("Experimental exact mode: immediately LP-refine the first N generated children per " - "expanded node.") + .help( + "Experimental exact mode: immediately LP-refine the first N generated children per " + "expanded node.") .default_value(size_t(0)) .store_into(args.exact_child_refine_count); diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index 0d6e728..3737a0d 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -15,28 +15,20 @@ #include "tesseract_trellis.h" #include -#include -#include -#include +#include #include #include +#if defined(__BMI2__) && \ + (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) +#include +#endif +#include #include #include -#include -#include #include #include "utils.h" -namespace std { -template <> -struct hash> { - size_t operator()(const boost::dynamic_bitset<>& bs) const { - return boost::hash_value(bs); - } -}; -} // namespace std - namespace { struct Fault { @@ -48,61 +40,39 @@ struct Fault { std::vector detectors; }; -struct LayerFault { - size_t error_index; - double log_q; - double log_p; - uint64_t obs_mask; - boost::dynamic_bitset<> local_det_mask; - boost::dynamic_bitset<> retiring_mask; - boost::dynamic_bitset<> expected_retiring_bits; - std::vector surviving_local_indices; -}; - -struct SmallLayerFault { - size_t error_index; - double q; - double p; - uint64_t obs_flip_bit; - uint64_t local_det_mask; - uint64_t retiring_mask; - uint64_t expected_retiring_bits; - std::vector surviving_local_indices; -}; - struct PackedMass { uint64_t key; double mass; double penalty; }; -struct StateMass { +struct SmallStateGroup { uint64_t state; double mass; double penalty; + size_t begin; + size_t end; }; -struct ObsAggregate { - double log_mass = -INF; +struct WidePackedMass { + std::vector state_words; + uint64_t obs_mask; + double mass; + double penalty; }; -struct FrontierAggregate { - double total_log_mass = -INF; - std::unordered_map obs_entries; +struct WideStateMass { + std::vector state_words; + double mass; + double penalty; }; -double logsumexp2(double a, double b) { - if (a == -INF) return b; - if (b == -INF) return a; - if (a < b) std::swap(a, b); - return a + std::log1p(std::exp(b - a)); -} - -void add_obs_mass(FrontierAggregate& aggregate, uint64_t obs_mask, double log_mass) { - aggregate.total_log_mass = logsumexp2(aggregate.total_log_mass, log_mass); - auto& obs = aggregate.obs_entries[obs_mask]; - obs.log_mass = logsumexp2(obs.log_mass, log_mass); -} +struct BranchPenaltyUpdate { + bool absent_valid = true; + bool present_valid = true; + double absent_penalty = 0.0; + double present_penalty = 0.0; +}; std::vector parse_faults(const std::vector& errors, size_t num_observables) { std::vector faults; @@ -110,7 +80,9 @@ std::vector parse_faults(const std::vector& errors, size_t for (size_t error_index = 0; error_index < errors.size(); ++error_index) { const auto& error = errors[error_index]; const double p = error.get_probability(); - if (p <= 0) continue; + if (p <= 0) { + continue; + } Fault fault; fault.error_index = error_index; fault.likelihood_cost = error.likelihood_cost; @@ -132,18 +104,9 @@ std::vector parse_faults(const std::vector& errors, size_t return faults; } -boost::dynamic_bitset<> project_state(const boost::dynamic_bitset<>& state, - const std::vector& surviving_local_indices) { - boost::dynamic_bitset<> out(surviving_local_indices.size()); - for (size_t k = 0; k < surviving_local_indices.size(); ++k) { - out[k] = state[surviving_local_indices[k]]; - } - return out; -} - -std::vector build_layer_faults(const std::vector& faults, size_t num_detectors, - const std::vector& detections, - size_t* max_frontier_width_seen) { +bool build_small_layer_templates(const std::vector& faults, size_t num_detectors, + std::vector* layers, + size_t* max_frontier_width_seen) { std::vector last_seen(num_detectors, std::numeric_limits::max()); for (size_t i = 0; i < faults.size(); ++i) { for (int d : faults[i].detectors) { @@ -151,22 +114,15 @@ std::vector build_layer_faults(const std::vector& faults, siz } } - boost::dynamic_bitset<> actual_dets(num_detectors); - for (uint64_t d : detections) { - if (d >= num_detectors) { - throw std::runtime_error("Detector index out of range."); - } - actual_dets.flip(d); - } - std::vector active_detectors; active_detectors.reserve(num_detectors); std::vector global_to_local(num_detectors, -1); - std::vector layers; - layers.reserve(faults.size()); + layers->clear(); + layers->reserve(faults.size()); *max_frontier_width_seen = 0; for (size_t i = 0; i < faults.size(); ++i) { + const size_t previous_width = active_detectors.size(); for (int d : faults[i].detectors) { if (global_to_local[d] == -1) { global_to_local[d] = active_detectors.size(); @@ -175,30 +131,37 @@ std::vector build_layer_faults(const std::vector& faults, siz } *max_frontier_width_seen = std::max(*max_frontier_width_seen, active_detectors.size()); - LayerFault layer{ - .error_index = faults[i].error_index, - .log_q = faults[i].log_q, - .log_p = faults[i].log_p, - .obs_mask = faults[i].obs_mask, - .local_det_mask = boost::dynamic_bitset<>(active_detectors.size()), - .retiring_mask = boost::dynamic_bitset<>(active_detectors.size()), - .expected_retiring_bits = boost::dynamic_bitset<>(active_detectors.size()), + if (*max_frontier_width_seen > 63) { + return false; + } + + TesseractTrellisSmallLayerTemplate layer{ + .q = std::exp(faults[i].log_q), + .p = std::exp(faults[i].log_p), + .obs_flip_bit = faults[i].obs_mask & 1, + .local_det_mask = 0, + .retiring_mask = 0, + .surviving_mask = 0, + .projected_fault_mask = 0, + .previous_width = previous_width, .surviving_local_indices = {}, + .current_active_detectors = active_detectors, + .next_frontier_costs = {}, + .detcost_transition = {}, }; - for (int d : faults[i].detectors) { - layer.local_det_mask.set(global_to_local[d]); + layer.local_det_mask ^= uint64_t{1} << global_to_local[d]; } - for (size_t local = 0; local < active_detectors.size(); ++local) { const int d = active_detectors[local]; if (last_seen[d] == i) { - layer.retiring_mask.set(local); - layer.expected_retiring_bits[local] = actual_dets[d]; + layer.retiring_mask ^= uint64_t{1} << local; } else { - layer.surviving_local_indices.push_back(local); + layer.surviving_local_indices.push_back((uint8_t)local); } } + uint64_t live_mask = (uint64_t{1} << active_detectors.size()) - 1; + layer.surviving_mask = live_mask & ~layer.retiring_mask; std::vector next_active; next_active.reserve(layer.surviving_local_indices.size()); @@ -209,14 +172,29 @@ std::vector build_layer_faults(const std::vector& faults, siz next_active.push_back(d); } active_detectors = std::move(next_active); - layers.push_back(std::move(layer)); + layers->push_back(std::move(layer)); } - return layers; + + return true; } -bool build_small_layer_templates(const std::vector& faults, size_t num_detectors, - std::vector* layers, - size_t* max_frontier_width_seen) { +void build_small_detector_layer_refs( + const std::vector& layers, size_t num_detectors, + std::vector>* refs) { + refs->assign(num_detectors, {}); + for (size_t layer_index = 0; layer_index < layers.size(); ++layer_index) { + const auto& layer = layers[layer_index]; + for (size_t local = 0; local < layer.current_active_detectors.size(); ++local) { + int detector = layer.current_active_detectors[local]; + (*refs)[(size_t)detector].push_back( + {static_cast(layer_index), static_cast(local)}); + } + } +} + +void build_wide_layer_templates(const std::vector& faults, size_t num_detectors, + std::vector* layers, + size_t* max_frontier_width_seen) { std::vector last_seen(num_detectors, std::numeric_limits::max()); for (size_t i = 0; i < faults.size(); ++i) { for (int d : faults[i].detectors) { @@ -241,31 +219,22 @@ bool build_small_layer_templates(const std::vector& faults, size_t num_de } *max_frontier_width_seen = std::max(*max_frontier_width_seen, active_detectors.size()); - if (*max_frontier_width_seen > 63) { - return false; - } - - TesseractTrellisSmallLayerTemplate layer{ + TesseractTrellisWideLayerTemplate layer{ .q = std::exp(faults[i].log_q), .p = std::exp(faults[i].log_p), - .obs_flip_bit = faults[i].obs_mask & 1, - .local_det_mask = 0, - .retiring_mask = 0, + .obs_mask = faults[i].obs_mask, .previous_width = previous_width, .surviving_local_indices = {}, .current_active_detectors = active_detectors, + .projected_fault_mask_words = {}, .next_frontier_costs = {}, .detcost_transition = {}, }; - for (int d : faults[i].detectors) { - layer.local_det_mask ^= uint64_t{1} << global_to_local[d]; - } + for (size_t local = 0; local < active_detectors.size(); ++local) { const int d = active_detectors[local]; - if (last_seen[d] == i) { - layer.retiring_mask ^= uint64_t{1} << local; - } else { - layer.surviving_local_indices.push_back((uint8_t)local); + if (last_seen[d] != i) { + layer.surviving_local_indices.push_back((uint32_t)local); } } @@ -280,47 +249,11 @@ bool build_small_layer_templates(const std::vector& faults, size_t num_de active_detectors = std::move(next_active); layers->push_back(std::move(layer)); } - - return true; -} - -uint64_t project_small_state(uint64_t state, const std::vector& surviving_local_indices) { - uint64_t out = 0; - for (size_t k = 0; k < surviving_local_indices.size(); ++k) { - out |= ((state >> surviving_local_indices[k]) & 1ULL) << k; - } - return out; -} - -uint64_t compute_target_bits(const std::vector& active_detectors, - const boost::dynamic_bitset<>& actual_dets) { - uint64_t target_bits = 0; - for (size_t local = 0; local < active_detectors.size(); ++local) { - if (actual_dets[(size_t)active_detectors[local]]) { - target_bits |= uint64_t{1} << local; - } - } - return target_bits; -} - -double compute_penalty_from_scratch(uint64_t mismatch_mask, - const std::vector& aligned_future_costs) { - double total = 0.0; - while (mismatch_mask) { - uint64_t low_bit = mismatch_mask & -mismatch_mask; - int detector = std::countr_zero(low_bit); - mismatch_mask ^= low_bit; - double best = aligned_future_costs[(size_t)detector]; - if (best == INF) { - return INF; - } - total += best; - } - return total; } +template void build_future_detcost_transitions(const std::vector& faults, size_t num_detectors, - std::vector* layers, + std::vector* layers, std::vector* initial_future_detcost) { std::vector current_row(num_detectors, INF); for (size_t fault_index = faults.size(); fault_index-- > 0;) { @@ -329,45 +262,241 @@ void build_future_detcost_transitions(const std::vector& faults, size_t n layer.next_frontier_costs.resize(layer.surviving_local_indices.size(), INF); for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { - int global_detector = layer.current_active_detectors[layer.surviving_local_indices[next_local]]; + size_t current_local = (size_t)layer.surviving_local_indices[next_local]; + int global_detector = layer.current_active_detectors[current_local]; layer.next_frontier_costs[next_local] = current_row[(size_t)global_detector]; } - std::array current_to_next; - current_to_next.fill(-1); + std::vector current_to_next(layer.current_active_detectors.size(), -1); for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { - current_to_next[layer.surviving_local_indices[next_local]] = (int8_t)next_local; + current_to_next[(size_t)layer.surviving_local_indices[next_local]] = (int32_t)next_local; } - layer.detcost_transition.fault_local_indices.clear(); - layer.detcost_transition.next_local_indices.clear(); - layer.detcost_transition.current_costs.clear(); - layer.detcost_transition.next_costs.clear(); - layer.detcost_transition.fault_local_indices.reserve(fault.detectors.size()); - layer.detcost_transition.next_local_indices.reserve(fault.detectors.size()); - layer.detcost_transition.current_costs.reserve(fault.detectors.size()); - layer.detcost_transition.next_costs.reserve(fault.detectors.size()); + auto& transition = layer.detcost_transition; + transition.fault_local_indices.clear(); + transition.next_local_indices.clear(); + transition.current_costs.clear(); + transition.next_costs.clear(); + transition.fault_local_indices.reserve(fault.detectors.size()); + transition.next_local_indices.reserve(fault.detectors.size()); + transition.current_costs.reserve(fault.detectors.size()); + transition.next_costs.reserve(fault.detectors.size()); if (!fault.detectors.empty()) { double ecost = fault.likelihood_cost / fault.detectors.size(); for (int detector : fault.detectors) { - auto it = std::find(layer.current_active_detectors.begin(), layer.current_active_detectors.end(), - detector); + auto it = std::find(layer.current_active_detectors.begin(), + layer.current_active_detectors.end(), detector); if (it == layer.current_active_detectors.end()) { throw std::runtime_error("Missing detector in active frontier while preparing detcost."); } - uint8_t local = (uint8_t)std::distance(layer.current_active_detectors.begin(), it); + uint32_t local = (uint32_t)std::distance(layer.current_active_detectors.begin(), it); double next_cost = current_row[(size_t)detector]; double current_cost = std::min(ecost, next_cost); - layer.detcost_transition.fault_local_indices.push_back(local); - layer.detcost_transition.next_local_indices.push_back(current_to_next[local]); - layer.detcost_transition.current_costs.push_back(current_cost); - layer.detcost_transition.next_costs.push_back(next_cost); + transition.fault_local_indices.push_back(local); + transition.next_local_indices.push_back(current_to_next[local]); + transition.current_costs.push_back(current_cost); + transition.next_costs.push_back(next_cost); current_row[(size_t)detector] = current_cost; } } } - *initial_future_detcost = std::move(current_row); + + if (initial_future_detcost != nullptr) { + *initial_future_detcost = std::move(current_row); + } +} + +size_t num_state_words(size_t num_bits) { + return (num_bits + 63) >> 6; +} + +uint64_t compact_bits_u64(uint64_t value, uint64_t mask) { +#if defined(__BMI2__) && \ + (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) + return _pext_u64(value, mask); +#else + uint64_t out = 0; + uint64_t out_bit = 1; + while (mask) { + uint64_t keep = mask & -mask; + if (value & keep) { + out |= out_bit; + } + mask ^= keep; + out_bit <<= 1; + } + return out; +#endif +} + +bool get_state_bit(const std::vector& state_words, size_t bit, size_t logical_width) { + if (bit >= logical_width) { + return false; + } + size_t word = bit >> 6; + if (word >= state_words.size()) { + return false; + } + return (state_words[word] >> (bit & 63)) & 1ULL; +} + +void xor_state_words(std::vector& state_words, const std::vector& mask_words) { + if (state_words.size() < mask_words.size()) { + state_words.resize(mask_words.size(), 0); + } + for (size_t k = 0; k < mask_words.size(); ++k) { + state_words[k] ^= mask_words[k]; + } +} + +std::vector project_wide_state(const std::vector& state_words, + size_t logical_width, + const std::vector& surviving_local_indices) { + std::vector out(num_state_words(surviving_local_indices.size()), 0); + for (size_t next_local = 0; next_local < surviving_local_indices.size(); ++next_local) { + size_t current_local = (size_t)surviving_local_indices[next_local]; + if (get_state_bit(state_words, current_local, logical_width)) { + out[next_local >> 6] ^= uint64_t{1} << (next_local & 63); + } + } + return out; +} + +bool wide_state_less(const std::vector& a, const std::vector& b) { + if (a.size() != b.size()) { + return a.size() < b.size(); + } + for (size_t k = a.size(); k-- > 0;) { + if (a[k] != b[k]) { + return a[k] < b[k]; + } + } + return false; +} + +bool wide_state_zero(const std::vector& state_words) { + for (uint64_t word : state_words) { + if (word != 0) { + return false; + } + } + return true; +} + +uint64_t project_small_state(uint64_t state, uint64_t surviving_mask) { + return compact_bits_u64(state, surviving_mask); +} + +double compute_initial_penalty_for_target_bits(uint64_t target_bits, + const std::vector& active_detectors, + const std::vector& initial_future_detcost) { + double total = 0.0; + while (target_bits) { + uint64_t low_bit = target_bits & -target_bits; + size_t local = (size_t)std::countr_zero(low_bit); + target_bits ^= low_bit; + double best = initial_future_detcost[(size_t)active_detectors[local]]; + if (best == INF) { + return INF; + } + total += best; + } + return total; +} + +double compute_initial_penalty_for_active_detectors( + const std::vector& active_detectors, const boost::dynamic_bitset<>& actual_dets, + const std::vector& initial_future_detcost) { + double total = 0.0; + for (int detector : active_detectors) { + if (!actual_dets[(size_t)detector]) { + continue; + } + double best = initial_future_detcost[(size_t)detector]; + if (best == INF) { + return INF; + } + total += best; + } + return total; +} + +BranchPenaltyUpdate compute_small_branch_update(uint64_t base_state, size_t previous_width, + double current_penalty, + uint64_t current_target_bits, + const TesseractTrellisDetcostTransition& transition, + bool compute_penalties) { + BranchPenaltyUpdate update; + update.absent_penalty = compute_penalties ? current_penalty : 0.0; + update.present_penalty = compute_penalties ? current_penalty : 0.0; + + for (size_t k = 0; k < transition.fault_local_indices.size(); ++k) { + size_t local = transition.fault_local_indices[k]; + bool state_bit = local < previous_width && ((base_state >> local) & 1ULL); + bool target_bit = (current_target_bits >> local) & 1ULL; + bool mismatch = state_bit ^ target_bit; + int32_t next_local = transition.next_local_indices[k]; + + if (next_local < 0) { + if (mismatch) { + update.absent_valid = false; + } else { + update.present_valid = false; + } + } + + if (!compute_penalties) { + continue; + } + + double prev_contrib = (local < previous_width && mismatch) ? transition.current_costs[k] : 0.0; + double absent_contrib = (next_local >= 0 && mismatch) ? transition.next_costs[k] : 0.0; + double present_contrib = (next_local >= 0 && !mismatch) ? transition.next_costs[k] : 0.0; + update.absent_penalty += absent_contrib - prev_contrib; + update.present_penalty += present_contrib - prev_contrib; + } + + return update; +} + +BranchPenaltyUpdate compute_wide_branch_update(const std::vector& base_state_words, + size_t previous_width, double current_penalty, + const std::vector& current_active_detectors, + const boost::dynamic_bitset<>& actual_dets, + const TesseractTrellisDetcostTransition& transition, + bool compute_penalties) { + BranchPenaltyUpdate update; + update.absent_penalty = compute_penalties ? current_penalty : 0.0; + update.present_penalty = compute_penalties ? current_penalty : 0.0; + + for (size_t k = 0; k < transition.fault_local_indices.size(); ++k) { + size_t local = transition.fault_local_indices[k]; + bool state_bit = get_state_bit(base_state_words, local, previous_width); + bool target_bit = actual_dets[(size_t)current_active_detectors[local]]; + bool mismatch = state_bit ^ target_bit; + int32_t next_local = transition.next_local_indices[k]; + + if (next_local < 0) { + if (mismatch) { + update.absent_valid = false; + } else { + update.present_valid = false; + } + } + + if (!compute_penalties) { + continue; + } + + double prev_contrib = (local < previous_width && mismatch) ? transition.current_costs[k] : 0.0; + double absent_contrib = (next_local >= 0 && mismatch) ? transition.next_costs[k] : 0.0; + double present_contrib = (next_local >= 0 && !mismatch) ? transition.next_costs[k] : 0.0; + update.absent_penalty += absent_contrib - prev_contrib; + update.present_penalty += present_contrib - prev_contrib; + } + + return update; } uint64_t pack_small_key(uint64_t state, uint64_t obs_flip_bit) { @@ -397,53 +526,87 @@ void normalize_items(std::vector& items) { } } -std::vector merge_equal_keys(std::vector& items) { +void normalize_items(std::vector& items) { + double total_mass = 0.0; + for (const auto& item : items) { + total_mass += item.mass; + } + if (total_mass == 0.0) { + items.clear(); + return; + } + double inv = 1.0 / total_mass; + for (auto& item : items) { + item.mass *= inv; + } +} + +void merge_equal_keys_inplace(std::vector& items) { + if (items.empty()) { + return; + } + std::sort(items.begin(), items.end(), + [](const PackedMass& a, const PackedMass& b) { return a.key < b.key; }); + size_t out = 0; + for (size_t i = 1; i < items.size(); ++i) { + if (items[i].key == items[out].key) { + items[out].mass += items[i].mass; + } else { + ++out; + if (out != i) { + items[out] = std::move(items[i]); + } + } + } + items.resize(out + 1); +} + +void merge_equal_keys_inplace(std::vector& items) { if (items.empty()) { - return {}; + return; } - std::sort(items.begin(), items.end(), [](const PackedMass& a, const PackedMass& b) { - return a.key < b.key; + std::sort(items.begin(), items.end(), [](const WidePackedMass& a, const WidePackedMass& b) { + if (wide_state_less(a.state_words, b.state_words)) { + return true; + } + if (wide_state_less(b.state_words, a.state_words)) { + return false; + } + return a.obs_mask < b.obs_mask; }); - std::vector merged; - merged.reserve(items.size()); - uint64_t cur_key = items[0].key; - double cur_mass = items[0].mass; - double cur_penalty = items[0].penalty; + + size_t out = 0; for (size_t i = 1; i < items.size(); ++i) { - if (items[i].key == cur_key) { - cur_mass += items[i].mass; + if (items[i].obs_mask == items[out].obs_mask && + items[i].state_words == items[out].state_words) { + items[out].mass += items[i].mass; } else { - merged.push_back({cur_key, cur_mass, cur_penalty}); - cur_key = items[i].key; - cur_mass = items[i].mass; - cur_penalty = items[i].penalty; + ++out; + if (out != i) { + items[out] = std::move(items[i]); + } } } - merged.push_back({cur_key, cur_mass, cur_penalty}); - return merged; + items.resize(out + 1); } -std::vector accumulate_state_masses_from_entries(const std::vector& entries) { - std::vector totals; +std::vector accumulate_state_masses_from_entries( + const std::vector& entries) { + std::vector totals; if (entries.empty()) { return totals; } totals.reserve(entries.size()); - uint64_t cur_state = unpack_small_state(entries[0].key); - double cur_mass = entries[0].mass; - double cur_penalty = entries[0].penalty; + WideStateMass current{entries[0].state_words, entries[0].mass, entries[0].penalty}; for (size_t i = 1; i < entries.size(); ++i) { - uint64_t s = unpack_small_state(entries[i].key); - if (s == cur_state) { - cur_mass += entries[i].mass; + if (entries[i].state_words == current.state_words) { + current.mass += entries[i].mass; } else { - totals.push_back({cur_state, cur_mass, cur_penalty}); - cur_state = s; - cur_mass = entries[i].mass; - cur_penalty = entries[i].penalty; + totals.push_back(std::move(current)); + current = {entries[i].state_words, entries[i].mass, entries[i].penalty}; } } - totals.push_back({cur_state, cur_mass, cur_penalty}); + totals.push_back(std::move(current)); return totals; } @@ -457,7 +620,48 @@ double branch_score(const PackedMass& item, TesseractTrellisRankingMode ranking_ return std::log(item.mass) - item.penalty; } -double state_score(const StateMass& item, TesseractTrellisRankingMode ranking_mode) { +double branch_score(const WidePackedMass& item, TesseractTrellisRankingMode ranking_mode) { + if (ranking_mode == TesseractTrellisRankingMode::MassOnly) { + return item.mass; + } + if (item.penalty == INF || item.mass == 0.0) { + return -INF; + } + return std::log(item.mass) - item.penalty; +} + +double state_score(const SmallStateGroup& item, TesseractTrellisRankingMode ranking_mode) { + if (ranking_mode == TesseractTrellisRankingMode::MassOnly) { + return item.mass; + } + if (item.penalty == INF || item.mass == 0.0) { + return -INF; + } + return std::log(item.mass) - item.penalty; +} + +std::vector collect_small_state_groups(const std::vector& entries) { + std::vector groups; + if (entries.empty()) { + return groups; + } + groups.reserve(entries.size()); + size_t begin = 0; + while (begin < entries.size()) { + uint64_t state = unpack_small_state(entries[begin].key); + double mass = 0.0; + size_t end = begin; + while (end < entries.size() && unpack_small_state(entries[end].key) == state) { + mass += entries[end].mass; + ++end; + } + groups.push_back({state, mass, entries[begin].penalty, begin, end}); + begin = end; + } + return groups; +} + +double state_score(const WideStateMass& item, TesseractTrellisRankingMode ranking_mode) { if (ranking_mode == TesseractTrellisRankingMode::MassOnly) { return item.mass; } @@ -472,34 +676,172 @@ void keep_top_states(std::vector& entries, size_t beam_width, if (entries.empty()) { return; } + auto groups = collect_small_state_groups(entries); + if (groups.size() <= beam_width) { + return; + } + + std::vector keep_indices(groups.size()); + std::iota(keep_indices.begin(), keep_indices.end(), 0); + std::nth_element(keep_indices.begin(), keep_indices.begin() + beam_width, keep_indices.end(), + [&groups, ranking_mode](size_t a, size_t b) { + return state_score(groups[a], ranking_mode) > + state_score(groups[b], ranking_mode); + }); + keep_indices.resize(beam_width); + std::sort(keep_indices.begin(), keep_indices.end(), + [&groups](size_t a, size_t b) { return groups[a].begin < groups[b].begin; }); + + std::vector kept; + size_t kept_entries = 0; + for (size_t idx : keep_indices) { + kept_entries += groups[idx].end - groups[idx].begin; + } + kept.reserve(kept_entries); + for (size_t idx : keep_indices) { + const auto& group = groups[idx]; + for (size_t k = group.begin; k < group.end; ++k) { + kept.push_back(entries[k]); + } + } + entries = std::move(kept); +} + +void keep_top_states(std::vector& entries, size_t beam_width, + TesseractTrellisRankingMode ranking_mode) { + if (entries.empty()) { + return; + } auto totals = accumulate_state_masses_from_entries(entries); if (totals.size() <= beam_width) { return; } std::nth_element(totals.begin(), totals.begin() + beam_width, totals.end(), - [ranking_mode](const StateMass& a, const StateMass& b) { + [ranking_mode](const WideStateMass& a, const WideStateMass& b) { return state_score(a, ranking_mode) > state_score(b, ranking_mode); }); totals.resize(beam_width); - std::sort(totals.begin(), totals.end(), [](const StateMass& a, const StateMass& b) { - return a.state < b.state; + std::sort(totals.begin(), totals.end(), [](const WideStateMass& a, const WideStateMass& b) { + return wide_state_less(a.state_words, b.state_words); }); - std::vector kept; + std::vector kept; kept.reserve(entries.size()); size_t ti = 0; - for (const auto& item : entries) { - uint64_t s = unpack_small_state(item.key); - while (ti < totals.size() && totals[ti].state < s) { + for (auto& item : entries) { + while (ti < totals.size() && wide_state_less(totals[ti].state_words, item.state_words)) { ++ti; } - if (ti < totals.size() && totals[ti].state == s) { - kept.push_back(item); + if (ti < totals.size() && item.state_words == totals[ti].state_words) { + kept.push_back(std::move(item)); } } entries = std::move(kept); } +void keep_best_state_representatives(std::vector& entries, size_t beam_width, + TesseractTrellisRankingMode ranking_mode) { + if (entries.empty()) { + return; + } + if (beam_width == 0) { + entries.clear(); + return; + } + + std::vector representative_indices; + representative_indices.reserve(entries.size()); + size_t begin = 0; + while (begin < entries.size()) { + uint64_t state = unpack_small_state(entries[begin].key); + size_t best = begin; + double best_score = branch_score(entries[begin], ranking_mode); + size_t end = begin + 1; + while (end < entries.size() && unpack_small_state(entries[end].key) == state) { + double score = branch_score(entries[end], ranking_mode); + if (score > best_score) { + best = end; + best_score = score; + } + ++end; + } + representative_indices.push_back(best); + begin = end; + } + + if (representative_indices.size() > beam_width) { + std::nth_element(representative_indices.begin(), representative_indices.begin() + beam_width, + representative_indices.end(), [&entries, ranking_mode](size_t a, size_t b) { + double sa = branch_score(entries[a], ranking_mode); + double sb = branch_score(entries[b], ranking_mode); + if (sa != sb) { + return sa > sb; + } + return a < b; + }); + representative_indices.resize(beam_width); + } + std::sort(representative_indices.begin(), representative_indices.end()); + + std::vector kept; + kept.reserve(representative_indices.size()); + for (size_t idx : representative_indices) { + kept.push_back(entries[idx]); + } + entries = std::move(kept); +} + +void keep_best_state_representatives(std::vector& entries, size_t beam_width, + TesseractTrellisRankingMode ranking_mode) { + if (entries.empty()) { + return; + } + if (beam_width == 0) { + entries.clear(); + return; + } + + std::vector representative_indices; + representative_indices.reserve(entries.size()); + size_t begin = 0; + while (begin < entries.size()) { + size_t best = begin; + double best_score = branch_score(entries[begin], ranking_mode); + size_t end = begin + 1; + while (end < entries.size() && entries[end].state_words == entries[begin].state_words) { + double score = branch_score(entries[end], ranking_mode); + if (score > best_score) { + best = end; + best_score = score; + } + ++end; + } + representative_indices.push_back(best); + begin = end; + } + + if (representative_indices.size() > beam_width) { + std::nth_element(representative_indices.begin(), representative_indices.begin() + beam_width, + representative_indices.end(), [&entries, ranking_mode](size_t a, size_t b) { + double sa = branch_score(entries[a], ranking_mode); + double sb = branch_score(entries[b], ranking_mode); + if (sa != sb) { + return sa > sb; + } + return a < b; + }); + representative_indices.resize(beam_width); + } + std::sort(representative_indices.begin(), representative_indices.end()); + + std::vector kept; + kept.reserve(representative_indices.size()); + for (size_t idx : representative_indices) { + kept.push_back(std::move(entries[idx])); + } + entries = std::move(kept); +} + void keep_top_branch_entries(std::vector& entries, size_t beam_width, TesseractTrellisRankingMode ranking_mode) { if (entries.size() <= beam_width) { @@ -512,6 +854,42 @@ void keep_top_branch_entries(std::vector& entries, size_t beam_width entries.resize(beam_width); } +void keep_top_branch_entries(std::vector& entries, size_t beam_width, + TesseractTrellisRankingMode ranking_mode) { + if (entries.size() <= beam_width) { + return; + } + std::nth_element(entries.begin(), entries.begin() + beam_width, entries.end(), + [ranking_mode](const WidePackedMass& a, const WidePackedMass& b) { + return branch_score(a, ranking_mode) > branch_score(b, ranking_mode); + }); + entries.resize(beam_width); +} + +void prepare_projected_fault_masks(std::vector* layers) { + for (auto& layer : *layers) { + layer.projected_fault_mask = 0; + for (int32_t next_local : layer.detcost_transition.next_local_indices) { + if (next_local >= 0) { + layer.projected_fault_mask ^= uint64_t{1} << next_local; + } + } + } +} + +void prepare_projected_fault_masks(std::vector* layers) { + for (auto& layer : *layers) { + layer.projected_fault_mask_words.assign(num_state_words(layer.surviving_local_indices.size()), + 0); + for (int32_t next_local : layer.detcost_transition.next_local_indices) { + if (next_local >= 0) { + size_t local = (size_t)next_local; + layer.projected_fault_mask_words[local >> 6] ^= uint64_t{1} << (local & 63); + } + } + } +} + } // namespace TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_) @@ -532,16 +910,25 @@ TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_) } auto faults = parse_faults(errors, num_observables); + + size_t wide_frontier_width = 0; + build_wide_layer_templates(faults, num_detectors, &wide_layer_templates, &wide_frontier_width); + build_future_detcost_transitions(faults, num_detectors, &wide_layer_templates, + &initial_future_detcost); + prepare_projected_fault_masks(&wide_layer_templates); + size_t small_frontier_width = 0; has_small_layer_templates = num_observables <= 1 && - build_small_layer_templates(faults, num_detectors, &small_layer_templates, &small_frontier_width); + build_small_layer_templates(faults, num_detectors, &small_layer_templates, + &small_frontier_width); if (has_small_layer_templates) { - build_future_detcost_transitions(faults, num_detectors, &small_layer_templates, - &initial_future_detcost); - } else if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) { - throw std::invalid_argument( - "future-detcost ranking is currently implemented only for the packed small trellis path"); + build_future_detcost_transitions(faults, num_detectors, &small_layer_templates, nullptr); + prepare_projected_fault_masks(&small_layer_templates); + build_small_detector_layer_refs(small_layer_templates, num_detectors, + &small_detector_layer_refs); + scratch_small_current_target_bits.assign(small_layer_templates.size(), 0); + scratch_small_expected_retiring_bits.assign(small_layer_templates.size(), 0); } } @@ -559,86 +946,96 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection total_mass_obs0 = 0; total_mass_obs1 = 0; - boost::dynamic_bitset<> actual_dets(num_detectors); - for (uint64_t d : detections) { - if (d >= num_detectors || !all_possible_detectors[d]) { - low_confidence_flag = true; - return; + if (has_small_layer_templates) { + std::fill(scratch_small_current_target_bits.begin(), scratch_small_current_target_bits.end(), + 0); + std::fill(scratch_small_expected_retiring_bits.begin(), + scratch_small_expected_retiring_bits.end(), 0); + + for (uint64_t d : detections) { + if (d >= num_detectors || !all_possible_detectors[d]) { + low_confidence_flag = true; + return; + } + for (const auto& ref : small_detector_layer_refs[(size_t)d]) { + scratch_small_current_target_bits[ref.layer_index] ^= uint64_t{1} << ref.local_index; + } } - actual_dets.flip((size_t)d); - } - if (has_small_layer_templates) { - max_frontier_width_seen = 0; - std::vector current_target_bits_per_layer(small_layer_templates.size()); - std::vector next_target_bits_per_layer(small_layer_templates.size()); - std::vector expected_retiring_bits_per_layer(small_layer_templates.size()); for (size_t layer_index = 0; layer_index < small_layer_templates.size(); ++layer_index) { const auto& layer = small_layer_templates[layer_index]; - max_frontier_width_seen = std::max(max_frontier_width_seen, layer.current_active_detectors.size()); - uint64_t current_target_bits = compute_target_bits(layer.current_active_detectors, actual_dets); - current_target_bits_per_layer[layer_index] = current_target_bits; - expected_retiring_bits_per_layer[layer_index] = current_target_bits & layer.retiring_mask; - - uint64_t next_target_bits = 0; - for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { - uint8_t current_local = layer.surviving_local_indices[next_local]; - next_target_bits |= ((current_target_bits >> current_local) & 1ULL) << next_local; - } - next_target_bits_per_layer[layer_index] = next_target_bits; + max_frontier_width_seen = + std::max(max_frontier_width_seen, layer.current_active_detectors.size()); + scratch_small_expected_retiring_bits[layer_index] = + scratch_small_current_target_bits[layer_index] & layer.retiring_mask; } - std::vector beam_entries; double initial_penalty = 0.0; - if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) { - std::vector initial_frontier_costs; - if (!small_layer_templates.empty()) { - const auto& first_layer = small_layer_templates.front(); - initial_frontier_costs.resize(first_layer.current_active_detectors.size(), INF); - for (size_t local = 0; local < first_layer.current_active_detectors.size(); ++local) { - initial_frontier_costs[local] = - initial_future_detcost[(size_t)first_layer.current_active_detectors[local]]; - } - } - initial_penalty = compute_penalty_from_scratch( - current_target_bits_per_layer.empty() ? 0 : current_target_bits_per_layer.front(), - initial_frontier_costs); + if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked && + !small_layer_templates.empty()) { + initial_penalty = compute_initial_penalty_for_target_bits( + scratch_small_current_target_bits.front(), + small_layer_templates.front().current_active_detectors, initial_future_detcost); } + + std::vector beam_entries; + std::vector next_entries; + beam_entries.reserve(config.beam_width * 2 + 2); + next_entries.reserve(config.beam_width * 4 + 4); beam_entries.push_back({pack_small_key(0, 0), 1.0, initial_penalty}); max_beam_size_seen = 1; + const bool compute_penalties = + config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked; for (size_t layer_index = 0; layer_index < small_layer_templates.size(); ++layer_index) { const auto& layer = small_layer_templates[layer_index]; - const uint64_t next_target_bits = next_target_bits_per_layer[layer_index]; - const uint64_t expected_retiring_bits = expected_retiring_bits_per_layer[layer_index]; + const uint64_t current_target_bits = scratch_small_current_target_bits[layer_index]; + const uint64_t expected_retiring_bits = scratch_small_expected_retiring_bits[layer_index]; + auto t0 = std::chrono::high_resolution_clock::now(); - std::vector next_entries; + next_entries.clear(); next_entries.reserve(beam_entries.size() * 2); for (const auto& item : beam_entries) { ++num_states_expanded; const uint64_t base_state = unpack_small_state(item.key); const uint64_t base_obs = unpack_small_obs(item.key); - if (((base_state ^ expected_retiring_bits) & layer.retiring_mask) == 0) { - uint64_t projected_state = project_small_state(base_state, layer.surviving_local_indices); - next_entries.push_back({pack_small_key(projected_state, base_obs), item.mass * layer.q, 0.0}); + BranchPenaltyUpdate update; + if (compute_penalties) { + update = compute_small_branch_update(base_state, layer.previous_width, item.penalty, + current_target_bits, layer.detcost_transition, true); + } else { + update.absent_valid = + (((base_state ^ expected_retiring_bits) & layer.retiring_mask) == 0); + uint64_t toggled_state = base_state ^ layer.local_det_mask; + update.present_valid = + (((toggled_state ^ expected_retiring_bits) & layer.retiring_mask) == 0); } - uint64_t toggled_state = base_state ^ layer.local_det_mask; - if (((toggled_state ^ expected_retiring_bits) & layer.retiring_mask) == 0) { - uint64_t projected_state = project_small_state(toggled_state, layer.surviving_local_indices); - next_entries.push_back( - {pack_small_key(projected_state, base_obs ^ layer.obs_flip_bit), item.mass * layer.p, 0.0}); + if (!update.absent_valid && !update.present_valid) { + continue; + } + + uint64_t projected_state = project_small_state(base_state, layer.surviving_mask); + if (update.absent_valid && layer.q != 0.0) { + next_entries.push_back({pack_small_key(projected_state, base_obs), item.mass * layer.q, + update.absent_penalty}); + } + if (update.present_valid && layer.p != 0.0) { + next_entries.push_back({pack_small_key(projected_state ^ layer.projected_fault_mask, + base_obs ^ layer.obs_flip_bit), + item.mass * layer.p, update.present_penalty}); } } auto t1 = std::chrono::high_resolution_clock::now(); - time_collapse_seconds += + time_expand_seconds += std::chrono::duration_cast(t1 - t0).count() / 1e6; - beam_entries = std::move(next_entries); + beam_entries.swap(next_entries); bool at_checkpoint = ((layer_index + 1) % config.merge_interval == 0) || (layer_index + 1 == small_layer_templates.size()); if (!at_checkpoint) { + normalize_items(beam_entries); max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); if (beam_entries.empty()) { low_confidence_flag = true; @@ -649,21 +1046,16 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection auto t2a = std::chrono::high_resolution_clock::now(); if (config.prune_mode != TesseractTrellisPruneMode::NoMerge) { - beam_entries = merge_equal_keys(beam_entries); + merge_equal_keys_inplace(beam_entries); } auto t2 = std::chrono::high_resolution_clock::now(); time_collapse_seconds += std::chrono::duration_cast(t2 - t2a).count() / 1e6; - if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked) { - for (auto& item : beam_entries) { - item.penalty = compute_penalty_from_scratch(unpack_small_state(item.key) ^ next_target_bits, - layer.next_frontier_costs); - } - } - if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { keep_top_states(beam_entries, config.beam_width, config.ranking_mode); + } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { + keep_best_state_representatives(beam_entries, config.beam_width, config.ranking_mode); } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries || config.prune_mode == TesseractTrellisPruneMode::NoMerge) { keep_top_branch_entries(beam_entries, config.beam_width, config.ranking_mode); @@ -673,13 +1065,16 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection low_confidence_flag = true; return; } - if (config.prune_mode == TesseractTrellisPruneMode::NoMerge) { + if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { + auto post_groups = collect_small_state_groups(beam_entries); + num_states_merged += post_groups.size(); + max_beam_size_seen = std::max(max_beam_size_seen, post_groups.size()); + } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { num_states_merged += beam_entries.size(); max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); } else { - auto post_totals = accumulate_state_masses_from_entries(beam_entries); - num_states_merged += post_totals.size(); - max_beam_size_seen = std::max(max_beam_size_seen, post_totals.size()); + num_states_merged += beam_entries.size(); + max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); } auto t3 = std::chrono::high_resolution_clock::now(); time_truncate_seconds += @@ -707,98 +1102,132 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection time_reconstruct_seconds += std::chrono::duration_cast(tr1 - tr0).count() / 1e6; } else { - auto faults = parse_faults(errors, num_observables); - auto layers = build_layer_faults(faults, num_detectors, detections, &max_frontier_width_seen); - std::unordered_map, FrontierAggregate> beam; - FrontierAggregate init; - add_obs_mass(init, 0, 0.0); - beam.emplace(boost::dynamic_bitset<>(0), std::move(init)); + boost::dynamic_bitset<> actual_dets(num_detectors); + for (uint64_t d : detections) { + if (d >= num_detectors || !all_possible_detectors[d]) { + low_confidence_flag = true; + return; + } + actual_dets.flip((size_t)d); + } + max_frontier_width_seen = 0; + for (const auto& layer : wide_layer_templates) { + max_frontier_width_seen = + std::max(max_frontier_width_seen, layer.current_active_detectors.size()); + } + + double initial_penalty = 0.0; + if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked && + !wide_layer_templates.empty()) { + initial_penalty = compute_initial_penalty_for_active_detectors( + wide_layer_templates.front().current_active_detectors, actual_dets, + initial_future_detcost); + } + + std::vector beam_entries; + std::vector next_entries; + beam_entries.reserve(config.beam_width * 2 + 2); + next_entries.reserve(config.beam_width * 4 + 4); + beam_entries.push_back({{}, 0, 1.0, initial_penalty}); max_beam_size_seen = 1; - for (const auto& layer : layers) { + for (size_t layer_index = 0; layer_index < wide_layer_templates.size(); ++layer_index) { + const auto& layer = wide_layer_templates[layer_index]; + const bool compute_penalties = + config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked; + auto t0 = std::chrono::high_resolution_clock::now(); - std::unordered_map, FrontierAggregate> expanded; - expanded.reserve(beam.size() * 2 + 1); + next_entries.clear(); + next_entries.reserve(beam_entries.size() * 2); - for (const auto& [state, aggregate] : beam) { + for (const auto& item : beam_entries) { ++num_states_expanded; - boost::dynamic_bitset<> base_state = state; - base_state.resize(layer.local_det_mask.size()); - for (const auto& [obs_mask, obs] : aggregate.obs_entries) { - auto& absent_bucket = expanded[base_state]; - add_obs_mass(absent_bucket, obs_mask, obs.log_mass + layer.log_q); - - boost::dynamic_bitset<> present_state = base_state ^ layer.local_det_mask; - auto& present_bucket = expanded[present_state]; - add_obs_mass(present_bucket, obs_mask ^ layer.obs_mask, obs.log_mass + layer.log_p); + BranchPenaltyUpdate update = compute_wide_branch_update( + item.state_words, layer.previous_width, item.penalty, layer.current_active_detectors, + actual_dets, layer.detcost_transition, compute_penalties); + + if (!update.absent_valid && !update.present_valid) { + continue; + } + + std::vector projected_state = project_wide_state( + item.state_words, layer.previous_width, layer.surviving_local_indices); + if (update.absent_valid && layer.q != 0.0) { + next_entries.push_back( + {projected_state, item.obs_mask, item.mass * layer.q, update.absent_penalty}); + } + if (update.present_valid && layer.p != 0.0) { + std::vector projected_toggled = projected_state; + xor_state_words(projected_toggled, layer.projected_fault_mask_words); + next_entries.push_back({std::move(projected_toggled), item.obs_mask ^ layer.obs_mask, + item.mass * layer.p, update.present_penalty}); } } auto t1 = std::chrono::high_resolution_clock::now(); time_expand_seconds += std::chrono::duration_cast(t1 - t0).count() / 1e6; - std::unordered_map, FrontierAggregate> collapsed; - collapsed.reserve(expanded.size()); - for (auto& [state, aggregate] : expanded) { - if (((state & layer.retiring_mask) ^ layer.expected_retiring_bits).any()) { - continue; - } - boost::dynamic_bitset<> projected = project_state(state, layer.surviving_local_indices); - auto& out = collapsed[projected]; - for (const auto& [obs_mask, obs] : aggregate.obs_entries) { - add_obs_mass(out, obs_mask, obs.log_mass); + beam_entries.swap(next_entries); + bool at_checkpoint = ((layer_index + 1) % config.merge_interval == 0) || + (layer_index + 1 == wide_layer_templates.size()); + if (!at_checkpoint) { + normalize_items(beam_entries); + max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); + if (beam_entries.empty()) { + low_confidence_flag = true; + return; } + continue; + } + + auto t2a = std::chrono::high_resolution_clock::now(); + if (config.prune_mode != TesseractTrellisPruneMode::NoMerge) { + merge_equal_keys_inplace(beam_entries); } auto t2 = std::chrono::high_resolution_clock::now(); time_collapse_seconds += - std::chrono::duration_cast(t2 - t1).count() / 1e6; + std::chrono::duration_cast(t2 - t2a).count() / 1e6; - num_states_merged += collapsed.size(); - if (collapsed.empty()) { + if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { + keep_top_states(beam_entries, config.beam_width, config.ranking_mode); + } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { + keep_best_state_representatives(beam_entries, config.beam_width, config.ranking_mode); + } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries || + config.prune_mode == TesseractTrellisPruneMode::NoMerge) { + keep_top_branch_entries(beam_entries, config.beam_width, config.ranking_mode); + } + normalize_items(beam_entries); + if (beam_entries.empty()) { low_confidence_flag = true; return; } - - std::vector, FrontierAggregate>> next_beam; - next_beam.reserve(collapsed.size()); - for (auto& item : collapsed) { - next_beam.push_back(std::move(item)); - } - if (next_beam.size() > config.beam_width) { - std::nth_element(next_beam.begin(), next_beam.begin() + config.beam_width, next_beam.end(), - [](const auto& a, const auto& b) { - return a.second.total_log_mass > b.second.total_log_mass; - }); - next_beam.resize(config.beam_width); - } - max_beam_size_seen = std::max(max_beam_size_seen, next_beam.size()); - - beam.clear(); - beam.reserve(next_beam.size()); - for (auto& item : next_beam) { - beam.emplace(std::move(item.first), std::move(item.second)); + if (config.prune_mode == TesseractTrellisPruneMode::NoMerge) { + num_states_merged += beam_entries.size(); + max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); + } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { + num_states_merged += beam_entries.size(); + max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); + } else { + auto post_totals = accumulate_state_masses_from_entries(beam_entries); + num_states_merged += post_totals.size(); + max_beam_size_seen = std::max(max_beam_size_seen, post_totals.size()); } auto t3 = std::chrono::high_resolution_clock::now(); time_truncate_seconds += std::chrono::duration_cast(t3 - t2).count() / 1e6; } - auto it = beam.find(boost::dynamic_bitset<>(0)); - if (it == beam.end() || it->second.obs_entries.empty()) { - low_confidence_flag = true; - return; - } - - const auto& final_entry = it->second; auto tr0 = std::chrono::high_resolution_clock::now(); - if (final_entry.obs_entries.empty()) { - low_confidence_flag = true; - return; + for (const auto& item : beam_entries) { + if (!wide_state_zero(item.state_words)) { + continue; + } + if (item.obs_mask == 0) { + total_mass_obs0 += item.mass; + } else if (item.obs_mask == 1) { + total_mass_obs1 += item.mass; + } } - auto it0 = final_entry.obs_entries.find(0); - auto it1 = final_entry.obs_entries.find(1); - total_mass_obs0 = it0 == final_entry.obs_entries.end() ? 0.0 : std::exp(it0->second.log_mass); - total_mass_obs1 = it1 == final_entry.obs_entries.end() ? 0.0 : std::exp(it1->second.log_mass); if (total_mass_obs0 == 0.0 && total_mass_obs1 == 0.0) { low_confidence_flag = true; return; @@ -813,8 +1242,8 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection std::cout << "trellis beam_width=" << config.beam_width << " frontier_width=" << max_frontier_width_seen << " states_expanded=" << num_states_expanded - << " states_merged=" << num_states_merged - << " max_beam=" << max_beam_size_seen << std::endl; + << " states_merged=" << num_states_merged << " max_beam=" << max_beam_size_seen + << std::endl; } } diff --git a/src/tesseract_trellis.h b/src/tesseract_trellis.h index 7715b98..897d489 100644 --- a/src/tesseract_trellis.h +++ b/src/tesseract_trellis.h @@ -16,6 +16,7 @@ #define TESSERACT_TRELLIS_DECODER_H #include +#include #include #include "common.h" @@ -23,6 +24,7 @@ enum class TesseractTrellisPruneMode { MergedStates, + KeepBest, BranchEntries, NoMerge, }; @@ -33,8 +35,8 @@ enum class TesseractTrellisRankingMode { }; struct TesseractTrellisDetcostTransition { - std::vector fault_local_indices; - std::vector next_local_indices; + std::vector fault_local_indices; + std::vector next_local_indices; std::vector current_costs; std::vector next_costs; }; @@ -45,6 +47,8 @@ struct TesseractTrellisSmallLayerTemplate { uint64_t obs_flip_bit = 0; uint64_t local_det_mask = 0; uint64_t retiring_mask = 0; + uint64_t surviving_mask = 0; + uint64_t projected_fault_mask = 0; size_t previous_width = 0; std::vector surviving_local_indices; std::vector current_active_detectors; @@ -52,6 +56,23 @@ struct TesseractTrellisSmallLayerTemplate { TesseractTrellisDetcostTransition detcost_transition; }; +struct TesseractTrellisSmallDetectorLayerRef { + uint32_t layer_index = 0; + uint8_t local_index = 0; +}; + +struct TesseractTrellisWideLayerTemplate { + double q = 0; + double p = 0; + uint64_t obs_mask = 0; + size_t previous_width = 0; + std::vector surviving_local_indices; + std::vector current_active_detectors; + std::vector projected_fault_mask_words; + std::vector next_frontier_costs; + TesseractTrellisDetcostTransition detcost_transition; +}; + struct TesseractTrellisConfig { stim::DetectorErrorModel dem; size_t beam_width = 1024; @@ -91,6 +112,10 @@ struct TesseractTrellisDecoder { boost::dynamic_bitset<> all_possible_detectors; bool has_small_layer_templates = false; std::vector small_layer_templates; + std::vector> small_detector_layer_refs; + std::vector scratch_small_current_target_bits; + std::vector scratch_small_expected_retiring_bits; + std::vector wide_layer_templates; std::vector initial_future_detcost; }; diff --git a/src/tesseract_trellis_main.cc b/src/tesseract_trellis_main.cc index d4362aa..22c44e1 100644 --- a/src/tesseract_trellis_main.cc +++ b/src/tesseract_trellis_main.cc @@ -28,6 +28,7 @@ namespace { TesseractTrellisPruneMode parse_prune_mode(const std::string& value) { if (value == "merged") return TesseractTrellisPruneMode::MergedStates; + if (value == "keep-best") return TesseractTrellisPruneMode::KeepBest; if (value == "branch") return TesseractTrellisPruneMode::BranchEntries; if (value == "none") return TesseractTrellisPruneMode::NoMerge; throw std::invalid_argument("Unknown trellis prune mode: " + value); @@ -253,7 +254,9 @@ int main(int argc, char* argv[]) { program.add_argument("--sample-seed") .default_value(static_cast(std::random_device()())) .store_into(args.sample_seed); - program.add_argument("--shot-range-begin").default_value(size_t(0)).store_into(args.shot_range_begin); + program.add_argument("--shot-range-begin") + .default_value(size_t(0)) + .store_into(args.shot_range_begin); program.add_argument("--shot-range-end").default_value(size_t(0)).store_into(args.shot_range_end); program.add_argument("--in").default_value(std::string("")).store_into(args.in_fname); program.add_argument("--in-format", "--in_format") @@ -263,14 +266,18 @@ int main(int argc, char* argv[]) { .default_value(false) .store_into(args.append_observables) .flag(); - program.add_argument("--obs_in", "--obs-in").default_value(std::string("")).store_into(args.obs_in_fname); + program.add_argument("--obs_in", "--obs-in") + .default_value(std::string("")) + .store_into(args.obs_in_fname); program.add_argument("--obs-in-format", "--obs_in_format") .default_value(std::string("")) .store_into(args.obs_in_format); program.add_argument("--out").default_value(std::string("")).store_into(args.out_fname); program.add_argument("--out-format").default_value(std::string("")).store_into(args.out_format); program.add_argument("--dem-out").default_value(std::string("")).store_into(args.dem_out_fname); - program.add_argument("--stats-out").default_value(std::string("")).store_into(args.stats_out_fname); + program.add_argument("--stats-out") + .default_value(std::string("")) + .store_into(args.stats_out_fname); program.add_argument("--threads") .default_value(size_t( std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) @@ -278,7 +285,15 @@ int main(int argc, char* argv[]) { program.add_argument("--beam").default_value(size_t(1024)).store_into(args.beam_width); program.add_argument("--merge-interval").default_value(size_t(1)).store_into(args.merge_interval); program.add_argument("--prune-mode") - .help("Trellis pruning mode: merged, branch, or none") + .help( + "Trellis pruning mode: merged, keep-best, branch, or none. " + "merged sums probabilities of all branches with the same residual detection events. " + "keep-best keeps only the single highest-probability branch for each residual detection " + "state. " + "branch ranks branches individually, but still merges exact duplicate (state, " + "observable) entries first. " + "none skips even that exact-duplicate merge, so identical branches may occupy multiple " + "beam slots.") .default_value(std::string("merged")) .store_into(args.prune_mode); program.add_argument("--ranking-mode") @@ -370,11 +385,9 @@ int main(int argc, char* argv[]) { << " max_beam = " << max_beam_size_per_shot[shot_index] << " frontier_width = " << max_frontier_width_per_shot[shot_index] << " total_time_seconds = " << total_time_seconds << std::endl; - std::cout << "branch_masses" - << " obs0=" << mass0_predicted[shot_index] + std::cout << "branch_masses" << " obs0=" << mass0_predicted[shot_index] << " obs1=" << mass1_predicted[shot_index] << std::endl; - std::cout << "phase_times_seconds" - << " expand=" << time_expand_per_shot[shot_index] + std::cout << "phase_times_seconds" << " expand=" << time_expand_per_shot[shot_index] << " collapse=" << time_collapse_per_shot[shot_index] << " truncate=" << time_truncate_per_shot[shot_index] << " reconstruct=" << time_reconstruct_per_shot[shot_index] << std::endl; @@ -384,7 +397,8 @@ int main(int argc, char* argv[]) { }); if (!args.dem_out_fname.empty()) { - throw std::invalid_argument("--dem-out is not supported by tesseract_trellis without path reconstruction."); + throw std::invalid_argument( + "--dem-out is not supported by tesseract_trellis without path reconstruction."); } bool print_final_stats = true; From dc4a51e4f19c56a201e5f0910e8e7530835fc415 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sat, 18 Apr 2026 13:23:23 -0700 Subject: [PATCH 17/25] add small optimization based on radix sort --- src/tesseract_trellis.cc | 141 +++++++++++++++++++--------------- src/tesseract_trellis_main.cc | 7 +- 2 files changed, 81 insertions(+), 67 deletions(-) diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index 3737a0d..76eb3e2 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -15,6 +15,7 @@ #include "tesseract_trellis.h" #include +#include #include #include #include @@ -49,7 +50,7 @@ struct PackedMass { struct SmallStateGroup { uint64_t state; double mass; - double penalty; + double score; size_t begin; size_t end; }; @@ -541,12 +542,44 @@ void normalize_items(std::vector& items) { } } +void radix_sort_packed_masses_by_key(std::vector& items) { + if (items.size() <= 1) { + return; + } + + thread_local std::vector buffer; + buffer.resize(items.size()); + + PackedMass* src = items.data(); + PackedMass* dst = buffer.data(); + constexpr size_t RADIX = 256; + std::array counts; + + for (size_t shift = 0; shift < 64; shift += 8) { + counts.fill(0); + for (size_t k = 0; k < items.size(); ++k) { + ++counts[(src[k].key >> shift) & 0xFF]; + } + + size_t total = 0; + for (size_t k = 0; k < RADIX; ++k) { + size_t count = counts[k]; + counts[k] = total; + total += count; + } + + for (size_t k = 0; k < items.size(); ++k) { + dst[counts[(src[k].key >> shift) & 0xFF]++] = src[k]; + } + std::swap(src, dst); + } +} + void merge_equal_keys_inplace(std::vector& items) { if (items.empty()) { return; } - std::sort(items.begin(), items.end(), - [](const PackedMass& a, const PackedMass& b) { return a.key < b.key; }); + radix_sort_packed_masses_by_key(items); size_t out = 0; for (size_t i = 1; i < items.size(); ++i) { if (items[i].key == items[out].key) { @@ -610,37 +643,27 @@ std::vector accumulate_state_masses_from_entries( return totals; } -double branch_score(const PackedMass& item, TesseractTrellisRankingMode ranking_mode) { +double score_mass_and_penalty(double mass, double penalty, + TesseractTrellisRankingMode ranking_mode) { if (ranking_mode == TesseractTrellisRankingMode::MassOnly) { - return item.mass; + return mass; } - if (item.penalty == INF || item.mass == 0.0) { + if (penalty == INF || mass == 0.0) { return -INF; } - return std::log(item.mass) - item.penalty; + return std::log(mass) - penalty; } -double branch_score(const WidePackedMass& item, TesseractTrellisRankingMode ranking_mode) { - if (ranking_mode == TesseractTrellisRankingMode::MassOnly) { - return item.mass; - } - if (item.penalty == INF || item.mass == 0.0) { - return -INF; - } - return std::log(item.mass) - item.penalty; +double branch_score(const PackedMass& item, TesseractTrellisRankingMode ranking_mode) { + return score_mass_and_penalty(item.mass, item.penalty, ranking_mode); } -double state_score(const SmallStateGroup& item, TesseractTrellisRankingMode ranking_mode) { - if (ranking_mode == TesseractTrellisRankingMode::MassOnly) { - return item.mass; - } - if (item.penalty == INF || item.mass == 0.0) { - return -INF; - } - return std::log(item.mass) - item.penalty; +double branch_score(const WidePackedMass& item, TesseractTrellisRankingMode ranking_mode) { + return score_mass_and_penalty(item.mass, item.penalty, ranking_mode); } -std::vector collect_small_state_groups(const std::vector& entries) { +std::vector collect_small_state_groups( + const std::vector& entries, TesseractTrellisRankingMode ranking_mode) { std::vector groups; if (entries.empty()) { return groups; @@ -655,66 +678,57 @@ std::vector collect_small_state_groups(const std::vector& entries, size_t beam_width, - TesseractTrellisRankingMode ranking_mode) { +size_t keep_top_states(std::vector& entries, size_t beam_width, + TesseractTrellisRankingMode ranking_mode) { if (entries.empty()) { - return; + return 0; } - auto groups = collect_small_state_groups(entries); + auto groups = collect_small_state_groups(entries, ranking_mode); if (groups.size() <= beam_width) { - return; + return groups.size(); } - std::vector keep_indices(groups.size()); - std::iota(keep_indices.begin(), keep_indices.end(), 0); - std::nth_element(keep_indices.begin(), keep_indices.begin() + beam_width, keep_indices.end(), - [&groups, ranking_mode](size_t a, size_t b) { - return state_score(groups[a], ranking_mode) > - state_score(groups[b], ranking_mode); + std::nth_element(groups.begin(), groups.begin() + beam_width, groups.end(), + [](const SmallStateGroup& a, const SmallStateGroup& b) { + return a.score > b.score; }); - keep_indices.resize(beam_width); - std::sort(keep_indices.begin(), keep_indices.end(), - [&groups](size_t a, size_t b) { return groups[a].begin < groups[b].begin; }); + groups.resize(beam_width); std::vector kept; size_t kept_entries = 0; - for (size_t idx : keep_indices) { - kept_entries += groups[idx].end - groups[idx].begin; + for (const auto& group : groups) { + kept_entries += group.end - group.begin; } kept.reserve(kept_entries); - for (size_t idx : keep_indices) { - const auto& group = groups[idx]; + for (const auto& group : groups) { for (size_t k = group.begin; k < group.end; ++k) { - kept.push_back(entries[k]); + kept.push_back(std::move(entries[k])); } } entries = std::move(kept); + return groups.size(); } -void keep_top_states(std::vector& entries, size_t beam_width, - TesseractTrellisRankingMode ranking_mode) { +size_t keep_top_states(std::vector& entries, size_t beam_width, + TesseractTrellisRankingMode ranking_mode) { if (entries.empty()) { - return; + return 0; } auto totals = accumulate_state_masses_from_entries(entries); if (totals.size() <= beam_width) { - return; + return totals.size(); } std::nth_element(totals.begin(), totals.begin() + beam_width, totals.end(), [ranking_mode](const WideStateMass& a, const WideStateMass& b) { @@ -737,6 +751,7 @@ void keep_top_states(std::vector& entries, size_t beam_width, } } entries = std::move(kept); + return totals.size(); } void keep_best_state_representatives(std::vector& entries, size_t beam_width, @@ -1052,8 +1067,9 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection time_collapse_seconds += std::chrono::duration_cast(t2 - t2a).count() / 1e6; + size_t kept_states = 0; if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { - keep_top_states(beam_entries, config.beam_width, config.ranking_mode); + kept_states = keep_top_states(beam_entries, config.beam_width, config.ranking_mode); } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { keep_best_state_representatives(beam_entries, config.beam_width, config.ranking_mode); } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries || @@ -1066,9 +1082,8 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection return; } if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { - auto post_groups = collect_small_state_groups(beam_entries); - num_states_merged += post_groups.size(); - max_beam_size_seen = std::max(max_beam_size_seen, post_groups.size()); + num_states_merged += kept_states; + max_beam_size_seen = std::max(max_beam_size_seen, kept_states); } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { num_states_merged += beam_entries.size(); max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); @@ -1188,8 +1203,9 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection time_collapse_seconds += std::chrono::duration_cast(t2 - t2a).count() / 1e6; + size_t kept_states = 0; if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { - keep_top_states(beam_entries, config.beam_width, config.ranking_mode); + kept_states = keep_top_states(beam_entries, config.beam_width, config.ranking_mode); } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { keep_best_state_representatives(beam_entries, config.beam_width, config.ranking_mode); } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries || @@ -1208,9 +1224,8 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection num_states_merged += beam_entries.size(); max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); } else { - auto post_totals = accumulate_state_masses_from_entries(beam_entries); - num_states_merged += post_totals.size(); - max_beam_size_seen = std::max(max_beam_size_seen, post_totals.size()); + num_states_merged += kept_states; + max_beam_size_seen = std::max(max_beam_size_seen, kept_states); } auto t3 = std::chrono::high_resolution_clock::now(); time_truncate_seconds += diff --git a/src/tesseract_trellis_main.cc b/src/tesseract_trellis_main.cc index 22c44e1..f7febfb 100644 --- a/src/tesseract_trellis_main.cc +++ b/src/tesseract_trellis_main.cc @@ -384,14 +384,13 @@ int main(int argc, char* argv[]) { << " states_merged = " << num_states_merged_per_shot[shot_index] << " max_beam = " << max_beam_size_per_shot[shot_index] << " frontier_width = " << max_frontier_width_per_shot[shot_index] - << " total_time_seconds = " << total_time_seconds << std::endl; + << " total_time_seconds = " << total_time_seconds << '\n'; std::cout << "branch_masses" << " obs0=" << mass0_predicted[shot_index] - << " obs1=" << mass1_predicted[shot_index] << std::endl; + << " obs1=" << mass1_predicted[shot_index] << '\n'; std::cout << "phase_times_seconds" << " expand=" << time_expand_per_shot[shot_index] << " collapse=" << time_collapse_per_shot[shot_index] << " truncate=" << time_truncate_per_shot[shot_index] - << " reconstruct=" << time_reconstruct_per_shot[shot_index] << std::endl; - std::cout.flush(); + << " reconstruct=" << time_reconstruct_per_shot[shot_index] << '\n'; } return num_errors < args.max_errors; }); From 4af03328a1b09ddd5e7fe8e6b4afe90f351bc978 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sat, 18 Apr 2026 17:33:30 -0700 Subject: [PATCH 18/25] add python prototype of the --beam-eps to trellis and C++ cli flag for it --- .../trellis_beam_detcost_ranked_threshold.py | 1100 +++++++++++++++++ src/tesseract_trellis.cc | 208 +++- src/tesseract_trellis.h | 8 + src/tesseract_trellis_main.cc | 31 + 4 files changed, 1327 insertions(+), 20 deletions(-) create mode 100644 src/py/astar/trellis_beam_detcost_ranked_threshold.py diff --git a/src/py/astar/trellis_beam_detcost_ranked_threshold.py b/src/py/astar/trellis_beam_detcost_ranked_threshold.py new file mode 100644 index 0000000..9e2afc4 --- /dev/null +++ b/src/py/astar/trellis_beam_detcost_ranked_threshold.py @@ -0,0 +1,1100 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import math +import shutil +import sys +import tempfile +import time +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +import stim + + +STIM_RESULT_FORMATS = ("01", "b8", "r8", "ptb64", "hits", "dets") +STIM_RESULT_FORMATS_HELP = "/".join(STIM_RESULT_FORMATS) +BEAM_PRUNE_MODES = ("fixed", "delta", "mass") +BEAM_PRUNE_MODES_HELP = "/".join(BEAM_PRUNE_MODES) + + +@dataclass(frozen=True) +class Fault: + q: float + p: float + delta_scale: float + det_mask: int + likelihood_cost: float + + +@dataclass(frozen=True) +class DecoderModel: + faults: tuple[Fault, ...] + retiring_masks: tuple[int, ...] + live_masks_after: tuple[int, ...] + future_detcost: tuple[tuple[float, ...], ...] + all_possible_dets_mask: int + max_width: int + + +@dataclass(frozen=True) +class BeamPruningConfig: + mode: str + beam_width: int + score_delta: float | None + mass_epsilon: float | None + hard_cap: int | None + + +@dataclass(frozen=True) +class BeamDecodeResult: + predicted_logical: bool | None + certified: bool + margin: float + discarded_mass: float + kept_state_counts: tuple[int, ...] + max_width: int + elapsed_seconds: float + + +@dataclass(frozen=True) +class DecodingShot: + det_mask: int + actual_logical: bool | None + + +@dataclass(frozen=True) +class IntegerSeriesSummary: + count: int + minimum: int | None + median: float | None + mean: float | None + maximum: int | None + + +@dataclass +class IntegerHistogramAccumulator: + count: int = 0 + total: int = 0 + minimum: int | None = None + maximum: int | None = None + histogram: dict[int, int] = field(default_factory=dict) + + def add(self, value: int) -> None: + self.count += 1 + self.total += value + if self.minimum is None or value < self.minimum: + self.minimum = value + if self.maximum is None or value > self.maximum: + self.maximum = value + self.histogram[value] = self.histogram.get(value, 0) + 1 + + def add_many(self, values: tuple[int, ...] | list[int]) -> None: + for value in values: + self.add(value) + + def summary(self) -> IntegerSeriesSummary: + if self.count == 0: + return IntegerSeriesSummary( + count=0, + minimum=None, + median=None, + mean=None, + maximum=None, + ) + + lower_target = (self.count - 1) // 2 + upper_target = self.count // 2 + seen = 0 + lower_value: int | None = None + upper_value: int | None = None + for value in sorted(self.histogram): + seen += self.histogram[value] + if lower_value is None and seen > lower_target: + lower_value = value + if upper_value is None and seen > upper_target: + upper_value = value + break + + assert lower_value is not None and upper_value is not None + return IntegerSeriesSummary( + count=self.count, + minimum=self.minimum, + median=(lower_value + upper_value) / 2.0, + mean=self.total / self.count, + maximum=self.maximum, + ) + + +@dataclass(frozen=True) +class ExperimentSummary: + predictions: list[bool | None] + num_certified: int + num_low_confidence: int + num_errors: int + num_truth_shots: int + num_scored_shots: int + total_elapsed: float + total_triggered: int + max_width_seen: int + kept_state_summary: IntegerSeriesSummary + + +def _likelihood_cost(probability: float) -> float: + if probability <= 0.0: + return math.inf + if probability >= 1.0: + return 0.0 + return -math.log(probability / (1.0 - probability)) + + +def _detectors_from_mask(mask: int) -> list[int]: + detectors: list[int] = [] + while mask: + low_bit = mask & -mask + detectors.append(low_bit.bit_length() - 1) + mask ^= low_bit + return detectors + + +def _mask_from_bool_row(row: np.ndarray) -> int: + mask = 0 + for index in np.flatnonzero(row): + mask |= 1 << int(index) + return mask + + +def _future_detcost_by_detector(faults: tuple[Fault, ...], num_detectors: int) -> tuple[tuple[float, ...], ...]: + future_detcost: list[list[float]] = [[math.inf] * num_detectors for _ in range(len(faults) + 1)] + next_row = future_detcost[-1] + for fault_index in range(len(faults) - 1, -1, -1): + row = next_row.copy() + fault = faults[fault_index] + det_count = fault.det_mask.bit_count() + if det_count: + ecost = fault.likelihood_cost / det_count + for det_id in _detectors_from_mask(fault.det_mask): + if ecost < row[det_id]: + row[det_id] = ecost + future_detcost[fault_index] = row + next_row = row + return tuple(tuple(row) for row in future_detcost) + + +def _build_decoder_model(circuit: stim.Circuit) -> DecoderModel: + dem = circuit.detector_error_model(decompose_errors=False).flattened() + + faults: list[Fault] = [] + all_possible_dets_mask = 0 + last_seen_index: dict[int, int] = {} + + for inst in dem: + if inst.type != "error": + continue + + p = float(inst.args_copy()[0]) + det_mask = 0 + flip_l0 = 0 + for target in inst.targets_copy(): + if target.is_separator(): + continue + if target.is_relative_detector_id(): + det_mask ^= 1 << target.val + elif target.is_logical_observable_id() and target.val == 0: + flip_l0 ^= 1 + + faults.append( + Fault( + q=1.0 - p, + p=p, + delta_scale=(-p if flip_l0 else p), + det_mask=det_mask, + likelihood_cost=_likelihood_cost(p), + ) + ) + all_possible_dets_mask |= det_mask + + for det_id in _detectors_from_mask(det_mask): + last_seen_index[det_id] = len(faults) - 1 + + retiring_masks = [0] * len(faults) + for det_id, index in last_seen_index.items(): + retiring_masks[index] |= 1 << det_id + + live_masks_after = [0] * (len(faults) + 1) + active_mask = 0 + max_width = 0 + for i, fault in enumerate(faults): + active_mask |= fault.det_mask + max_width = max(max_width, active_mask.bit_count()) + active_mask &= ~retiring_masks[i] + live_masks_after[i + 1] = active_mask + + frozen_faults = tuple(faults) + return DecoderModel( + faults=frozen_faults, + retiring_masks=tuple(retiring_masks), + live_masks_after=tuple(live_masks_after), + future_detcost=_future_detcost_by_detector(frozen_faults, circuit.num_detectors), + all_possible_dets_mask=all_possible_dets_mask, + max_width=max_width, + ) + + +def _detcost_penalty(mismatch_mask: int, future_detcost: tuple[float, ...]) -> float: + total = 0.0 + pending = mismatch_mask + + while pending: + low_bit = pending & -pending + detector = low_bit.bit_length() - 1 + pending ^= low_bit + + best = future_detcost[detector] + if best == math.inf: + return math.inf + total += best + + return total + + +def _accumulate_collapsed_state( + collapsed_probs: dict[int, list[float]], + *, + state: int, + total: float, + delta: float, +) -> float: + if total <= 0.0: + return 0.0 + + entry = collapsed_probs.get(state) + if entry is None: + collapsed_probs[state] = [total, delta] + else: + entry[0] += total + entry[1] += delta + return total + + +def _prune_ranked_states( + ranked_states: list[tuple[float, float, int, float]], + *, + total_mass: float, + pruning: BeamPruningConfig, +) -> tuple[list[tuple[int, float, float]], float]: + if not ranked_states: + return [], total_mass + + ranked_states.sort(reverse=True) + + if pruning.mode == "fixed": + kept_ranked = ranked_states[:pruning.beam_width] + elif pruning.mode == "delta": + assert pruning.score_delta is not None + best_score = ranked_states[0][0] + if best_score == -math.inf: + kept_ranked = ranked_states + else: + cutoff = best_score - pruning.score_delta + kept_ranked = [entry for entry in ranked_states if entry[0] >= cutoff] + if not kept_ranked: + kept_ranked = ranked_states[:1] + if pruning.hard_cap is not None and len(kept_ranked) > pruning.hard_cap: + kept_ranked = kept_ranked[:pruning.hard_cap] + elif pruning.mode == "mass": + assert pruning.mass_epsilon is not None + retained_target_mass = (1.0 - pruning.mass_epsilon) * total_mass + retained_mass = 0.0 + kept_ranked = [] + for entry in ranked_states: + kept_ranked.append(entry) + retained_mass += entry[1] + if retained_mass >= retained_target_mass: + break + if not kept_ranked: + kept_ranked = ranked_states[:1] + if pruning.hard_cap is not None and len(kept_ranked) > pruning.hard_cap: + kept_ranked = kept_ranked[:pruning.hard_cap] + else: + raise ValueError(f"Unsupported pruning mode: {pruning.mode!r}") + + kept_mass = sum(total for _, total, _, _ in kept_ranked) + dropped_mass = total_mass - kept_mass + kept_beam = [(state, total, delta) for _, total, state, delta in kept_ranked] + return kept_beam, dropped_mass + + +def _summarize_int_values(values: tuple[int, ...] | list[int]) -> IntegerSeriesSummary: + if not values: + return IntegerSeriesSummary( + count=0, + minimum=None, + median=None, + mean=None, + maximum=None, + ) + + sorted_values = sorted(values) + count = len(sorted_values) + lower = sorted_values[(count - 1) // 2] + upper = sorted_values[count // 2] + return IntegerSeriesSummary( + count=count, + minimum=sorted_values[0], + median=(lower + upper) / 2.0, + mean=sum(sorted_values) / count, + maximum=sorted_values[-1], + ) + + +def _format_optional_int(value: int | None) -> str: + return "none" if value is None else str(value) + + +def _format_pruning_value(value: float | None) -> str: + if value is None: + return "n/a" + return f"{value:.6g}" + + +def _format_summary_int(value: int | None) -> str: + return "n/a" if value is None else str(value) + + +def _format_summary_float(value: float | None, *, digits: int = 2) -> str: + return "n/a" if value is None else f"{value:.{digits}f}" + + +def _print_pruning_configuration(*, pruning: BeamPruningConfig, log_stream) -> None: + print(f"Beam Prune Mode: {pruning.mode}", file=log_stream) + if pruning.mode == "fixed": + print(f"Beam Width: {pruning.beam_width}", file=log_stream) + elif pruning.mode == "delta": + print(f"Beam Score Delta: {_format_pruning_value(pruning.score_delta)}", file=log_stream) + print(f"Beam Hard Cap: {_format_optional_int(pruning.hard_cap)}", file=log_stream) + elif pruning.mode == "mass": + assert pruning.mass_epsilon is not None + print(f"Beam Mass Epsilon: {_format_pruning_value(pruning.mass_epsilon)}", file=log_stream) + print(f"Beam Retained Mass: {_format_pruning_value(1.0 - pruning.mass_epsilon)}", file=log_stream) + print(f"Beam Hard Cap: {_format_optional_int(pruning.hard_cap)}", file=log_stream) + else: + raise ValueError(f"Unsupported pruning mode: {pruning.mode!r}") + + +def _beam_pruning_config_from_args(args: argparse.Namespace) -> BeamPruningConfig: + return BeamPruningConfig( + mode=args.beam_prune_mode, + beam_width=args.beam, + score_delta=args.beam_score_delta, + mass_epsilon=args.beam_mass_epsilon, + hard_cap=args.beam_hard_cap, + ) + + +def _as_bool_2d(data: np.ndarray, *, expected_cols: int, description: str) -> np.ndarray: + arr = np.asarray(data) + if arr.ndim != 2: + raise ValueError(f"Expected {description} to be a 2D array but got shape {arr.shape!r}.") + if arr.shape[1] != expected_cols: + raise ValueError( + f"Expected {description} to have {expected_cols} columns but got {arr.shape[1]}." + ) + if arr.dtype != np.bool_: + arr = arr.astype(np.bool_, copy=False) + return arr + + +def _sample_shot_arrays( + circuit: stim.Circuit, + *, + shots: int, + seed: int | None, +) -> tuple[np.ndarray, np.ndarray]: + sampler = circuit.compile_detector_sampler(seed=seed) + dets, obs = sampler.sample(shots=shots, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=circuit.num_detectors, description="sampled detector data"), + _as_bool_2d(obs, expected_cols=circuit.num_observables, description="sampled observable data"), + ) + + +def _read_detector_shot_arrays( + *, + path: str, + fmt: str, + num_detectors: int, + num_observables: int, +) -> tuple[np.ndarray, np.ndarray | None]: + common_kwargs = dict( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=num_detectors, + num_observables=num_observables, + ) + + if num_observables: + try: + dets, obs = stim.read_shot_data_file(**common_kwargs, separate_observables=True) + return ( + _as_bool_2d(dets, expected_cols=num_detectors, description="input detector data"), + _as_bool_2d(obs, expected_cols=num_observables, description="appended observable data"), + ) + except TypeError: + flat = stim.read_shot_data_file(**common_kwargs) + flat = _as_bool_2d( + flat, + expected_cols=num_detectors + num_observables, + description="combined detector/observable input data", + ) + return flat[:, :num_detectors], flat[:, num_detectors:] + + flat = stim.read_shot_data_file(**common_kwargs) + return _as_bool_2d(flat, expected_cols=num_detectors, description="input detector data"), None + + +def _read_observable_shot_array(*, path: str, fmt: str, num_observables: int) -> np.ndarray: + obs = stim.read_shot_data_file( + path=path, + format=fmt, + bit_packed=False, + num_measurements=0, + num_detectors=0, + num_observables=num_observables, + ) + return _as_bool_2d(obs, expected_cols=num_observables, description="observable input data") + + +def _apply_shot_range( + dets: np.ndarray, + obs: np.ndarray | None, + *, + shot_range_begin: int, + shot_range_end: int, +) -> tuple[np.ndarray, np.ndarray | None]: + if not (shot_range_begin or shot_range_end): + return dets, obs + + if shot_range_end < shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if shot_range_end > len(dets): + raise ValueError( + f"Shot range end {shot_range_end} is past the end of the shot data (size {len(dets)})." + ) + + dets = dets[shot_range_begin:shot_range_end] + if obs is not None: + obs = obs[shot_range_begin:shot_range_end] + return dets, obs + + +def _shots_from_arrays(dets: np.ndarray, obs: np.ndarray | None) -> list[DecodingShot]: + shots: list[DecodingShot] = [] + for shot_index in range(dets.shape[0]): + actual_logical = None if obs is None else bool(obs[shot_index, 0]) + shots.append( + DecodingShot( + det_mask=_mask_from_bool_row(dets[shot_index]), + actual_logical=actual_logical, + ) + ) + return shots + + +def _resolve_stdin_path_if_needed(path: str, *, temp_dir: str, stem: str) -> str: + if path != "-": + return path + temp_path = str(Path(temp_dir) / f"{stem}.bin") + with open(temp_path, "wb") as f: + f.write(sys.stdin.buffer.read()) + return temp_path + + +def _resolve_stdout_path_if_needed(path: str, *, temp_dir: str, stem: str) -> tuple[str, bool]: + if path != "-": + return path, False + return str(Path(temp_dir) / f"{stem}.bin"), True + + +def _copy_file_to_stdout(path: str) -> None: + sys.stdout.flush() + with open(path, "rb") as f: + shutil.copyfileobj(f, sys.stdout.buffer) + sys.stdout.buffer.flush() + + +def _load_shots( + circuit: stim.Circuit, + args: argparse.Namespace, + *, + temp_dir: str, +) -> list[DecodingShot]: + if args.in_file: + in_path = _resolve_stdin_path_if_needed(args.in_file, temp_dir=temp_dir, stem="shots_in") + appended_obs_count = circuit.num_observables if args.in_includes_appended_observables else 0 + dets, obs = _read_detector_shot_arrays( + path=in_path, + fmt=args.in_format, + num_detectors=circuit.num_detectors, + num_observables=appended_obs_count, + ) + + if args.obs_in_file: + obs_in_path = _resolve_stdin_path_if_needed(args.obs_in_file, temp_dir=temp_dir, stem="obs_in") + obs = _read_observable_shot_array( + path=obs_in_path, + fmt=args.obs_in_format, + num_observables=circuit.num_observables, + ) + if len(obs) != len(dets): + raise ValueError("Observable input ended before, or after, the detector shot data.") + else: + dets, obs = _sample_shot_arrays(circuit, shots=args.sample_num_shots, seed=args.sample_seed) + + dets, obs = _apply_shot_range( + dets, + obs, + shot_range_begin=args.shot_range_begin, + shot_range_end=args.shot_range_end, + ) + return _shots_from_arrays(dets, obs) + + +def decode_beam_search_detcost_ranked( + model: DecoderModel, + actual_dets_mask: int, + pruning: BeamPruningConfig, +) -> BeamDecodeResult: + start_time = time.perf_counter() + retained_state_counts: list[int] = [] + + if (actual_dets_mask & ~model.all_possible_dets_mask) != 0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=0.0, + kept_state_counts=(0,) * len(model.faults), + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + beam = [(0, 1.0, 1.0)] + discarded_mass = 0.0 + + for i, fault in enumerate(model.faults): + collapsed_probs: dict[int, list[float]] = {} + total_mass = 0.0 + retiring_mask = model.retiring_masks[i] + + if retiring_mask == 0: + for state, total, delta in beam: + total_mass += _accumulate_collapsed_state( + collapsed_probs, + state=state, + total=total * fault.q, + delta=delta * fault.q, + ) + + total_mass += _accumulate_collapsed_state( + collapsed_probs, + state=state ^ fault.det_mask, + total=total * fault.p, + delta=delta * fault.delta_scale, + ) + else: + expected_bits = actual_dets_mask & retiring_mask + keep_mask = ~retiring_mask + for state, total, delta in beam: + absent_total = total * fault.q + if absent_total > 0.0 and (state & retiring_mask) == expected_bits: + total_mass += _accumulate_collapsed_state( + collapsed_probs, + state=state & keep_mask, + total=absent_total, + delta=delta * fault.q, + ) + + toggled = state ^ fault.det_mask + present_total = total * fault.p + if present_total > 0.0 and (toggled & retiring_mask) == expected_bits: + total_mass += _accumulate_collapsed_state( + collapsed_probs, + state=toggled & keep_mask, + total=present_total, + delta=delta * fault.delta_scale, + ) + + if total_mass == 0.0: + retained_state_counts.append(0) + retained_state_counts.extend([0] * (len(model.faults) - i - 1)) + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=0.0, + discarded_mass=discarded_mass, + kept_state_counts=tuple(retained_state_counts), + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + ranked_states: list[tuple[float, float, int, float]] = [] + live_target_mask = actual_dets_mask & model.live_masks_after[i + 1] + next_future_detcost = model.future_detcost[i + 1] + for state, (total, delta) in collapsed_probs.items(): + if total <= 0.0: + continue + mismatch_mask = state ^ live_target_mask + penalty = _detcost_penalty(mismatch_mask=mismatch_mask, future_detcost=next_future_detcost) + if penalty == math.inf: + rank_score = -math.inf + else: + rank_score = math.log(total) - penalty + ranked_states.append((rank_score, total, state, delta)) + + beam, dropped_mass = _prune_ranked_states( + ranked_states, + total_mass=total_mass, + pruning=pruning, + ) + retained_state_counts.append(len(beam)) + + inv_total_mass = 1.0 / total_mass + discarded_mass = (discarded_mass + dropped_mass) * inv_total_mass + beam = [ + (state, total * inv_total_mass, delta * inv_total_mass) + for state, total, delta in beam + ] + + _, _, final_delta = next((entry for entry in beam if entry[0] == 0), (0, 0.0, 0.0)) + margin = abs(final_delta) + certified = margin > discarded_mass + + if final_delta == 0.0: + return BeamDecodeResult( + predicted_logical=None, + certified=False, + margin=margin, + discarded_mass=discarded_mass, + kept_state_counts=tuple(retained_state_counts), + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + return BeamDecodeResult( + predicted_logical=final_delta < 0.0, + certified=certified, + margin=margin, + discarded_mass=discarded_mass, + kept_state_counts=tuple(retained_state_counts), + max_width=model.max_width, + elapsed_seconds=time.perf_counter() - start_time, + ) + + +def _print_run_header( + *, + circuit: stim.Circuit, + args: argparse.Namespace, + pruning: BeamPruningConfig, + num_faults: int, + num_shots: int, + log_stream, +) -> None: + print(f"Running on circuit {args.circuit}", file=log_stream) + print(f"Total Detectors: {circuit.num_detectors}", file=log_stream) + print(f"Total Observables: {circuit.num_observables}", file=log_stream) + print(f"Total Faults: {num_faults}", file=log_stream) + _print_pruning_configuration(pruning=pruning, log_stream=log_stream) + if args.in_file: + print(f"Shot Input: {args.in_file}", file=log_stream) + print(f"Shot Input Format: {args.in_format}", file=log_stream) + if args.in_includes_appended_observables: + print("Observable Input: appended to --in", file=log_stream) + elif args.obs_in_file: + print(f"Observable Input: {args.obs_in_file}", file=log_stream) + print(f"Observable Format: {args.obs_in_format}", file=log_stream) + else: + print("Observable Input: none", file=log_stream) + else: + print(f"Sample Seed: {args.sample_seed}", file=log_stream) + print(f"Requested Shots: {args.sample_num_shots}", file=log_stream) + if args.shot_range_begin or args.shot_range_end: + print( + f"Shot Range: [{args.shot_range_begin}, {args.shot_range_end})", + file=log_stream, + ) + print(f"Num Shots: {num_shots}", file=log_stream) + + +def run_experiment(args: argparse.Namespace) -> ExperimentSummary: + circuit = stim.Circuit.from_file(args.circuit) + if circuit.num_observables != 1: + raise ValueError( + "This decoder currently supports exactly one logical observable, because it only tracks L0. " + f"The circuit has {circuit.num_observables} observables." + ) + + model = _build_decoder_model(circuit) + pruning = _beam_pruning_config_from_args(args) + log_stream = sys.stderr if args.out_file == "-" else sys.stdout + + with tempfile.TemporaryDirectory() as temp_dir: + shots = _load_shots(circuit, args, temp_dir=temp_dir) + _print_run_header( + circuit=circuit, + args=args, + pruning=pruning, + num_faults=len(model.faults), + num_shots=len(shots), + log_stream=log_stream, + ) + + num_errors = 0 + num_low_confidence = 0 + num_certified = 0 + num_truth_shots = 0 + num_scored_shots = 0 + total_elapsed = 0.0 + total_triggered = 0 + max_width_seen = 0 + predictions: list[bool | None] = [] + kept_state_accumulator = IntegerHistogramAccumulator() + + for shot_index, shot in enumerate(shots): + result = decode_beam_search_detcost_ranked(model, shot.det_mask, pruning) + predictions.append(result.predicted_logical) + kept_state_accumulator.add_many(result.kept_state_counts) + kept_state_summary = _summarize_int_values(result.kept_state_counts) + + success: bool | None + if shot.actual_logical is None or result.predicted_logical is None: + success = None + else: + success = result.predicted_logical == shot.actual_logical + + if result.predicted_logical is None: + num_low_confidence += 1 + if shot.actual_logical is not None: + num_truth_shots += 1 + if success is not None: + num_scored_shots += 1 + if not success: + num_errors += 1 + if result.certified: + num_certified += 1 + + total_elapsed += result.elapsed_seconds + triggered_dets = shot.det_mask.bit_count() + total_triggered += triggered_dets + max_width_seen = max(max_width_seen, result.max_width) + + shots_done = shot_index + 1 + error_rate_so_far = num_errors / num_scored_shots if num_scored_shots else 0.0 + print( + f"progress shots_done={shots_done}/{len(shots)} errors_so_far={num_errors} " + f"low_conf_so_far={num_low_confidence} scored_shots_so_far={num_scored_shots} " + f"error_rate_so_far={error_rate_so_far:.6f} elapsed_total_seconds={total_elapsed:.6f} " + f"kept_states_min={_format_summary_int(kept_state_summary.minimum)} " + f"kept_states_median={_format_summary_float(kept_state_summary.median)} " + f"kept_states_mean={_format_summary_float(kept_state_summary.mean)} " + f"kept_states_max={_format_summary_int(kept_state_summary.maximum)}", + file=log_stream, + ) + + if args.print_per_shot: + print( + f"shot={shot_index} triggered_detectors={triggered_dets} " + f"predicted_logical={result.predicted_logical} actual_logical={shot.actual_logical} " + f"success={success} certified={result.certified} " + f"margin={result.margin:.6e} discarded_mass={result.discarded_mass:.6e} " + f"kept_states_min={_format_summary_int(kept_state_summary.minimum)} " + f"kept_states_median={_format_summary_float(kept_state_summary.median)} " + f"kept_states_mean={_format_summary_float(kept_state_summary.mean)} " + f"kept_states_max={_format_summary_int(kept_state_summary.maximum)} " + f"elapsed_seconds={result.elapsed_seconds:.6f}", + file=log_stream, + ) + + if args.out_file: + output_path, copy_to_stdout = _resolve_stdout_path_if_needed( + args.out_file, + temp_dir=temp_dir, + stem="predictions_out", + ) + prediction_data = np.zeros((len(predictions), circuit.num_observables), dtype=np.bool_) + for shot_index, predicted_logical in enumerate(predictions): + prediction_data[shot_index, 0] = bool(predicted_logical) if predicted_logical is not None else False + + if args.out_format == "ptb64" and len(prediction_data) % 64 != 0: + raise ValueError("The ptb64 format requires the number of shots to be a multiple of 64.") + + stim.write_shot_data_file( + data=prediction_data, + path=output_path, + format=args.out_format, + num_measurements=0, + num_detectors=0, + num_observables=circuit.num_observables, + ) + if copy_to_stdout: + _copy_file_to_stdout(output_path) + if num_low_confidence: + print( + f"warning: wrote {num_low_confidence} low-confidence predictions as L0=0 because Stim result " + "files can only store bits, not unknown values.", + file=log_stream, + ) + + kept_state_summary = kept_state_accumulator.summary() + + print(f"Mean Triggered Dets: {total_triggered / max(1, len(shots)):.2f}", file=log_stream) + print(f"Max Width: {max_width_seen}", file=log_stream) + print(f"{'Kept States/Fault Min:':<26}{_format_summary_int(kept_state_summary.minimum)}", file=log_stream) + print(f"{'Kept States/Fault Median:':<26}{_format_summary_float(kept_state_summary.median)}", file=log_stream) + print(f"{'Kept States/Fault Mean:':<26}{_format_summary_float(kept_state_summary.mean)}", file=log_stream) + print(f"{'Kept States/Fault Max:':<26}{_format_summary_int(kept_state_summary.maximum)}", file=log_stream) + print(f"Certified Shots: {num_certified}", file=log_stream) + print(f"Low Confidence: {num_low_confidence}", file=log_stream) + print(f"Truth-Labeled Shots: {num_truth_shots}", file=log_stream) + print(f"Scored Shots: {num_scored_shots}", file=log_stream) + if num_truth_shots: + print(f"Logical Errors: {num_errors}", file=log_stream) + else: + print("Logical Errors: n/a", file=log_stream) + print(f"Total Seconds: {total_elapsed:.6f}", file=log_stream) + print(f"Mean Seconds/Shot: {total_elapsed / max(1, len(shots)):.6f}", file=log_stream) + + return ExperimentSummary( + predictions=predictions, + num_certified=num_certified, + num_low_confidence=num_low_confidence, + num_errors=num_errors, + num_truth_shots=num_truth_shots, + num_scored_shots=num_scored_shots, + total_elapsed=total_elapsed, + total_triggered=total_triggered, + max_width_seen=max_width_seen, + kept_state_summary=kept_state_summary, + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Run trellis beam decoding ranked by mass minus a detcost-style future penalty, " + "with optional adaptive threshold pruning and Stim-compatible shot-data I/O options." + ), + allow_abbrev=False, + ) + parser.add_argument("--circuit", required=True, help="Path to the .stim circuit file.") + parser.add_argument( + "--beam", + type=int, + default=1000, + help="Beam width cutoff used when --beam-prune-mode=fixed.", + ) + parser.add_argument( + "--beam-prune-mode", + choices=BEAM_PRUNE_MODES, + default="fixed", + help=( + "Beam pruning rule: fixed keeps the top --beam states, delta keeps all states within " + "--beam-score-delta of the best rank score, and mass keeps a rank-sorted prefix whose " + "retained normalized mass reaches 1-epsilon." + ), + ) + parser.add_argument( + "--beam-score-delta", + type=float, + default=None, + help=( + "For --beam-prune-mode=delta, keep every state whose rank score is within this additive " + "gap of the best state's rank score." + ), + ) + parser.add_argument( + "--beam-mass-epsilon", + type=float, + default=None, + help=( + "For --beam-prune-mode=mass, keep the smallest rank-sorted prefix whose retained " + "normalized mass is at least 1 - epsilon." + ), + ) + parser.add_argument( + "--beam-retained-mass", + type=float, + default=None, + help=( + "For --beam-prune-mode=mass, equivalent to setting --beam-mass-epsilon to " + "1 - retained_mass." + ), + ) + parser.add_argument( + "--beam-hard-cap", + type=int, + default=None, + help=( + "Optional hard cap on the number of states retained after delta or mass thresholding. " + "Ignored in fixed mode." + ), + ) + parser.add_argument( + "--sample-num-shots", + type=int, + default=None, + help="Number of sampled shots. Defaults to 1 unless --in is provided.", + ) + parser.add_argument("--sample-seed", type=int, default=None, help="Stim sampler seed.") + parser.add_argument( + "--shot-range-begin", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--shot-range-end", + type=int, + default=0, + help=( + "If both --shot-range-begin and --shot-range-end are 0, decode all available shots. " + "Otherwise only decode shots in [begin, end)." + ), + ) + parser.add_argument( + "--in", + dest="in_file", + default="", + help="File to read detection events from (use - for stdin).", + ) + parser.add_argument( + "--in-format", + "--in_format", + dest="in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--in-includes-appended-observables", + "--in_includes_appended_observables", + dest="in_includes_appended_observables", + action="store_true", + help="Assume the observable flips are appended to each shot in --in.", + ) + parser.add_argument( + "--obs-in", + "--obs_in", + dest="obs_in_file", + default="", + help="File to read observable flips from (use - for stdin).", + ) + parser.add_argument( + "--obs-in-format", + "--obs_in_format", + dest="obs_in_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file read by --obs-in ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--out", + dest="out_file", + default="", + help="File to write predicted observable flips to (use - for stdout).", + ) + parser.add_argument( + "--out-format", + "--out_format", + dest="out_format", + choices=STIM_RESULT_FORMATS, + default="01", + help=f"Format of the file written by --out ({STIM_RESULT_FORMATS_HELP}).", + ) + parser.add_argument( + "--print-per-shot", + action="store_true", + help="Print a detailed line per decoded shot.", + ) + args = parser.parse_args() + + if args.sample_num_shots is None: + # Preserve the original script's one-shot default while still allowing + # file input without requiring --sample-num-shots 0. + args.sample_num_shots = 0 if args.in_file else 1 + + if args.beam <= 0: + raise ValueError("--beam must be positive.") + if args.beam_hard_cap is not None and args.beam_hard_cap <= 0: + raise ValueError("--beam-hard-cap must be positive when provided.") + if args.beam_score_delta is not None: + if math.isnan(args.beam_score_delta) or args.beam_score_delta < 0.0: + raise ValueError("--beam-score-delta must be a non-negative number.") + if args.beam_mass_epsilon is not None: + if math.isnan(args.beam_mass_epsilon) or not (0.0 <= args.beam_mass_epsilon < 1.0): + raise ValueError("--beam-mass-epsilon must satisfy 0 <= epsilon < 1.") + if args.beam_retained_mass is not None: + if math.isnan(args.beam_retained_mass) or not (0.0 <= args.beam_retained_mass <= 1.0): + raise ValueError("--beam-retained-mass must satisfy 0 <= retained_mass <= 1.") + if args.beam_mass_epsilon is not None: + raise ValueError("Choose at most one of --beam-mass-epsilon and --beam-retained-mass.") + args.beam_mass_epsilon = 1.0 - args.beam_retained_mass + if args.sample_num_shots < 0: + raise ValueError("--sample-num-shots must be non-negative.") + if args.sample_seed is not None and args.sample_seed < 0: + raise ValueError("--sample-seed must be non-negative.") + if args.shot_range_begin < 0 or args.shot_range_end < 0: + raise ValueError("--shot-range-begin and --shot-range-end must be non-negative.") + if args.shot_range_end < args.shot_range_begin: + raise ValueError("Provided shot range must satisfy --shot-range-end >= --shot-range-begin.") + if args.in_includes_appended_observables and args.obs_in_file: + raise ValueError( + "Choose either --in-includes-appended-observables or --obs-in, not both." + ) + if args.obs_in_file and not args.in_file: + raise ValueError("Cannot load observable flips from --obs-in without also providing --in.") + if args.in_file == "-" and args.obs_in_file == "-": + raise ValueError("At most one of --in and --obs-in may read from stdin.") + + if args.beam_prune_mode == "fixed": + if args.beam_score_delta is not None: + raise ValueError("--beam-score-delta is only valid with --beam-prune-mode=delta.") + if args.beam_mass_epsilon is not None: + raise ValueError( + "--beam-mass-epsilon/--beam-retained-mass are only valid with --beam-prune-mode=mass." + ) + if args.beam_hard_cap is not None: + raise ValueError("--beam-hard-cap is only meaningful with adaptive pruning modes.") + elif args.beam_prune_mode == "delta": + if args.beam_score_delta is None: + raise ValueError("--beam-prune-mode=delta requires --beam-score-delta.") + if args.beam_mass_epsilon is not None: + raise ValueError( + "--beam-mass-epsilon/--beam-retained-mass are not valid with --beam-prune-mode=delta." + ) + elif args.beam_prune_mode == "mass": + if args.beam_mass_epsilon is None: + raise ValueError( + "--beam-prune-mode=mass requires --beam-mass-epsilon or --beam-retained-mass." + ) + if args.beam_score_delta is not None: + raise ValueError("--beam-score-delta is not valid with --beam-prune-mode=mass.") + else: + raise ValueError(f"Unsupported --beam-prune-mode {args.beam_prune_mode!r}.") + + num_shot_sources = int(args.sample_num_shots > 0) + int(bool(args.in_file)) + if num_shot_sources != 1: + raise ValueError("Requires exactly one source of shots: either --sample-num-shots > 0 or --in.") + + return args + + +if __name__ == "__main__": + run_experiment(_parse_args()) diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index 76eb3e2..5de656f 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -75,6 +75,12 @@ struct BranchPenaltyUpdate { double present_penalty = 0.0; }; +struct FinalizeKeptStateStatsOnExit { + TesseractTrellisDecoder* decoder; + + ~FinalizeKeptStateStatsOnExit(); +}; + std::vector parse_faults(const std::vector& errors, size_t num_observables) { std::vector faults; faults.reserve(errors.size()); @@ -662,6 +668,127 @@ double branch_score(const WidePackedMass& item, TesseractTrellisRankingMode rank return score_mass_and_penalty(item.mass, item.penalty, ranking_mode); } +void reset_kept_state_stats(TesseractTrellisDecoder* decoder) { + decoder->kept_state_sample_count = 0; + decoder->kept_state_min = 0; + decoder->kept_state_median = 0; + decoder->kept_state_mean = 0; + decoder->kept_state_max = 0; + if (!decoder->config.track_kept_state_stats) { + return; + } + + const size_t histogram_size = decoder->config.beam_width + 1; + if (decoder->kept_state_histogram_scratch.size() != histogram_size) { + decoder->kept_state_histogram_scratch.assign(histogram_size, 0); + } else { + std::fill(decoder->kept_state_histogram_scratch.begin(), + decoder->kept_state_histogram_scratch.end(), 0); + } +} + +void record_kept_state_count(TesseractTrellisDecoder* decoder, size_t kept_states) { + if (!decoder->config.track_kept_state_stats) { + return; + } + + kept_states = std::min(kept_states, decoder->config.beam_width); + if (decoder->kept_state_sample_count == 0) { + decoder->kept_state_min = kept_states; + decoder->kept_state_max = kept_states; + } else { + decoder->kept_state_min = std::min(decoder->kept_state_min, kept_states); + decoder->kept_state_max = std::max(decoder->kept_state_max, kept_states); + } + ++decoder->kept_state_sample_count; + decoder->kept_state_mean += kept_states; + ++decoder->kept_state_histogram_scratch[kept_states]; +} + +void finalize_kept_state_stats(TesseractTrellisDecoder* decoder) { + if (!decoder->config.track_kept_state_stats || decoder->kept_state_sample_count == 0) { + return; + } + + decoder->kept_state_mean /= decoder->kept_state_sample_count; + const size_t lower_target = (decoder->kept_state_sample_count - 1) >> 1; + const size_t upper_target = decoder->kept_state_sample_count >> 1; + size_t seen = 0; + size_t lower = 0; + size_t upper = 0; + bool lower_found = false; + for (size_t kept_states = 0; kept_states < decoder->kept_state_histogram_scratch.size(); + ++kept_states) { + seen += decoder->kept_state_histogram_scratch[kept_states]; + if (!lower_found && seen > lower_target) { + lower = kept_states; + lower_found = true; + } + if (seen > upper_target) { + upper = kept_states; + break; + } + } + decoder->kept_state_median = 0.5 * (lower + upper); +} + +FinalizeKeptStateStatsOnExit::~FinalizeKeptStateStatsOnExit() { + finalize_kept_state_stats(decoder); +} + +bool small_state_group_score_greater(const SmallStateGroup& a, const SmallStateGroup& b) { + if (a.score != b.score) { + return a.score > b.score; + } + return a.state < b.state; +} + +size_t trim_small_state_groups_by_beam_and_mass(std::vector* groups, + size_t beam_width, double beam_eps) { + if (groups->empty()) { + return 0; + } + + double total_mass = 0.0; + if (beam_eps > 0.0) { + for (const auto& group : *groups) { + total_mass += group.mass; + } + } + + if (groups->size() > beam_width) { + std::nth_element(groups->begin(), groups->begin() + beam_width, groups->end(), + [](const SmallStateGroup& a, const SmallStateGroup& b) { + return a.score > b.score; + }); + groups->resize(beam_width); + } else if (beam_eps <= 0.0) { + return groups->size(); + } + + if (beam_eps <= 0.0 || total_mass <= 0.0) { + return groups->size(); + } + + std::sort(groups->begin(), groups->end(), small_state_group_score_greater); + const double retained_target_mass = total_mass * (1.0 - beam_eps); + double retained_mass = 0.0; + size_t keep_count = 0; + while (keep_count < groups->size()) { + retained_mass += (*groups)[keep_count].mass; + ++keep_count; + if (retained_mass >= retained_target_mass) { + break; + } + } + groups->resize(keep_count); + std::sort(groups->begin(), groups->end(), + [](const SmallStateGroup& a, const SmallStateGroup& b) { + return a.begin < b.begin; + }); + return groups->size(); +} + std::vector collect_small_state_groups( const std::vector& entries, TesseractTrellisRankingMode ranking_mode) { std::vector groups; @@ -690,21 +817,14 @@ double state_score(const WideStateMass& item, TesseractTrellisRankingMode rankin return score_mass_and_penalty(item.mass, item.penalty, ranking_mode); } -size_t keep_top_states(std::vector& entries, size_t beam_width, +size_t keep_top_states(std::vector& entries, size_t beam_width, double beam_eps, TesseractTrellisRankingMode ranking_mode) { if (entries.empty()) { return 0; } auto groups = collect_small_state_groups(entries, ranking_mode); - if (groups.size() <= beam_width) { - return groups.size(); - } - - std::nth_element(groups.begin(), groups.begin() + beam_width, groups.end(), - [](const SmallStateGroup& a, const SmallStateGroup& b) { - return a.score > b.score; - }); - groups.resize(beam_width); + const size_t kept_group_count = + trim_small_state_groups_by_beam_and_mass(&groups, beam_width, beam_eps); std::vector kept; size_t kept_entries = 0; @@ -718,23 +838,55 @@ size_t keep_top_states(std::vector& entries, size_t beam_width, } } entries = std::move(kept); - return groups.size(); + return kept_group_count; } -size_t keep_top_states(std::vector& entries, size_t beam_width, +size_t keep_top_states(std::vector& entries, size_t beam_width, double beam_eps, TesseractTrellisRankingMode ranking_mode) { if (entries.empty()) { return 0; } auto totals = accumulate_state_masses_from_entries(entries); - if (totals.size() <= beam_width) { + double total_mass = 0.0; + if (beam_eps > 0.0) { + for (const auto& item : totals) { + total_mass += item.mass; + } + } + + if (totals.size() > beam_width) { + std::nth_element(totals.begin(), totals.begin() + beam_width, totals.end(), + [ranking_mode](const WideStateMass& a, const WideStateMass& b) { + return state_score(a, ranking_mode) > state_score(b, ranking_mode); + }); + totals.resize(beam_width); + } else if (beam_eps <= 0.0) { return totals.size(); } - std::nth_element(totals.begin(), totals.begin() + beam_width, totals.end(), - [ranking_mode](const WideStateMass& a, const WideStateMass& b) { - return state_score(a, ranking_mode) > state_score(b, ranking_mode); - }); - totals.resize(beam_width); + + if (beam_eps > 0.0 && total_mass > 0.0) { + std::sort(totals.begin(), totals.end(), + [ranking_mode](const WideStateMass& a, const WideStateMass& b) { + double sa = state_score(a, ranking_mode); + double sb = state_score(b, ranking_mode); + if (sa != sb) { + return sa > sb; + } + return wide_state_less(a.state_words, b.state_words); + }); + const double retained_target_mass = total_mass * (1.0 - beam_eps); + double retained_mass = 0.0; + size_t keep_count = 0; + while (keep_count < totals.size()) { + retained_mass += totals[keep_count].mass; + ++keep_count; + if (retained_mass >= retained_target_mass) { + break; + } + } + totals.resize(keep_count); + } + std::sort(totals.begin(), totals.end(), [](const WideStateMass& a, const WideStateMass& b) { return wide_state_less(a.state_words, b.state_words); }); @@ -953,6 +1105,7 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection num_states_merged = 0; max_beam_size_seen = 0; max_frontier_width_seen = 0; + reset_kept_state_stats(this); time_expand_seconds = 0; time_collapse_seconds = 0; time_truncate_seconds = 0; @@ -960,6 +1113,7 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection predicted_obs_mask = 0; total_mass_obs0 = 0; total_mass_obs1 = 0; + FinalizeKeptStateStatsOnExit kept_state_stats_guard{this}; if (has_small_layer_templates) { std::fill(scratch_small_current_target_bits.begin(), scratch_small_current_target_bits.end(), @@ -1069,7 +1223,8 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection size_t kept_states = 0; if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { - kept_states = keep_top_states(beam_entries, config.beam_width, config.ranking_mode); + kept_states = + keep_top_states(beam_entries, config.beam_width, config.beam_eps, config.ranking_mode); } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { keep_best_state_representatives(beam_entries, config.beam_width, config.ranking_mode); } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries || @@ -1077,6 +1232,12 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection keep_top_branch_entries(beam_entries, config.beam_width, config.ranking_mode); } normalize_items(beam_entries); + const size_t kept_state_sample = + beam_entries.empty() + ? 0 + : (config.prune_mode == TesseractTrellisPruneMode::MergedStates ? kept_states + : beam_entries.size()); + record_kept_state_count(this, kept_state_sample); if (beam_entries.empty()) { low_confidence_flag = true; return; @@ -1205,7 +1366,8 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection size_t kept_states = 0; if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { - kept_states = keep_top_states(beam_entries, config.beam_width, config.ranking_mode); + kept_states = + keep_top_states(beam_entries, config.beam_width, config.beam_eps, config.ranking_mode); } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { keep_best_state_representatives(beam_entries, config.beam_width, config.ranking_mode); } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries || @@ -1213,6 +1375,12 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection keep_top_branch_entries(beam_entries, config.beam_width, config.ranking_mode); } normalize_items(beam_entries); + const size_t kept_state_sample = + beam_entries.empty() + ? 0 + : (config.prune_mode == TesseractTrellisPruneMode::MergedStates ? kept_states + : beam_entries.size()); + record_kept_state_count(this, kept_state_sample); if (beam_entries.empty()) { low_confidence_flag = true; return; diff --git a/src/tesseract_trellis.h b/src/tesseract_trellis.h index 897d489..38bd2e8 100644 --- a/src/tesseract_trellis.h +++ b/src/tesseract_trellis.h @@ -76,8 +76,10 @@ struct TesseractTrellisWideLayerTemplate { struct TesseractTrellisConfig { stim::DetectorErrorModel dem; size_t beam_width = 1024; + double beam_eps = 0.0; size_t merge_interval = 1; bool verbose = false; + bool track_kept_state_stats = false; TesseractTrellisPruneMode prune_mode = TesseractTrellisPruneMode::MergedStates; TesseractTrellisRankingMode ranking_mode = TesseractTrellisRankingMode::MassOnly; }; @@ -96,6 +98,11 @@ struct TesseractTrellisDecoder { size_t num_states_merged = 0; size_t max_beam_size_seen = 0; size_t max_frontier_width_seen = 0; + size_t kept_state_sample_count = 0; + size_t kept_state_min = 0; + double kept_state_median = 0; + double kept_state_mean = 0; + size_t kept_state_max = 0; double time_expand_seconds = 0; double time_collapse_seconds = 0; double time_truncate_seconds = 0; @@ -117,6 +124,7 @@ struct TesseractTrellisDecoder { std::vector scratch_small_expected_retiring_bits; std::vector wide_layer_templates; std::vector initial_future_detcost; + std::vector kept_state_histogram_scratch; }; #endif // TESSERACT_TRELLIS_DECODER_H diff --git a/src/tesseract_trellis_main.cc b/src/tesseract_trellis_main.cc index f7febfb..8410ae9 100644 --- a/src/tesseract_trellis_main.cc +++ b/src/tesseract_trellis_main.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -66,6 +67,7 @@ struct Args { size_t num_threads = 1; size_t beam_width = 1024; + double beam_eps = 0.0; size_t merge_interval = 1; std::string prune_mode = "merged"; std::string ranking_mode = "mass"; @@ -119,11 +121,17 @@ struct Args { if (beam_width == 0) { throw std::invalid_argument("--beam must be at least 1."); } + if (!std::isfinite(beam_eps) || beam_eps < 0.0 || beam_eps >= 1.0) { + throw std::invalid_argument("--beam-eps must satisfy 0 <= beam-eps < 1."); + } if (merge_interval == 0) { throw std::invalid_argument("--merge-interval must be at least 1."); } parse_prune_mode(prune_mode); parse_ranking_mode(ranking_mode); + if (beam_eps != 0.0 && prune_mode != "merged") { + throw std::invalid_argument("--beam-eps is currently only supported with --prune-mode=merged."); + } } void extract(TesseractTrellisConfig& config, std::vector& shots, @@ -156,8 +164,10 @@ struct Args { } config.beam_width = beam_width; + config.beam_eps = beam_eps; config.merge_interval = merge_interval; config.verbose = verbose; + config.track_kept_state_stats = print_stats; config.prune_mode = parse_prune_mode(prune_mode); config.ranking_mode = parse_ranking_mode(ranking_mode); @@ -283,6 +293,13 @@ int main(int argc, char* argv[]) { std::thread::hardware_concurrency() == 0 ? 1 : std::thread::hardware_concurrency())) .store_into(args.num_threads); program.add_argument("--beam").default_value(size_t(1024)).store_into(args.beam_width); + program.add_argument("--beam-eps") + .help( + "With --prune-mode=merged, keep at most --beam states and also drop the suffix once " + "the kept prefix has accumulated at least (1 - beam-eps) of the total merged-state " + "mass. Use 0 to disable the mass-threshold cutoff.") + .default_value(0.0) + .store_into(args.beam_eps); program.add_argument("--merge-interval").default_value(size_t(1)).store_into(args.merge_interval); program.add_argument("--prune-mode") .help( @@ -325,6 +342,10 @@ int main(int argc, char* argv[]) { std::vector num_states_merged_per_shot(shots.size()); std::vector max_beam_size_per_shot(shots.size()); std::vector max_frontier_width_per_shot(shots.size()); + std::vector kept_state_min_per_shot(shots.size()); + std::vector kept_state_median_per_shot(shots.size()); + std::vector kept_state_mean_per_shot(shots.size()); + std::vector kept_state_max_per_shot(shots.size()); std::vector time_expand_per_shot(shots.size()); std::vector time_collapse_per_shot(shots.size()); std::vector time_truncate_per_shot(shots.size()); @@ -360,6 +381,10 @@ int main(int argc, char* argv[]) { num_states_merged_per_shot[shot_index] = decoder.num_states_merged; max_beam_size_per_shot[shot_index] = decoder.max_beam_size_seen; max_frontier_width_per_shot[shot_index] = decoder.max_frontier_width_seen; + kept_state_min_per_shot[shot_index] = decoder.kept_state_min; + kept_state_median_per_shot[shot_index] = decoder.kept_state_median; + kept_state_mean_per_shot[shot_index] = decoder.kept_state_mean; + kept_state_max_per_shot[shot_index] = decoder.kept_state_max; time_expand_per_shot[shot_index] = decoder.time_expand_seconds; time_collapse_per_shot[shot_index] = decoder.time_collapse_seconds; time_truncate_per_shot[shot_index] = decoder.time_truncate_seconds; @@ -385,6 +410,11 @@ int main(int argc, char* argv[]) { << " max_beam = " << max_beam_size_per_shot[shot_index] << " frontier_width = " << max_frontier_width_per_shot[shot_index] << " total_time_seconds = " << total_time_seconds << '\n'; + std::cout << "kept_states" + << " min=" << kept_state_min_per_shot[shot_index] + << " median=" << kept_state_median_per_shot[shot_index] + << " mean=" << kept_state_mean_per_shot[shot_index] + << " max=" << kept_state_max_per_shot[shot_index] << '\n'; std::cout << "branch_masses" << " obs0=" << mass0_predicted[shot_index] << " obs1=" << mass1_predicted[shot_index] << '\n'; std::cout << "phase_times_seconds" << " expand=" << time_expand_per_shot[shot_index] @@ -405,6 +435,7 @@ int main(int argc, char* argv[]) { nlohmann::json stats_json = {{"circuit_path", args.circuit_path}, {"dem_path", args.dem_path}, {"beam_width", args.beam_width}, + {"beam_eps", args.beam_eps}, {"sample_seed", args.sample_seed}, {"sample_num_shots", args.sample_num_shots}, {"num_threads", args.num_threads}, From 70afff3738b27c5965d018644de7011c0c7d8768 Mon Sep 17 00:00:00 2001 From: noajshu Date: Sun, 19 Apr 2026 18:21:38 +0000 Subject: [PATCH 19/25] clang format and treat low conf as logical error --- src/py/astar/plot_log.py | 41 ++++++++++++++++++++++++---------- src/tesseract_trellis.cc | 42 ++++++++++++++++++----------------- src/tesseract_trellis_main.cc | 6 ++--- 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/src/py/astar/plot_log.py b/src/py/astar/plot_log.py index 2ffd9ad..56e7687 100644 --- a/src/py/astar/plot_log.py +++ b/src/py/astar/plot_log.py @@ -8,7 +8,10 @@ def analyze_log(filename): errors = [] current_errors = 0 + current_low_conf = 0 + pending_error_diff = None + pending_low_conf_diff = None # Parse the log file line by line with open(filename, 'r') as f: @@ -18,33 +21,47 @@ def analyze_log(filename): continue if parts[0] == "num_shots": - # Find 'num_errors' and grab the value two indices over (past the '=') - idx = parts.index("num_errors") - errs = int(parts[idx + 2]) + # Find 'num_errors' and 'num_low_confidence' and grab the values + idx_err = parts.index("num_errors") + errs = int(parts[idx_err + 2]) + + idx_lc = parts.index("num_low_confidence") + lc = int(parts[idx_lc + 2]) - # Calculate errors for this specific shot and store it as pending + # Calculate diffs for this specific shot pending_error_diff = errs - current_errors + pending_low_conf_diff = lc - current_low_conf + current_errors = errs + current_low_conf = lc elif parts[0] == "branch_masses": - # Parse out the values from obs0=... and obs1=... obs0 = float(parts[1].split("=")[1]) obs1 = float(parts[2].split("=")[1]) - norm = obs0 + obs1 - if norm == 0: - obs0 = 0.5 - obs1 = 0.5 + + # Override if it was flagged as a low confidence shot + if pending_low_conf_diff is not None and pending_low_conf_diff > 0: + obs0 = 0.5 + obs1 = 0.5 + # Count the low confidence increment as additional logical errors + pending_error_diff += pending_low_conf_diff else: - obs0 /= norm - obs1 /= norm + norm = obs0 + obs1 + if norm == 0: + obs0 = 0.5 + obs1 = 0.5 + else: + obs0 /= norm + obs1 /= norm # Only append if we just successfully parsed a num_shots line if pending_error_diff is not None: min_masses.append(min(obs0, obs1)) errors.append(pending_error_diff) - # Reset pending diff to ensure we don't double-count + # Reset pending diffs to ensure we don't double-count pending_error_diff = None + pending_low_conf_diff = None min_masses = np.array(min_masses) errors = np.array(errors) diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index 5de656f..cfbf42a 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -757,10 +757,9 @@ size_t trim_small_state_groups_by_beam_and_mass(std::vector* gr } if (groups->size() > beam_width) { - std::nth_element(groups->begin(), groups->begin() + beam_width, groups->end(), - [](const SmallStateGroup& a, const SmallStateGroup& b) { - return a.score > b.score; - }); + std::nth_element( + groups->begin(), groups->begin() + beam_width, groups->end(), + [](const SmallStateGroup& a, const SmallStateGroup& b) { return a.score > b.score; }); groups->resize(beam_width); } else if (beam_eps <= 0.0) { return groups->size(); @@ -783,14 +782,12 @@ size_t trim_small_state_groups_by_beam_and_mass(std::vector* gr } groups->resize(keep_count); std::sort(groups->begin(), groups->end(), - [](const SmallStateGroup& a, const SmallStateGroup& b) { - return a.begin < b.begin; - }); + [](const SmallStateGroup& a, const SmallStateGroup& b) { return a.begin < b.begin; }); return groups->size(); } -std::vector collect_small_state_groups( - const std::vector& entries, TesseractTrellisRankingMode ranking_mode) { +std::vector collect_small_state_groups(const std::vector& entries, + TesseractTrellisRankingMode ranking_mode) { std::vector groups; if (entries.empty()) { return groups; @@ -805,9 +802,9 @@ std::vector collect_small_state_groups( mass += entries[end].mass; ++end; } - groups.push_back( - {state, mass, score_mass_and_penalty(mass, entries[begin].penalty, ranking_mode), begin, - end}); + groups.push_back({state, mass, + score_mass_and_penalty(mass, entries[begin].penalty, ranking_mode), begin, + end}); begin = end; } return groups; @@ -1233,10 +1230,10 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection } normalize_items(beam_entries); const size_t kept_state_sample = - beam_entries.empty() - ? 0 - : (config.prune_mode == TesseractTrellisPruneMode::MergedStates ? kept_states - : beam_entries.size()); + beam_entries.empty() ? 0 + : (config.prune_mode == TesseractTrellisPruneMode::MergedStates + ? kept_states + : beam_entries.size()); record_kept_state_count(this, kept_state_sample); if (beam_entries.empty()) { low_confidence_flag = true; @@ -1316,6 +1313,11 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection next_entries.clear(); next_entries.reserve(beam_entries.size() * 2); + if (config.verbose) { + std::cout << "expanding layer " << layer_index << " / " << (wide_layer_templates.size() - 1) + << std::endl; + std::cout << "states to expand = " << beam_entries.size() << std::endl; + } for (const auto& item : beam_entries) { ++num_states_expanded; BranchPenaltyUpdate update = compute_wide_branch_update( @@ -1376,10 +1378,10 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection } normalize_items(beam_entries); const size_t kept_state_sample = - beam_entries.empty() - ? 0 - : (config.prune_mode == TesseractTrellisPruneMode::MergedStates ? kept_states - : beam_entries.size()); + beam_entries.empty() ? 0 + : (config.prune_mode == TesseractTrellisPruneMode::MergedStates + ? kept_states + : beam_entries.size()); record_kept_state_count(this, kept_state_sample); if (beam_entries.empty()) { low_confidence_flag = true; diff --git a/src/tesseract_trellis_main.cc b/src/tesseract_trellis_main.cc index 8410ae9..224ee26 100644 --- a/src/tesseract_trellis_main.cc +++ b/src/tesseract_trellis_main.cc @@ -130,7 +130,8 @@ struct Args { parse_prune_mode(prune_mode); parse_ranking_mode(ranking_mode); if (beam_eps != 0.0 && prune_mode != "merged") { - throw std::invalid_argument("--beam-eps is currently only supported with --prune-mode=merged."); + throw std::invalid_argument( + "--beam-eps is currently only supported with --prune-mode=merged."); } } @@ -410,8 +411,7 @@ int main(int argc, char* argv[]) { << " max_beam = " << max_beam_size_per_shot[shot_index] << " frontier_width = " << max_frontier_width_per_shot[shot_index] << " total_time_seconds = " << total_time_seconds << '\n'; - std::cout << "kept_states" - << " min=" << kept_state_min_per_shot[shot_index] + std::cout << "kept_states" << " min=" << kept_state_min_per_shot[shot_index] << " median=" << kept_state_median_per_shot[shot_index] << " mean=" << kept_state_mean_per_shot[shot_index] << " max=" << kept_state_max_per_shot[shot_index] << '\n'; From 56996facf54c25e6c08fed19d8902f40e1971f55 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sun, 19 Apr 2026 12:01:54 -0700 Subject: [PATCH 20/25] simplify by removing the merge interval and prune mode from tesseract trellis --- src/tesseract_trellis.cc | 229 ++-------------------------------- src/tesseract_trellis.h | 9 -- src/tesseract_trellis_main.cc | 39 +----- 3 files changed, 15 insertions(+), 262 deletions(-) diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index cfbf42a..5628070 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -660,14 +660,6 @@ double score_mass_and_penalty(double mass, double penalty, return std::log(mass) - penalty; } -double branch_score(const PackedMass& item, TesseractTrellisRankingMode ranking_mode) { - return score_mass_and_penalty(item.mass, item.penalty, ranking_mode); -} - -double branch_score(const WidePackedMass& item, TesseractTrellisRankingMode ranking_mode) { - return score_mass_and_penalty(item.mass, item.penalty, ranking_mode); -} - void reset_kept_state_stats(TesseractTrellisDecoder* decoder) { decoder->kept_state_sample_count = 0; decoder->kept_state_min = 0; @@ -903,133 +895,6 @@ size_t keep_top_states(std::vector& entries, size_t beam_width, return totals.size(); } -void keep_best_state_representatives(std::vector& entries, size_t beam_width, - TesseractTrellisRankingMode ranking_mode) { - if (entries.empty()) { - return; - } - if (beam_width == 0) { - entries.clear(); - return; - } - - std::vector representative_indices; - representative_indices.reserve(entries.size()); - size_t begin = 0; - while (begin < entries.size()) { - uint64_t state = unpack_small_state(entries[begin].key); - size_t best = begin; - double best_score = branch_score(entries[begin], ranking_mode); - size_t end = begin + 1; - while (end < entries.size() && unpack_small_state(entries[end].key) == state) { - double score = branch_score(entries[end], ranking_mode); - if (score > best_score) { - best = end; - best_score = score; - } - ++end; - } - representative_indices.push_back(best); - begin = end; - } - - if (representative_indices.size() > beam_width) { - std::nth_element(representative_indices.begin(), representative_indices.begin() + beam_width, - representative_indices.end(), [&entries, ranking_mode](size_t a, size_t b) { - double sa = branch_score(entries[a], ranking_mode); - double sb = branch_score(entries[b], ranking_mode); - if (sa != sb) { - return sa > sb; - } - return a < b; - }); - representative_indices.resize(beam_width); - } - std::sort(representative_indices.begin(), representative_indices.end()); - - std::vector kept; - kept.reserve(representative_indices.size()); - for (size_t idx : representative_indices) { - kept.push_back(entries[idx]); - } - entries = std::move(kept); -} - -void keep_best_state_representatives(std::vector& entries, size_t beam_width, - TesseractTrellisRankingMode ranking_mode) { - if (entries.empty()) { - return; - } - if (beam_width == 0) { - entries.clear(); - return; - } - - std::vector representative_indices; - representative_indices.reserve(entries.size()); - size_t begin = 0; - while (begin < entries.size()) { - size_t best = begin; - double best_score = branch_score(entries[begin], ranking_mode); - size_t end = begin + 1; - while (end < entries.size() && entries[end].state_words == entries[begin].state_words) { - double score = branch_score(entries[end], ranking_mode); - if (score > best_score) { - best = end; - best_score = score; - } - ++end; - } - representative_indices.push_back(best); - begin = end; - } - - if (representative_indices.size() > beam_width) { - std::nth_element(representative_indices.begin(), representative_indices.begin() + beam_width, - representative_indices.end(), [&entries, ranking_mode](size_t a, size_t b) { - double sa = branch_score(entries[a], ranking_mode); - double sb = branch_score(entries[b], ranking_mode); - if (sa != sb) { - return sa > sb; - } - return a < b; - }); - representative_indices.resize(beam_width); - } - std::sort(representative_indices.begin(), representative_indices.end()); - - std::vector kept; - kept.reserve(representative_indices.size()); - for (size_t idx : representative_indices) { - kept.push_back(std::move(entries[idx])); - } - entries = std::move(kept); -} - -void keep_top_branch_entries(std::vector& entries, size_t beam_width, - TesseractTrellisRankingMode ranking_mode) { - if (entries.size() <= beam_width) { - return; - } - std::nth_element(entries.begin(), entries.begin() + beam_width, entries.end(), - [ranking_mode](const PackedMass& a, const PackedMass& b) { - return branch_score(a, ranking_mode) > branch_score(b, ranking_mode); - }); - entries.resize(beam_width); -} - -void keep_top_branch_entries(std::vector& entries, size_t beam_width, - TesseractTrellisRankingMode ranking_mode) { - if (entries.size() <= beam_width) { - return; - } - std::nth_element(entries.begin(), entries.begin() + beam_width, entries.end(), - [ranking_mode](const WidePackedMass& a, const WidePackedMass& b) { - return branch_score(a, ranking_mode) > branch_score(b, ranking_mode); - }); - entries.resize(beam_width); -} - void prepare_projected_fault_masks(std::vector* layers) { for (auto& layer : *layers) { layer.projected_fault_mask = 0; @@ -1198,57 +1063,22 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection std::chrono::duration_cast(t1 - t0).count() / 1e6; beam_entries.swap(next_entries); - bool at_checkpoint = ((layer_index + 1) % config.merge_interval == 0) || - (layer_index + 1 == small_layer_templates.size()); - if (!at_checkpoint) { - normalize_items(beam_entries); - max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); - if (beam_entries.empty()) { - low_confidence_flag = true; - return; - } - continue; - } - auto t2a = std::chrono::high_resolution_clock::now(); - if (config.prune_mode != TesseractTrellisPruneMode::NoMerge) { - merge_equal_keys_inplace(beam_entries); - } + merge_equal_keys_inplace(beam_entries); auto t2 = std::chrono::high_resolution_clock::now(); time_collapse_seconds += std::chrono::duration_cast(t2 - t2a).count() / 1e6; - size_t kept_states = 0; - if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { - kept_states = - keep_top_states(beam_entries, config.beam_width, config.beam_eps, config.ranking_mode); - } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { - keep_best_state_representatives(beam_entries, config.beam_width, config.ranking_mode); - } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries || - config.prune_mode == TesseractTrellisPruneMode::NoMerge) { - keep_top_branch_entries(beam_entries, config.beam_width, config.ranking_mode); - } + size_t kept_states = + keep_top_states(beam_entries, config.beam_width, config.beam_eps, config.ranking_mode); normalize_items(beam_entries); - const size_t kept_state_sample = - beam_entries.empty() ? 0 - : (config.prune_mode == TesseractTrellisPruneMode::MergedStates - ? kept_states - : beam_entries.size()); - record_kept_state_count(this, kept_state_sample); + record_kept_state_count(this, beam_entries.empty() ? 0 : kept_states); if (beam_entries.empty()) { low_confidence_flag = true; return; } - if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { - num_states_merged += kept_states; - max_beam_size_seen = std::max(max_beam_size_seen, kept_states); - } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { - num_states_merged += beam_entries.size(); - max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); - } else { - num_states_merged += beam_entries.size(); - max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); - } + num_states_merged += kept_states; + max_beam_size_seen = std::max(max_beam_size_seen, kept_states); auto t3 = std::chrono::high_resolution_clock::now(); time_truncate_seconds += std::chrono::duration_cast(t3 - t2).count() / 1e6; @@ -1346,57 +1176,22 @@ void TesseractTrellisDecoder::decode_shot(const std::vector& detection std::chrono::duration_cast(t1 - t0).count() / 1e6; beam_entries.swap(next_entries); - bool at_checkpoint = ((layer_index + 1) % config.merge_interval == 0) || - (layer_index + 1 == wide_layer_templates.size()); - if (!at_checkpoint) { - normalize_items(beam_entries); - max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); - if (beam_entries.empty()) { - low_confidence_flag = true; - return; - } - continue; - } - auto t2a = std::chrono::high_resolution_clock::now(); - if (config.prune_mode != TesseractTrellisPruneMode::NoMerge) { - merge_equal_keys_inplace(beam_entries); - } + merge_equal_keys_inplace(beam_entries); auto t2 = std::chrono::high_resolution_clock::now(); time_collapse_seconds += std::chrono::duration_cast(t2 - t2a).count() / 1e6; - size_t kept_states = 0; - if (config.prune_mode == TesseractTrellisPruneMode::MergedStates) { - kept_states = - keep_top_states(beam_entries, config.beam_width, config.beam_eps, config.ranking_mode); - } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { - keep_best_state_representatives(beam_entries, config.beam_width, config.ranking_mode); - } else if (config.prune_mode == TesseractTrellisPruneMode::BranchEntries || - config.prune_mode == TesseractTrellisPruneMode::NoMerge) { - keep_top_branch_entries(beam_entries, config.beam_width, config.ranking_mode); - } + size_t kept_states = + keep_top_states(beam_entries, config.beam_width, config.beam_eps, config.ranking_mode); normalize_items(beam_entries); - const size_t kept_state_sample = - beam_entries.empty() ? 0 - : (config.prune_mode == TesseractTrellisPruneMode::MergedStates - ? kept_states - : beam_entries.size()); - record_kept_state_count(this, kept_state_sample); + record_kept_state_count(this, beam_entries.empty() ? 0 : kept_states); if (beam_entries.empty()) { low_confidence_flag = true; return; } - if (config.prune_mode == TesseractTrellisPruneMode::NoMerge) { - num_states_merged += beam_entries.size(); - max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); - } else if (config.prune_mode == TesseractTrellisPruneMode::KeepBest) { - num_states_merged += beam_entries.size(); - max_beam_size_seen = std::max(max_beam_size_seen, beam_entries.size()); - } else { - num_states_merged += kept_states; - max_beam_size_seen = std::max(max_beam_size_seen, kept_states); - } + num_states_merged += kept_states; + max_beam_size_seen = std::max(max_beam_size_seen, kept_states); auto t3 = std::chrono::high_resolution_clock::now(); time_truncate_seconds += std::chrono::duration_cast(t3 - t2).count() / 1e6; diff --git a/src/tesseract_trellis.h b/src/tesseract_trellis.h index 38bd2e8..233f528 100644 --- a/src/tesseract_trellis.h +++ b/src/tesseract_trellis.h @@ -22,13 +22,6 @@ #include "common.h" #include "stim.h" -enum class TesseractTrellisPruneMode { - MergedStates, - KeepBest, - BranchEntries, - NoMerge, -}; - enum class TesseractTrellisRankingMode { MassOnly, FutureDetcostRanked, @@ -77,10 +70,8 @@ struct TesseractTrellisConfig { stim::DetectorErrorModel dem; size_t beam_width = 1024; double beam_eps = 0.0; - size_t merge_interval = 1; bool verbose = false; bool track_kept_state_stats = false; - TesseractTrellisPruneMode prune_mode = TesseractTrellisPruneMode::MergedStates; TesseractTrellisRankingMode ranking_mode = TesseractTrellisRankingMode::MassOnly; }; diff --git a/src/tesseract_trellis_main.cc b/src/tesseract_trellis_main.cc index 224ee26..28e2386 100644 --- a/src/tesseract_trellis_main.cc +++ b/src/tesseract_trellis_main.cc @@ -27,14 +27,6 @@ namespace { -TesseractTrellisPruneMode parse_prune_mode(const std::string& value) { - if (value == "merged") return TesseractTrellisPruneMode::MergedStates; - if (value == "keep-best") return TesseractTrellisPruneMode::KeepBest; - if (value == "branch") return TesseractTrellisPruneMode::BranchEntries; - if (value == "none") return TesseractTrellisPruneMode::NoMerge; - throw std::invalid_argument("Unknown trellis prune mode: " + value); -} - TesseractTrellisRankingMode parse_ranking_mode(const std::string& value) { if (value == "mass") return TesseractTrellisRankingMode::MassOnly; if (value == "future-detcost") return TesseractTrellisRankingMode::FutureDetcostRanked; @@ -68,8 +60,6 @@ struct Args { size_t num_threads = 1; size_t beam_width = 1024; double beam_eps = 0.0; - size_t merge_interval = 1; - std::string prune_mode = "merged"; std::string ranking_mode = "mass"; bool verbose = false; @@ -124,15 +114,7 @@ struct Args { if (!std::isfinite(beam_eps) || beam_eps < 0.0 || beam_eps >= 1.0) { throw std::invalid_argument("--beam-eps must satisfy 0 <= beam-eps < 1."); } - if (merge_interval == 0) { - throw std::invalid_argument("--merge-interval must be at least 1."); - } - parse_prune_mode(prune_mode); parse_ranking_mode(ranking_mode); - if (beam_eps != 0.0 && prune_mode != "merged") { - throw std::invalid_argument( - "--beam-eps is currently only supported with --prune-mode=merged."); - } } void extract(TesseractTrellisConfig& config, std::vector& shots, @@ -166,10 +148,8 @@ struct Args { config.beam_width = beam_width; config.beam_eps = beam_eps; - config.merge_interval = merge_interval; config.verbose = verbose; config.track_kept_state_stats = print_stats; - config.prune_mode = parse_prune_mode(prune_mode); config.ranking_mode = parse_ranking_mode(ranking_mode); if (sample_num_shots > 0) { @@ -296,24 +276,11 @@ int main(int argc, char* argv[]) { program.add_argument("--beam").default_value(size_t(1024)).store_into(args.beam_width); program.add_argument("--beam-eps") .help( - "With --prune-mode=merged, keep at most --beam states and also drop the suffix once " - "the kept prefix has accumulated at least (1 - beam-eps) of the total merged-state " - "mass. Use 0 to disable the mass-threshold cutoff.") + "Keep at most --beam merged states and also drop the suffix once the kept prefix has " + "accumulated at least (1 - beam-eps) of the total merged-state mass. Use 0 to disable " + "the mass-threshold cutoff.") .default_value(0.0) .store_into(args.beam_eps); - program.add_argument("--merge-interval").default_value(size_t(1)).store_into(args.merge_interval); - program.add_argument("--prune-mode") - .help( - "Trellis pruning mode: merged, keep-best, branch, or none. " - "merged sums probabilities of all branches with the same residual detection events. " - "keep-best keeps only the single highest-probability branch for each residual detection " - "state. " - "branch ranks branches individually, but still merges exact duplicate (state, " - "observable) entries first. " - "none skips even that exact-duplicate merge, so identical branches may occupy multiple " - "beam slots.") - .default_value(std::string("merged")) - .store_into(args.prune_mode); program.add_argument("--ranking-mode") .help("Trellis ranking mode: mass or future-detcost") .default_value(std::string("mass")) From 87c2f3b0543174603cbe215c065553edfd8dc871 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sun, 19 Apr 2026 12:54:24 -0700 Subject: [PATCH 21/25] update trellis implementation with bit packing and a template --- src/tesseract_trellis.cc | 1159 ++++++++++++++------------------------ src/tesseract_trellis.h | 30 +- 2 files changed, 420 insertions(+), 769 deletions(-) diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index 5628070..e8dd9ec 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -15,6 +15,7 @@ #include "tesseract_trellis.h" #include +#include #include #include #include @@ -26,12 +27,27 @@ #include #include #include +#include #include #include "utils.h" +struct TesseractTrellisWideKernelBase { + virtual ~TesseractTrellisWideKernelBase() = default; + virtual void decode_shot(TesseractTrellisDecoder* decoder, + const std::vector& detections) const = 0; +}; + namespace { +constexpr size_t kMaxCompiledWideStateWords = 4; + +#if defined(__GNUC__) || defined(__clang__) +#define TESSERACT_ALWAYS_INLINE inline __attribute__((always_inline)) +#else +#define TESSERACT_ALWAYS_INLINE inline +#endif + struct Fault { size_t error_index; double likelihood_cost; @@ -41,31 +57,40 @@ struct Fault { std::vector detectors; }; -struct PackedMass { - uint64_t key; - double mass; - double penalty; -}; - -struct SmallStateGroup { - uint64_t state; +struct WideStateGroup { double mass; double score; size_t begin; size_t end; }; -struct WidePackedMass { - std::vector state_words; - uint64_t obs_mask; - double mass; - double penalty; +template +using FixedWideStateWords = std::array; + +template +struct FixedWidePackedMass { + FixedWideStateWords state_words{}; + uint64_t obs_mask = 0; + double mass = 0.0; + double penalty = 0.0; }; -struct WideStateMass { - std::vector state_words; - double mass; - double penalty; +template +struct CompiledWideLayerTemplate { + double q = 0.0; + double p = 0.0; + uint64_t obs_mask = 0; + std::array surviving_masks{}; + std::array projection_dst_words{}; + std::array projection_dst_offsets{}; + std::array projected_fault_mask_words{}; + std::vector fault_detector_indices; + std::vector fault_word_indices; + std::vector fault_bit_masks; + std::vector fault_was_active_before; + std::vector next_local_indices; + std::vector current_costs; + std::vector next_costs; }; struct BranchPenaltyUpdate { @@ -111,94 +136,6 @@ std::vector parse_faults(const std::vector& errors, size_t return faults; } -bool build_small_layer_templates(const std::vector& faults, size_t num_detectors, - std::vector* layers, - size_t* max_frontier_width_seen) { - std::vector last_seen(num_detectors, std::numeric_limits::max()); - for (size_t i = 0; i < faults.size(); ++i) { - for (int d : faults[i].detectors) { - last_seen[d] = i; - } - } - - std::vector active_detectors; - active_detectors.reserve(num_detectors); - std::vector global_to_local(num_detectors, -1); - layers->clear(); - layers->reserve(faults.size()); - *max_frontier_width_seen = 0; - - for (size_t i = 0; i < faults.size(); ++i) { - const size_t previous_width = active_detectors.size(); - for (int d : faults[i].detectors) { - if (global_to_local[d] == -1) { - global_to_local[d] = active_detectors.size(); - active_detectors.push_back(d); - } - } - - *max_frontier_width_seen = std::max(*max_frontier_width_seen, active_detectors.size()); - if (*max_frontier_width_seen > 63) { - return false; - } - - TesseractTrellisSmallLayerTemplate layer{ - .q = std::exp(faults[i].log_q), - .p = std::exp(faults[i].log_p), - .obs_flip_bit = faults[i].obs_mask & 1, - .local_det_mask = 0, - .retiring_mask = 0, - .surviving_mask = 0, - .projected_fault_mask = 0, - .previous_width = previous_width, - .surviving_local_indices = {}, - .current_active_detectors = active_detectors, - .next_frontier_costs = {}, - .detcost_transition = {}, - }; - for (int d : faults[i].detectors) { - layer.local_det_mask ^= uint64_t{1} << global_to_local[d]; - } - for (size_t local = 0; local < active_detectors.size(); ++local) { - const int d = active_detectors[local]; - if (last_seen[d] == i) { - layer.retiring_mask ^= uint64_t{1} << local; - } else { - layer.surviving_local_indices.push_back((uint8_t)local); - } - } - uint64_t live_mask = (uint64_t{1} << active_detectors.size()) - 1; - layer.surviving_mask = live_mask & ~layer.retiring_mask; - - std::vector next_active; - next_active.reserve(layer.surviving_local_indices.size()); - std::fill(global_to_local.begin(), global_to_local.end(), -1); - for (size_t next_local = 0; next_local < layer.surviving_local_indices.size(); ++next_local) { - int d = active_detectors[layer.surviving_local_indices[next_local]]; - global_to_local[d] = next_local; - next_active.push_back(d); - } - active_detectors = std::move(next_active); - layers->push_back(std::move(layer)); - } - - return true; -} - -void build_small_detector_layer_refs( - const std::vector& layers, size_t num_detectors, - std::vector>* refs) { - refs->assign(num_detectors, {}); - for (size_t layer_index = 0; layer_index < layers.size(); ++layer_index) { - const auto& layer = layers[layer_index]; - for (size_t local = 0; local < layer.current_active_detectors.size(); ++local) { - int detector = layer.current_active_detectors[local]; - (*refs)[(size_t)detector].push_back( - {static_cast(layer_index), static_cast(local)}); - } - } -} - void build_wide_layer_templates(const std::vector& faults, size_t num_detectors, std::vector* layers, size_t* max_frontier_width_seen) { @@ -318,7 +255,7 @@ size_t num_state_words(size_t num_bits) { return (num_bits + 63) >> 6; } -uint64_t compact_bits_u64(uint64_t value, uint64_t mask) { +TESSERACT_ALWAYS_INLINE uint64_t compact_bits_u64(uint64_t value, uint64_t mask) { #if defined(__BMI2__) && \ (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) return _pext_u64(value, mask); @@ -337,81 +274,6 @@ uint64_t compact_bits_u64(uint64_t value, uint64_t mask) { #endif } -bool get_state_bit(const std::vector& state_words, size_t bit, size_t logical_width) { - if (bit >= logical_width) { - return false; - } - size_t word = bit >> 6; - if (word >= state_words.size()) { - return false; - } - return (state_words[word] >> (bit & 63)) & 1ULL; -} - -void xor_state_words(std::vector& state_words, const std::vector& mask_words) { - if (state_words.size() < mask_words.size()) { - state_words.resize(mask_words.size(), 0); - } - for (size_t k = 0; k < mask_words.size(); ++k) { - state_words[k] ^= mask_words[k]; - } -} - -std::vector project_wide_state(const std::vector& state_words, - size_t logical_width, - const std::vector& surviving_local_indices) { - std::vector out(num_state_words(surviving_local_indices.size()), 0); - for (size_t next_local = 0; next_local < surviving_local_indices.size(); ++next_local) { - size_t current_local = (size_t)surviving_local_indices[next_local]; - if (get_state_bit(state_words, current_local, logical_width)) { - out[next_local >> 6] ^= uint64_t{1} << (next_local & 63); - } - } - return out; -} - -bool wide_state_less(const std::vector& a, const std::vector& b) { - if (a.size() != b.size()) { - return a.size() < b.size(); - } - for (size_t k = a.size(); k-- > 0;) { - if (a[k] != b[k]) { - return a[k] < b[k]; - } - } - return false; -} - -bool wide_state_zero(const std::vector& state_words) { - for (uint64_t word : state_words) { - if (word != 0) { - return false; - } - } - return true; -} - -uint64_t project_small_state(uint64_t state, uint64_t surviving_mask) { - return compact_bits_u64(state, surviving_mask); -} - -double compute_initial_penalty_for_target_bits(uint64_t target_bits, - const std::vector& active_detectors, - const std::vector& initial_future_detcost) { - double total = 0.0; - while (target_bits) { - uint64_t low_bit = target_bits & -target_bits; - size_t local = (size_t)std::countr_zero(low_bit); - target_bits ^= low_bit; - double best = initial_future_detcost[(size_t)active_detectors[local]]; - if (best == INF) { - return INF; - } - total += best; - } - return total; -} - double compute_initial_penalty_for_active_detectors( const std::vector& active_detectors, const boost::dynamic_bitset<>& actual_dets, const std::vector& initial_future_detcost) { @@ -429,226 +291,6 @@ double compute_initial_penalty_for_active_detectors( return total; } -BranchPenaltyUpdate compute_small_branch_update(uint64_t base_state, size_t previous_width, - double current_penalty, - uint64_t current_target_bits, - const TesseractTrellisDetcostTransition& transition, - bool compute_penalties) { - BranchPenaltyUpdate update; - update.absent_penalty = compute_penalties ? current_penalty : 0.0; - update.present_penalty = compute_penalties ? current_penalty : 0.0; - - for (size_t k = 0; k < transition.fault_local_indices.size(); ++k) { - size_t local = transition.fault_local_indices[k]; - bool state_bit = local < previous_width && ((base_state >> local) & 1ULL); - bool target_bit = (current_target_bits >> local) & 1ULL; - bool mismatch = state_bit ^ target_bit; - int32_t next_local = transition.next_local_indices[k]; - - if (next_local < 0) { - if (mismatch) { - update.absent_valid = false; - } else { - update.present_valid = false; - } - } - - if (!compute_penalties) { - continue; - } - - double prev_contrib = (local < previous_width && mismatch) ? transition.current_costs[k] : 0.0; - double absent_contrib = (next_local >= 0 && mismatch) ? transition.next_costs[k] : 0.0; - double present_contrib = (next_local >= 0 && !mismatch) ? transition.next_costs[k] : 0.0; - update.absent_penalty += absent_contrib - prev_contrib; - update.present_penalty += present_contrib - prev_contrib; - } - - return update; -} - -BranchPenaltyUpdate compute_wide_branch_update(const std::vector& base_state_words, - size_t previous_width, double current_penalty, - const std::vector& current_active_detectors, - const boost::dynamic_bitset<>& actual_dets, - const TesseractTrellisDetcostTransition& transition, - bool compute_penalties) { - BranchPenaltyUpdate update; - update.absent_penalty = compute_penalties ? current_penalty : 0.0; - update.present_penalty = compute_penalties ? current_penalty : 0.0; - - for (size_t k = 0; k < transition.fault_local_indices.size(); ++k) { - size_t local = transition.fault_local_indices[k]; - bool state_bit = get_state_bit(base_state_words, local, previous_width); - bool target_bit = actual_dets[(size_t)current_active_detectors[local]]; - bool mismatch = state_bit ^ target_bit; - int32_t next_local = transition.next_local_indices[k]; - - if (next_local < 0) { - if (mismatch) { - update.absent_valid = false; - } else { - update.present_valid = false; - } - } - - if (!compute_penalties) { - continue; - } - - double prev_contrib = (local < previous_width && mismatch) ? transition.current_costs[k] : 0.0; - double absent_contrib = (next_local >= 0 && mismatch) ? transition.next_costs[k] : 0.0; - double present_contrib = (next_local >= 0 && !mismatch) ? transition.next_costs[k] : 0.0; - update.absent_penalty += absent_contrib - prev_contrib; - update.present_penalty += present_contrib - prev_contrib; - } - - return update; -} - -uint64_t pack_small_key(uint64_t state, uint64_t obs_flip_bit) { - return (state << 1) | (obs_flip_bit & 1ULL); -} - -uint64_t unpack_small_state(uint64_t packed_key) { - return packed_key >> 1; -} - -uint64_t unpack_small_obs(uint64_t packed_key) { - return packed_key & 1ULL; -} - -void normalize_items(std::vector& items) { - double total_mass = 0.0; - for (const auto& item : items) { - total_mass += item.mass; - } - if (total_mass == 0.0) { - items.clear(); - return; - } - double inv = 1.0 / total_mass; - for (auto& item : items) { - item.mass *= inv; - } -} - -void normalize_items(std::vector& items) { - double total_mass = 0.0; - for (const auto& item : items) { - total_mass += item.mass; - } - if (total_mass == 0.0) { - items.clear(); - return; - } - double inv = 1.0 / total_mass; - for (auto& item : items) { - item.mass *= inv; - } -} - -void radix_sort_packed_masses_by_key(std::vector& items) { - if (items.size() <= 1) { - return; - } - - thread_local std::vector buffer; - buffer.resize(items.size()); - - PackedMass* src = items.data(); - PackedMass* dst = buffer.data(); - constexpr size_t RADIX = 256; - std::array counts; - - for (size_t shift = 0; shift < 64; shift += 8) { - counts.fill(0); - for (size_t k = 0; k < items.size(); ++k) { - ++counts[(src[k].key >> shift) & 0xFF]; - } - - size_t total = 0; - for (size_t k = 0; k < RADIX; ++k) { - size_t count = counts[k]; - counts[k] = total; - total += count; - } - - for (size_t k = 0; k < items.size(); ++k) { - dst[counts[(src[k].key >> shift) & 0xFF]++] = src[k]; - } - std::swap(src, dst); - } -} - -void merge_equal_keys_inplace(std::vector& items) { - if (items.empty()) { - return; - } - radix_sort_packed_masses_by_key(items); - size_t out = 0; - for (size_t i = 1; i < items.size(); ++i) { - if (items[i].key == items[out].key) { - items[out].mass += items[i].mass; - } else { - ++out; - if (out != i) { - items[out] = std::move(items[i]); - } - } - } - items.resize(out + 1); -} - -void merge_equal_keys_inplace(std::vector& items) { - if (items.empty()) { - return; - } - std::sort(items.begin(), items.end(), [](const WidePackedMass& a, const WidePackedMass& b) { - if (wide_state_less(a.state_words, b.state_words)) { - return true; - } - if (wide_state_less(b.state_words, a.state_words)) { - return false; - } - return a.obs_mask < b.obs_mask; - }); - - size_t out = 0; - for (size_t i = 1; i < items.size(); ++i) { - if (items[i].obs_mask == items[out].obs_mask && - items[i].state_words == items[out].state_words) { - items[out].mass += items[i].mass; - } else { - ++out; - if (out != i) { - items[out] = std::move(items[i]); - } - } - } - items.resize(out + 1); -} - -std::vector accumulate_state_masses_from_entries( - const std::vector& entries) { - std::vector totals; - if (entries.empty()) { - return totals; - } - totals.reserve(entries.size()); - WideStateMass current{entries[0].state_words, entries[0].mass, entries[0].penalty}; - for (size_t i = 1; i < entries.size(); ++i) { - if (entries[i].state_words == current.state_words) { - current.mass += entries[i].mass; - } else { - totals.push_back(std::move(current)); - current = {entries[i].state_words, entries[i].mass, entries[i].penalty}; - } - } - totals.push_back(std::move(current)); - return totals; -} - double score_mass_and_penalty(double mass, double penalty, TesseractTrellisRankingMode ranking_mode) { if (ranking_mode == TesseractTrellisRankingMode::MassOnly) { @@ -728,15 +370,167 @@ FinalizeKeptStateStatsOnExit::~FinalizeKeptStateStatsOnExit() { finalize_kept_state_stats(decoder); } -bool small_state_group_score_greater(const SmallStateGroup& a, const SmallStateGroup& b) { +void prepare_projected_fault_masks(std::vector* layers) { + for (auto& layer : *layers) { + layer.projected_fault_mask_words.assign(num_state_words(layer.surviving_local_indices.size()), + 0); + for (int32_t next_local : layer.detcost_transition.next_local_indices) { + if (next_local >= 0) { + size_t local = (size_t)next_local; + layer.projected_fault_mask_words[local >> 6] ^= uint64_t{1} << (local & 63); + } + } + } +} + +template +bool fixed_wide_state_less(const FixedWideStateWords& a, const FixedWideStateWords& b) { + for (size_t k = Words; k-- > 0;) { + if (a[k] != b[k]) { + return a[k] < b[k]; + } + } + return false; +} + +template +bool fixed_wide_state_zero(const FixedWideStateWords& state_words) { + for (size_t k = 0; k < Words; ++k) { + if (state_words[k] != 0) { + return false; + } + } + return true; +} + +template +void xor_compiled_wide_state(FixedWideStateWords* state_words, + const std::array& mask_words) { + for (size_t k = 0; k < Words; ++k) { + (*state_words)[k] ^= mask_words[k]; + } +} + +template +FixedWideStateWords project_compiled_wide_state( + const FixedWideStateWords& state_words, const CompiledWideLayerTemplate& layer) { + FixedWideStateWords out{}; + for (size_t src_word = 0; src_word < Words; ++src_word) { + const uint64_t mask = layer.surviving_masks[src_word]; + if (mask == 0) { + continue; + } + const uint64_t packed = compact_bits_u64(state_words[src_word], mask); + const size_t dst_word = layer.projection_dst_words[src_word]; + const uint8_t shift = layer.projection_dst_offsets[src_word]; + out[dst_word] |= packed << shift; + if constexpr (Words > 1) { + if (shift != 0 && dst_word + 1 < Words) { + out[dst_word + 1] |= packed >> (64 - shift); + } + } + } + return out; +} + +template +BranchPenaltyUpdate compute_compiled_wide_branch_update( + const FixedWideStateWords& base_state_words, double current_penalty, + const boost::dynamic_bitset<>& actual_dets, const CompiledWideLayerTemplate& layer, + bool compute_penalties) { + BranchPenaltyUpdate update; + update.absent_penalty = compute_penalties ? current_penalty : 0.0; + update.present_penalty = compute_penalties ? current_penalty : 0.0; + + for (size_t k = 0; k < layer.fault_detector_indices.size(); ++k) { + const bool state_bit = + layer.fault_was_active_before[k] && + ((base_state_words[layer.fault_word_indices[k]] & layer.fault_bit_masks[k]) != 0); + const bool target_bit = actual_dets[layer.fault_detector_indices[k]]; + const bool mismatch = state_bit ^ target_bit; + const int32_t next_local = layer.next_local_indices[k]; + + if (next_local < 0) { + if (mismatch) { + update.absent_valid = false; + } else { + update.present_valid = false; + } + } + + if (!compute_penalties) { + continue; + } + + const double prev_contrib = + (layer.fault_was_active_before[k] && mismatch) ? layer.current_costs[k] : 0.0; + const double absent_contrib = (next_local >= 0 && mismatch) ? layer.next_costs[k] : 0.0; + const double present_contrib = (next_local >= 0 && !mismatch) ? layer.next_costs[k] : 0.0; + update.absent_penalty += absent_contrib - prev_contrib; + update.present_penalty += present_contrib - prev_contrib; + } + + return update; +} + +template +void normalize_compiled_items(std::vector>* items) { + double total = 0.0; + for (const auto& item : *items) { + total += item.mass; + } + if (total == 0.0) { + return; + } + for (auto& item : *items) { + item.mass /= total; + } +} + +template +void merge_equal_compiled_keys_inplace(std::vector>* items) { + if (items->empty()) { + return; + } + std::sort(items->begin(), items->end(), + [](const FixedWidePackedMass& a, const FixedWidePackedMass& b) { + if (fixed_wide_state_less(a.state_words, b.state_words)) { + return true; + } + if (fixed_wide_state_less(b.state_words, a.state_words)) { + return false; + } + return a.obs_mask < b.obs_mask; + }); + + size_t out = 0; + for (size_t i = 1; i < items->size(); ++i) { + if ((*items)[i].obs_mask == (*items)[out].obs_mask && + (*items)[i].state_words == (*items)[out].state_words) { + (*items)[out].mass += (*items)[i].mass; + } else { + ++out; + if (out != i) { + (*items)[out] = std::move((*items)[i]); + } + } + } + items->resize(out + 1); +} + +template +bool compiled_wide_state_group_score_greater(const std::vector>& entries, + const WideStateGroup& a, const WideStateGroup& b) { if (a.score != b.score) { return a.score > b.score; } - return a.state < b.state; + return fixed_wide_state_less(entries[a.begin].state_words, entries[b.begin].state_words); } -size_t trim_small_state_groups_by_beam_and_mass(std::vector* groups, - size_t beam_width, double beam_eps) { +template +size_t trim_compiled_wide_state_groups_by_beam_and_mass( + const std::vector>& entries, std::vector* groups, + size_t beam_width, double beam_eps) { if (groups->empty()) { return 0; } @@ -749,9 +543,10 @@ size_t trim_small_state_groups_by_beam_and_mass(std::vector* gr } if (groups->size() > beam_width) { - std::nth_element( - groups->begin(), groups->begin() + beam_width, groups->end(), - [](const SmallStateGroup& a, const SmallStateGroup& b) { return a.score > b.score; }); + std::nth_element(groups->begin(), groups->begin() + beam_width, groups->end(), + [&entries](const WideStateGroup& a, const WideStateGroup& b) { + return compiled_wide_state_group_score_greater(entries, a, b); + }); groups->resize(beam_width); } else if (beam_eps <= 0.0) { return groups->size(); @@ -761,7 +556,10 @@ size_t trim_small_state_groups_by_beam_and_mass(std::vector* gr return groups->size(); } - std::sort(groups->begin(), groups->end(), small_state_group_score_greater); + std::sort(groups->begin(), groups->end(), + [&entries](const WideStateGroup& a, const WideStateGroup& b) { + return compiled_wide_state_group_score_greater(entries, a, b); + }); const double retained_target_mass = total_mass * (1.0 - beam_eps); double retained_mass = 0.0; size_t keep_count = 0; @@ -774,48 +572,46 @@ size_t trim_small_state_groups_by_beam_and_mass(std::vector* gr } groups->resize(keep_count); std::sort(groups->begin(), groups->end(), - [](const SmallStateGroup& a, const SmallStateGroup& b) { return a.begin < b.begin; }); + [](const WideStateGroup& a, const WideStateGroup& b) { return a.begin < b.begin; }); return groups->size(); } -std::vector collect_small_state_groups(const std::vector& entries, - TesseractTrellisRankingMode ranking_mode) { - std::vector groups; +template +std::vector collect_compiled_wide_state_groups( + const std::vector>& entries, + TesseractTrellisRankingMode ranking_mode) { + std::vector groups; if (entries.empty()) { return groups; } groups.reserve(entries.size()); size_t begin = 0; while (begin < entries.size()) { - uint64_t state = unpack_small_state(entries[begin].key); double mass = 0.0; size_t end = begin; - while (end < entries.size() && unpack_small_state(entries[end].key) == state) { + while (end < entries.size() && entries[end].state_words == entries[begin].state_words) { mass += entries[end].mass; ++end; } - groups.push_back({state, mass, - score_mass_and_penalty(mass, entries[begin].penalty, ranking_mode), begin, - end}); + groups.push_back( + {mass, score_mass_and_penalty(mass, entries[begin].penalty, ranking_mode), begin, end}); begin = end; } return groups; } -double state_score(const WideStateMass& item, TesseractTrellisRankingMode ranking_mode) { - return score_mass_and_penalty(item.mass, item.penalty, ranking_mode); -} - -size_t keep_top_states(std::vector& entries, size_t beam_width, double beam_eps, - TesseractTrellisRankingMode ranking_mode) { - if (entries.empty()) { +template +size_t keep_top_compiled_states(std::vector>* entries, + size_t beam_width, double beam_eps, + TesseractTrellisRankingMode ranking_mode) { + if (entries->empty()) { return 0; } - auto groups = collect_small_state_groups(entries, ranking_mode); + auto groups = collect_compiled_wide_state_groups(*entries, ranking_mode); const size_t kept_group_count = - trim_small_state_groups_by_beam_and_mass(&groups, beam_width, beam_eps); + trim_compiled_wide_state_groups_by_beam_and_mass(*entries, &groups, beam_width, beam_eps); - std::vector kept; + std::vector> kept; size_t kept_entries = 0; for (const auto& group : groups) { kept_entries += group.end - group.begin; @@ -823,401 +619,276 @@ size_t keep_top_states(std::vector& entries, size_t beam_width, doub kept.reserve(kept_entries); for (const auto& group : groups) { for (size_t k = group.begin; k < group.end; ++k) { - kept.push_back(std::move(entries[k])); + kept.push_back(std::move((*entries)[k])); } } - entries = std::move(kept); + *entries = std::move(kept); return kept_group_count; } -size_t keep_top_states(std::vector& entries, size_t beam_width, double beam_eps, - TesseractTrellisRankingMode ranking_mode) { - if (entries.empty()) { - return 0; - } - auto totals = accumulate_state_masses_from_entries(entries); - double total_mass = 0.0; - if (beam_eps > 0.0) { - for (const auto& item : totals) { - total_mass += item.mass; - } - } - - if (totals.size() > beam_width) { - std::nth_element(totals.begin(), totals.begin() + beam_width, totals.end(), - [ranking_mode](const WideStateMass& a, const WideStateMass& b) { - return state_score(a, ranking_mode) > state_score(b, ranking_mode); - }); - totals.resize(beam_width); - } else if (beam_eps <= 0.0) { - return totals.size(); - } - - if (beam_eps > 0.0 && total_mass > 0.0) { - std::sort(totals.begin(), totals.end(), - [ranking_mode](const WideStateMass& a, const WideStateMass& b) { - double sa = state_score(a, ranking_mode); - double sb = state_score(b, ranking_mode); - if (sa != sb) { - return sa > sb; - } - return wide_state_less(a.state_words, b.state_words); - }); - const double retained_target_mass = total_mass * (1.0 - beam_eps); - double retained_mass = 0.0; - size_t keep_count = 0; - while (keep_count < totals.size()) { - retained_mass += totals[keep_count].mass; - ++keep_count; - if (retained_mass >= retained_target_mass) { - break; - } +template +std::vector> compile_wide_layers( + const std::vector& layers) { + std::vector> compiled_layers; + compiled_layers.reserve(layers.size()); + for (const auto& layer : layers) { + if (num_state_words(layer.current_active_detectors.size()) > Words || + layer.projected_fault_mask_words.size() > Words) { + throw std::invalid_argument("Compiled wide kernel word count is smaller than the frontier."); } - totals.resize(keep_count); - } - std::sort(totals.begin(), totals.end(), [](const WideStateMass& a, const WideStateMass& b) { - return wide_state_less(a.state_words, b.state_words); - }); + CompiledWideLayerTemplate compiled; + compiled.q = layer.q; + compiled.p = layer.p; + compiled.obs_mask = layer.obs_mask; - std::vector kept; - kept.reserve(entries.size()); - size_t ti = 0; - for (auto& item : entries) { - while (ti < totals.size() && wide_state_less(totals[ti].state_words, item.state_words)) { - ++ti; + std::array surviving_masks{}; + for (uint32_t current_local : layer.surviving_local_indices) { + surviving_masks[current_local >> 6] |= uint64_t{1} << (current_local & 63); } - if (ti < totals.size() && item.state_words == totals[ti].state_words) { - kept.push_back(std::move(item)); + size_t next_offset = 0; + for (size_t src_word = 0; src_word < Words; ++src_word) { + compiled.surviving_masks[src_word] = surviving_masks[src_word]; + compiled.projection_dst_words[src_word] = static_cast(next_offset >> 6); + compiled.projection_dst_offsets[src_word] = static_cast(next_offset & 63); + next_offset += std::popcount(surviving_masks[src_word]); } - } - entries = std::move(kept); - return totals.size(); -} -void prepare_projected_fault_masks(std::vector* layers) { - for (auto& layer : *layers) { - layer.projected_fault_mask = 0; - for (int32_t next_local : layer.detcost_transition.next_local_indices) { - if (next_local >= 0) { - layer.projected_fault_mask ^= uint64_t{1} << next_local; - } - } - } -} - -void prepare_projected_fault_masks(std::vector* layers) { - for (auto& layer : *layers) { - layer.projected_fault_mask_words.assign(num_state_words(layer.surviving_local_indices.size()), - 0); - for (int32_t next_local : layer.detcost_transition.next_local_indices) { - if (next_local >= 0) { - size_t local = (size_t)next_local; - layer.projected_fault_mask_words[local >> 6] ^= uint64_t{1} << (local & 63); - } + for (size_t k = 0; k < layer.projected_fault_mask_words.size(); ++k) { + compiled.projected_fault_mask_words[k] = layer.projected_fault_mask_words[k]; } - } -} - -} // namespace -TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_) - : config(std::move(config_)) { - std::vector dem_error_map(config.dem.flattened().count_errors()); - std::iota(dem_error_map.begin(), dem_error_map.end(), 0); - dem_error_to_error = std::move(dem_error_map); - error_to_dem_error = common::invert_error_map(dem_error_to_error, config.dem.count_errors()); - errors = get_errors_from_dem(config.dem.flattened()); - num_detectors = config.dem.count_detectors(); - num_observables = config.dem.count_observables(); - - all_possible_detectors = boost::dynamic_bitset<>(num_detectors); - for (const auto& error : errors) { - for (int d : error.symptom.detectors) { - all_possible_detectors[(size_t)d] = true; + const auto& transition = layer.detcost_transition; + compiled.fault_detector_indices.reserve(transition.fault_local_indices.size()); + compiled.fault_word_indices.reserve(transition.fault_local_indices.size()); + compiled.fault_bit_masks.reserve(transition.fault_local_indices.size()); + compiled.fault_was_active_before.reserve(transition.fault_local_indices.size()); + compiled.next_local_indices = transition.next_local_indices; + compiled.current_costs = transition.current_costs; + compiled.next_costs = transition.next_costs; + for (uint32_t local : transition.fault_local_indices) { + compiled.fault_detector_indices.push_back((uint32_t)layer.current_active_detectors[local]); + compiled.fault_word_indices.push_back(static_cast(local >> 6)); + compiled.fault_bit_masks.push_back(uint64_t{1} << (local & 63)); + compiled.fault_was_active_before.push_back(local < layer.previous_width); } - } - - auto faults = parse_faults(errors, num_observables); - - size_t wide_frontier_width = 0; - build_wide_layer_templates(faults, num_detectors, &wide_layer_templates, &wide_frontier_width); - build_future_detcost_transitions(faults, num_detectors, &wide_layer_templates, - &initial_future_detcost); - prepare_projected_fault_masks(&wide_layer_templates); - size_t small_frontier_width = 0; - has_small_layer_templates = - num_observables <= 1 && - build_small_layer_templates(faults, num_detectors, &small_layer_templates, - &small_frontier_width); - if (has_small_layer_templates) { - build_future_detcost_transitions(faults, num_detectors, &small_layer_templates, nullptr); - prepare_projected_fault_masks(&small_layer_templates); - build_small_detector_layer_refs(small_layer_templates, num_detectors, - &small_detector_layer_refs); - scratch_small_current_target_bits.assign(small_layer_templates.size(), 0); - scratch_small_expected_retiring_bits.assign(small_layer_templates.size(), 0); + compiled_layers.push_back(std::move(compiled)); } + return compiled_layers; } -void TesseractTrellisDecoder::decode_shot(const std::vector& detections) { - low_confidence_flag = false; - num_states_expanded = 0; - num_states_merged = 0; - max_beam_size_seen = 0; - max_frontier_width_seen = 0; - reset_kept_state_stats(this); - time_expand_seconds = 0; - time_collapse_seconds = 0; - time_truncate_seconds = 0; - time_reconstruct_seconds = 0; - predicted_obs_mask = 0; - total_mass_obs0 = 0; - total_mass_obs1 = 0; - FinalizeKeptStateStatsOnExit kept_state_stats_guard{this}; - - if (has_small_layer_templates) { - std::fill(scratch_small_current_target_bits.begin(), scratch_small_current_target_bits.end(), - 0); - std::fill(scratch_small_expected_retiring_bits.begin(), - scratch_small_expected_retiring_bits.end(), 0); - - for (uint64_t d : detections) { - if (d >= num_detectors || !all_possible_detectors[d]) { - low_confidence_flag = true; - return; - } - for (const auto& ref : small_detector_layer_refs[(size_t)d]) { - scratch_small_current_target_bits[ref.layer_index] ^= uint64_t{1} << ref.local_index; - } - } - - for (size_t layer_index = 0; layer_index < small_layer_templates.size(); ++layer_index) { - const auto& layer = small_layer_templates[layer_index]; - max_frontier_width_seen = - std::max(max_frontier_width_seen, layer.current_active_detectors.size()); - scratch_small_expected_retiring_bits[layer_index] = - scratch_small_current_target_bits[layer_index] & layer.retiring_mask; - } - - double initial_penalty = 0.0; - if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked && - !small_layer_templates.empty()) { - initial_penalty = compute_initial_penalty_for_target_bits( - scratch_small_current_target_bits.front(), - small_layer_templates.front().current_active_detectors, initial_future_detcost); - } - - std::vector beam_entries; - std::vector next_entries; - beam_entries.reserve(config.beam_width * 2 + 2); - next_entries.reserve(config.beam_width * 4 + 4); - beam_entries.push_back({pack_small_key(0, 0), 1.0, initial_penalty}); - max_beam_size_seen = 1; - - const bool compute_penalties = - config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked; - for (size_t layer_index = 0; layer_index < small_layer_templates.size(); ++layer_index) { - const auto& layer = small_layer_templates[layer_index]; - const uint64_t current_target_bits = scratch_small_current_target_bits[layer_index]; - const uint64_t expected_retiring_bits = scratch_small_expected_retiring_bits[layer_index]; - - auto t0 = std::chrono::high_resolution_clock::now(); - next_entries.clear(); - next_entries.reserve(beam_entries.size() * 2); - for (const auto& item : beam_entries) { - ++num_states_expanded; - const uint64_t base_state = unpack_small_state(item.key); - const uint64_t base_obs = unpack_small_obs(item.key); - - BranchPenaltyUpdate update; - if (compute_penalties) { - update = compute_small_branch_update(base_state, layer.previous_width, item.penalty, - current_target_bits, layer.detcost_transition, true); - } else { - update.absent_valid = - (((base_state ^ expected_retiring_bits) & layer.retiring_mask) == 0); - uint64_t toggled_state = base_state ^ layer.local_det_mask; - update.present_valid = - (((toggled_state ^ expected_retiring_bits) & layer.retiring_mask) == 0); - } - - if (!update.absent_valid && !update.present_valid) { - continue; - } +template +struct CompiledWideKernel final : TesseractTrellisWideKernelBase { + explicit CompiledWideKernel(std::vector> layers_, + std::vector initial_active_detectors_, + size_t max_frontier_width_) + : layers(std::move(layers_)), + initial_active_detectors(std::move(initial_active_detectors_)), + max_frontier_width(max_frontier_width_) {} - uint64_t projected_state = project_small_state(base_state, layer.surviving_mask); - if (update.absent_valid && layer.q != 0.0) { - next_entries.push_back({pack_small_key(projected_state, base_obs), item.mass * layer.q, - update.absent_penalty}); - } - if (update.present_valid && layer.p != 0.0) { - next_entries.push_back({pack_small_key(projected_state ^ layer.projected_fault_mask, - base_obs ^ layer.obs_flip_bit), - item.mass * layer.p, update.present_penalty}); - } - } - auto t1 = std::chrono::high_resolution_clock::now(); - time_expand_seconds += - std::chrono::duration_cast(t1 - t0).count() / 1e6; - - beam_entries.swap(next_entries); - auto t2a = std::chrono::high_resolution_clock::now(); - merge_equal_keys_inplace(beam_entries); - auto t2 = std::chrono::high_resolution_clock::now(); - time_collapse_seconds += - std::chrono::duration_cast(t2 - t2a).count() / 1e6; - - size_t kept_states = - keep_top_states(beam_entries, config.beam_width, config.beam_eps, config.ranking_mode); - normalize_items(beam_entries); - record_kept_state_count(this, beam_entries.empty() ? 0 : kept_states); - if (beam_entries.empty()) { - low_confidence_flag = true; - return; - } - num_states_merged += kept_states; - max_beam_size_seen = std::max(max_beam_size_seen, kept_states); - auto t3 = std::chrono::high_resolution_clock::now(); - time_truncate_seconds += - std::chrono::duration_cast(t3 - t2).count() / 1e6; - } - - auto tr0 = std::chrono::high_resolution_clock::now(); - for (const auto& [packed_key, mass, penalty] : beam_entries) { - (void)penalty; - if (unpack_small_state(packed_key) != 0) { - continue; - } - if (unpack_small_obs(packed_key) == 0) { - total_mass_obs0 += mass; - } else { - total_mass_obs1 += mass; - } - } - if (total_mass_obs0 == 0.0 && total_mass_obs1 == 0.0) { - low_confidence_flag = true; - return; - } - predicted_obs_mask = total_mass_obs1 > total_mass_obs0 ? 1 : 0; - auto tr1 = std::chrono::high_resolution_clock::now(); - time_reconstruct_seconds += - std::chrono::duration_cast(tr1 - tr0).count() / 1e6; - } else { - boost::dynamic_bitset<> actual_dets(num_detectors); + void decode_shot(TesseractTrellisDecoder* decoder, + const std::vector& detections) const override { + boost::dynamic_bitset<> actual_dets(decoder->num_detectors); for (uint64_t d : detections) { - if (d >= num_detectors || !all_possible_detectors[d]) { - low_confidence_flag = true; + if (d >= decoder->num_detectors || !decoder->all_possible_detectors[d]) { + decoder->low_confidence_flag = true; return; } actual_dets.flip((size_t)d); } - max_frontier_width_seen = 0; - for (const auto& layer : wide_layer_templates) { - max_frontier_width_seen = - std::max(max_frontier_width_seen, layer.current_active_detectors.size()); - } + + decoder->max_frontier_width_seen = max_frontier_width; double initial_penalty = 0.0; - if (config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked && - !wide_layer_templates.empty()) { + if (decoder->config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked && + !layers.empty()) { initial_penalty = compute_initial_penalty_for_active_detectors( - wide_layer_templates.front().current_active_detectors, actual_dets, - initial_future_detcost); + initial_active_detectors, actual_dets, decoder->initial_future_detcost); } - std::vector beam_entries; - std::vector next_entries; - beam_entries.reserve(config.beam_width * 2 + 2); - next_entries.reserve(config.beam_width * 4 + 4); + std::vector> beam_entries; + std::vector> next_entries; + beam_entries.reserve(decoder->config.beam_width * 2 + 2); + next_entries.reserve(decoder->config.beam_width * 4 + 4); beam_entries.push_back({{}, 0, 1.0, initial_penalty}); - max_beam_size_seen = 1; + decoder->max_beam_size_seen = 1; - for (size_t layer_index = 0; layer_index < wide_layer_templates.size(); ++layer_index) { - const auto& layer = wide_layer_templates[layer_index]; - const bool compute_penalties = - config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked; + const bool compute_penalties = + decoder->config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked; + for (size_t layer_index = 0; layer_index < layers.size(); ++layer_index) { + const auto& layer = layers[layer_index]; auto t0 = std::chrono::high_resolution_clock::now(); next_entries.clear(); next_entries.reserve(beam_entries.size() * 2); - if (config.verbose) { - std::cout << "expanding layer " << layer_index << " / " << (wide_layer_templates.size() - 1) + if (decoder->config.verbose) { + std::cout << "expanding layer " << layer_index << " / " << (layers.size() - 1) << std::endl; std::cout << "states to expand = " << beam_entries.size() << std::endl; } for (const auto& item : beam_entries) { - ++num_states_expanded; - BranchPenaltyUpdate update = compute_wide_branch_update( - item.state_words, layer.previous_width, item.penalty, layer.current_active_detectors, - actual_dets, layer.detcost_transition, compute_penalties); + ++decoder->num_states_expanded; + BranchPenaltyUpdate update = compute_compiled_wide_branch_update( + item.state_words, item.penalty, actual_dets, layer, compute_penalties); if (!update.absent_valid && !update.present_valid) { continue; } - std::vector projected_state = project_wide_state( - item.state_words, layer.previous_width, layer.surviving_local_indices); - if (update.absent_valid && layer.q != 0.0) { + FixedWideStateWords projected_state = + project_compiled_wide_state(item.state_words, layer); + const bool keep_absent = update.absent_valid && layer.q != 0.0; + const bool keep_present = update.present_valid && layer.p != 0.0; + if (keep_absent && keep_present) { + FixedWideStateWords projected_toggled = projected_state; + xor_compiled_wide_state(&projected_toggled, layer.projected_fault_mask_words); next_entries.push_back( - {projected_state, item.obs_mask, item.mass * layer.q, update.absent_penalty}); - } - if (update.present_valid && layer.p != 0.0) { - std::vector projected_toggled = projected_state; - xor_state_words(projected_toggled, layer.projected_fault_mask_words); + {std::move(projected_state), item.obs_mask, item.mass * layer.q, update.absent_penalty}); next_entries.push_back({std::move(projected_toggled), item.obs_mask ^ layer.obs_mask, item.mass * layer.p, update.present_penalty}); + } else if (keep_absent) { + next_entries.push_back( + {std::move(projected_state), item.obs_mask, item.mass * layer.q, update.absent_penalty}); + } else if (keep_present) { + xor_compiled_wide_state(&projected_state, layer.projected_fault_mask_words); + next_entries.push_back({std::move(projected_state), item.obs_mask ^ layer.obs_mask, + item.mass * layer.p, update.present_penalty}); } } auto t1 = std::chrono::high_resolution_clock::now(); - time_expand_seconds += + decoder->time_expand_seconds += std::chrono::duration_cast(t1 - t0).count() / 1e6; beam_entries.swap(next_entries); auto t2a = std::chrono::high_resolution_clock::now(); - merge_equal_keys_inplace(beam_entries); + merge_equal_compiled_keys_inplace(&beam_entries); auto t2 = std::chrono::high_resolution_clock::now(); - time_collapse_seconds += + decoder->time_collapse_seconds += std::chrono::duration_cast(t2 - t2a).count() / 1e6; - size_t kept_states = - keep_top_states(beam_entries, config.beam_width, config.beam_eps, config.ranking_mode); - normalize_items(beam_entries); - record_kept_state_count(this, beam_entries.empty() ? 0 : kept_states); + const size_t kept_states = keep_top_compiled_states( + &beam_entries, decoder->config.beam_width, decoder->config.beam_eps, + decoder->config.ranking_mode); + normalize_compiled_items(&beam_entries); + record_kept_state_count(decoder, beam_entries.empty() ? 0 : kept_states); if (beam_entries.empty()) { - low_confidence_flag = true; + decoder->low_confidence_flag = true; return; } - num_states_merged += kept_states; - max_beam_size_seen = std::max(max_beam_size_seen, kept_states); + decoder->num_states_merged += kept_states; + decoder->max_beam_size_seen = std::max(decoder->max_beam_size_seen, kept_states); auto t3 = std::chrono::high_resolution_clock::now(); - time_truncate_seconds += + decoder->time_truncate_seconds += std::chrono::duration_cast(t3 - t2).count() / 1e6; } auto tr0 = std::chrono::high_resolution_clock::now(); for (const auto& item : beam_entries) { - if (!wide_state_zero(item.state_words)) { + if (!fixed_wide_state_zero(item.state_words)) { continue; } if (item.obs_mask == 0) { - total_mass_obs0 += item.mass; + decoder->total_mass_obs0 += item.mass; } else if (item.obs_mask == 1) { - total_mass_obs1 += item.mass; + decoder->total_mass_obs1 += item.mass; } } - if (total_mass_obs0 == 0.0 && total_mass_obs1 == 0.0) { - low_confidence_flag = true; + if (decoder->total_mass_obs0 == 0.0 && decoder->total_mass_obs1 == 0.0) { + decoder->low_confidence_flag = true; return; } - predicted_obs_mask = total_mass_obs1 > total_mass_obs0 ? 1 : 0; + decoder->predicted_obs_mask = decoder->total_mass_obs1 > decoder->total_mass_obs0 ? 1 : 0; auto tr1 = std::chrono::high_resolution_clock::now(); - time_reconstruct_seconds += + decoder->time_reconstruct_seconds += std::chrono::duration_cast(tr1 - tr0).count() / 1e6; } + std::vector> layers; + std::vector initial_active_detectors; + size_t max_frontier_width; +}; + +std::unique_ptr build_compiled_wide_kernel( + const std::vector& layers, size_t max_frontier_width) { + const size_t required_words = std::max(1, num_state_words(max_frontier_width)); + if (required_words > kMaxCompiledWideStateWords) { + throw std::invalid_argument("Wide trellis frontier requires " + std::to_string(required_words) + + " words, but only " + + std::to_string(kMaxCompiledWideStateWords) + + " compiled words are enabled."); + } + + const std::vector initial_active_detectors = + layers.empty() ? std::vector{} : layers.front().current_active_detectors; + switch (required_words) { + case 1: + return std::make_unique>( + compile_wide_layers<1>(layers), initial_active_detectors, max_frontier_width); + case 2: + return std::make_unique>( + compile_wide_layers<2>(layers), initial_active_detectors, max_frontier_width); + case 3: + return std::make_unique>( + compile_wide_layers<3>(layers), initial_active_detectors, max_frontier_width); + case 4: + return std::make_unique>( + compile_wide_layers<4>(layers), initial_active_detectors, max_frontier_width); + default: + throw std::invalid_argument("Unsupported compiled wide trellis word count."); + } +} + +} // namespace + +TesseractTrellisDecoder::~TesseractTrellisDecoder() = default; + +TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_) + : config(std::move(config_)) { + std::vector dem_error_map(config.dem.flattened().count_errors()); + std::iota(dem_error_map.begin(), dem_error_map.end(), 0); + dem_error_to_error = std::move(dem_error_map); + error_to_dem_error = common::invert_error_map(dem_error_to_error, config.dem.count_errors()); + errors = get_errors_from_dem(config.dem.flattened()); + num_detectors = config.dem.count_detectors(); + num_observables = config.dem.count_observables(); + + all_possible_detectors = boost::dynamic_bitset<>(num_detectors); + for (const auto& error : errors) { + for (int d : error.symptom.detectors) { + all_possible_detectors[(size_t)d] = true; + } + } + + auto faults = parse_faults(errors, num_observables); + + size_t wide_frontier_width = 0; + build_wide_layer_templates(faults, num_detectors, &wide_layer_templates, &wide_frontier_width); + build_future_detcost_transitions(faults, num_detectors, &wide_layer_templates, + &initial_future_detcost); + prepare_projected_fault_masks(&wide_layer_templates); + wide_kernel = build_compiled_wide_kernel(wide_layer_templates, wide_frontier_width); +} + +__attribute__((hot)) void TesseractTrellisDecoder::decode_shot( + const std::vector& detections) { + low_confidence_flag = false; + num_states_expanded = 0; + num_states_merged = 0; + max_beam_size_seen = 0; + max_frontier_width_seen = 0; + reset_kept_state_stats(this); + time_expand_seconds = 0; + time_collapse_seconds = 0; + time_truncate_seconds = 0; + time_reconstruct_seconds = 0; + predicted_obs_mask = 0; + total_mass_obs0 = 0; + total_mass_obs1 = 0; + FinalizeKeptStateStatsOnExit kept_state_stats_guard{this}; + wide_kernel->decode_shot(this, detections); + if (config.verbose) { std::cout << "trellis beam_width=" << config.beam_width << " frontier_width=" << max_frontier_width_seen diff --git a/src/tesseract_trellis.h b/src/tesseract_trellis.h index 233f528..8aed4ae 100644 --- a/src/tesseract_trellis.h +++ b/src/tesseract_trellis.h @@ -17,11 +17,14 @@ #include #include +#include #include #include "common.h" #include "stim.h" +struct TesseractTrellisWideKernelBase; + enum class TesseractTrellisRankingMode { MassOnly, FutureDetcostRanked, @@ -34,26 +37,6 @@ struct TesseractTrellisDetcostTransition { std::vector next_costs; }; -struct TesseractTrellisSmallLayerTemplate { - double q = 0; - double p = 0; - uint64_t obs_flip_bit = 0; - uint64_t local_det_mask = 0; - uint64_t retiring_mask = 0; - uint64_t surviving_mask = 0; - uint64_t projected_fault_mask = 0; - size_t previous_width = 0; - std::vector surviving_local_indices; - std::vector current_active_detectors; - std::vector next_frontier_costs; - TesseractTrellisDetcostTransition detcost_transition; -}; - -struct TesseractTrellisSmallDetectorLayerRef { - uint32_t layer_index = 0; - uint8_t local_index = 0; -}; - struct TesseractTrellisWideLayerTemplate { double q = 0; double p = 0; @@ -77,6 +60,7 @@ struct TesseractTrellisConfig { struct TesseractTrellisDecoder { explicit TesseractTrellisDecoder(TesseractTrellisConfig config); + ~TesseractTrellisDecoder(); void decode_shot(const std::vector& detections); std::vector decode(const std::vector& detections); @@ -108,12 +92,8 @@ struct TesseractTrellisDecoder { size_t num_observables = 0; size_t num_detectors = 0; boost::dynamic_bitset<> all_possible_detectors; - bool has_small_layer_templates = false; - std::vector small_layer_templates; - std::vector> small_detector_layer_refs; - std::vector scratch_small_current_target_bits; - std::vector scratch_small_expected_retiring_bits; std::vector wide_layer_templates; + std::unique_ptr wide_kernel; std::vector initial_future_detcost; std::vector kept_state_histogram_scratch; }; From 4bd166133b0192b52c49c89f9801734d3efb7b12 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sun, 19 Apr 2026 13:42:20 -0700 Subject: [PATCH 22/25] remove boost dynamic bitset from trellis --- src/BUILD | 2 - src/tesseract_trellis.cc | 114 ++++++++++++++++++++++++++++----------- src/tesseract_trellis.h | 5 +- 3 files changed, 86 insertions(+), 35 deletions(-) diff --git a/src/BUILD b/src/BUILD index 09d559e..cb021d4 100644 --- a/src/BUILD +++ b/src/BUILD @@ -149,7 +149,6 @@ cc_library( deps = [ ":libcommon", ":libutils", - "@boost//:dynamic_bitset", "@stim//:stim_lib", ], ) @@ -163,7 +162,6 @@ cc_library( linkopts = OPT_LINKOPTS, deps = [ ":libtesseract", - "@boost//:dynamic_bitset", "@highs", ], ) diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index e8dd9ec..d101e01 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -84,7 +84,8 @@ struct CompiledWideLayerTemplate { std::array projection_dst_words{}; std::array projection_dst_offsets{}; std::array projected_fault_mask_words{}; - std::vector fault_detector_indices; + std::vector fault_target_word_indices; + std::vector fault_target_bit_masks; std::vector fault_word_indices; std::vector fault_bit_masks; std::vector fault_was_active_before; @@ -255,6 +256,14 @@ size_t num_state_words(size_t num_bits) { return (num_bits + 63) >> 6; } +TESSERACT_ALWAYS_INLINE size_t detector_word_index(size_t detector) { + return detector >> 6; +} + +TESSERACT_ALWAYS_INLINE uint64_t detector_word_mask(size_t detector) { + return uint64_t{1} << (detector & 63); +} + TESSERACT_ALWAYS_INLINE uint64_t compact_bits_u64(uint64_t value, uint64_t mask) { #if defined(__BMI2__) && \ (defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86)) @@ -275,14 +284,17 @@ TESSERACT_ALWAYS_INLINE uint64_t compact_bits_u64(uint64_t value, uint64_t mask) } double compute_initial_penalty_for_active_detectors( - const std::vector& active_detectors, const boost::dynamic_bitset<>& actual_dets, - const std::vector& initial_future_detcost) { + const std::vector& active_detector_word_indices, + const std::vector& active_detector_bit_masks, + const std::vector& active_detector_costs, + const std::vector& actual_detector_words) { double total = 0.0; - for (int detector : active_detectors) { - if (!actual_dets[(size_t)detector]) { + for (size_t k = 0; k < active_detector_costs.size(); ++k) { + if ((actual_detector_words[active_detector_word_indices[k]] & active_detector_bit_masks[k]) == + 0) { continue; } - double best = initial_future_detcost[(size_t)detector]; + double best = active_detector_costs[k]; if (best == INF) { return INF; } @@ -436,17 +448,20 @@ FixedWideStateWords project_compiled_wide_state( template BranchPenaltyUpdate compute_compiled_wide_branch_update( const FixedWideStateWords& base_state_words, double current_penalty, - const boost::dynamic_bitset<>& actual_dets, const CompiledWideLayerTemplate& layer, + const std::vector& actual_detector_words, + const CompiledWideLayerTemplate& layer, bool compute_penalties) { BranchPenaltyUpdate update; update.absent_penalty = compute_penalties ? current_penalty : 0.0; update.present_penalty = compute_penalties ? current_penalty : 0.0; - for (size_t k = 0; k < layer.fault_detector_indices.size(); ++k) { + for (size_t k = 0; k < layer.fault_target_word_indices.size(); ++k) { const bool state_bit = layer.fault_was_active_before[k] && ((base_state_words[layer.fault_word_indices[k]] & layer.fault_bit_masks[k]) != 0); - const bool target_bit = actual_dets[layer.fault_detector_indices[k]]; + const bool target_bit = + (actual_detector_words[layer.fault_target_word_indices[k]] & + layer.fault_target_bit_masks[k]) != 0; const bool mismatch = state_bit ^ target_bit; const int32_t next_local = layer.next_local_indices[k]; @@ -659,7 +674,8 @@ std::vector> compile_wide_layers( } const auto& transition = layer.detcost_transition; - compiled.fault_detector_indices.reserve(transition.fault_local_indices.size()); + compiled.fault_target_word_indices.reserve(transition.fault_local_indices.size()); + compiled.fault_target_bit_masks.reserve(transition.fault_local_indices.size()); compiled.fault_word_indices.reserve(transition.fault_local_indices.size()); compiled.fault_bit_masks.reserve(transition.fault_local_indices.size()); compiled.fault_was_active_before.reserve(transition.fault_local_indices.size()); @@ -667,7 +683,9 @@ std::vector> compile_wide_layers( compiled.current_costs = transition.current_costs; compiled.next_costs = transition.next_costs; for (uint32_t local : transition.fault_local_indices) { - compiled.fault_detector_indices.push_back((uint32_t)layer.current_active_detectors[local]); + const uint32_t detector = (uint32_t)layer.current_active_detectors[local]; + compiled.fault_target_word_indices.push_back((uint32_t)detector_word_index(detector)); + compiled.fault_target_bit_masks.push_back(detector_word_mask(detector)); compiled.fault_word_indices.push_back(static_cast(local >> 6)); compiled.fault_bit_masks.push_back(uint64_t{1} << (local & 63)); compiled.fault_was_active_before.push_back(local < layer.previous_width); @@ -681,21 +699,32 @@ std::vector> compile_wide_layers( template struct CompiledWideKernel final : TesseractTrellisWideKernelBase { explicit CompiledWideKernel(std::vector> layers_, - std::vector initial_active_detectors_, + std::vector initial_detector_word_indices_, + std::vector initial_detector_bit_masks_, + std::vector initial_detector_costs_, size_t max_frontier_width_) : layers(std::move(layers_)), - initial_active_detectors(std::move(initial_active_detectors_)), + initial_detector_word_indices(std::move(initial_detector_word_indices_)), + initial_detector_bit_masks(std::move(initial_detector_bit_masks_)), + initial_detector_costs(std::move(initial_detector_costs_)), max_frontier_width(max_frontier_width_) {} void decode_shot(TesseractTrellisDecoder* decoder, const std::vector& detections) const override { - boost::dynamic_bitset<> actual_dets(decoder->num_detectors); + auto& actual_detector_words = decoder->actual_detector_words_scratch; + std::fill(actual_detector_words.begin(), actual_detector_words.end(), 0); for (uint64_t d : detections) { - if (d >= decoder->num_detectors || !decoder->all_possible_detectors[d]) { + if (d >= decoder->num_detectors) { decoder->low_confidence_flag = true; return; } - actual_dets.flip((size_t)d); + const size_t word = detector_word_index((size_t)d); + const uint64_t mask = detector_word_mask((size_t)d); + if ((decoder->all_possible_detector_words[word] & mask) == 0) { + decoder->low_confidence_flag = true; + return; + } + actual_detector_words[word] ^= mask; } decoder->max_frontier_width_seen = max_frontier_width; @@ -703,8 +732,10 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { double initial_penalty = 0.0; if (decoder->config.ranking_mode == TesseractTrellisRankingMode::FutureDetcostRanked && !layers.empty()) { - initial_penalty = compute_initial_penalty_for_active_detectors( - initial_active_detectors, actual_dets, decoder->initial_future_detcost); + initial_penalty = compute_initial_penalty_for_active_detectors(initial_detector_word_indices, + initial_detector_bit_masks, + initial_detector_costs, + actual_detector_words); } std::vector> beam_entries; @@ -731,7 +762,7 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { for (const auto& item : beam_entries) { ++decoder->num_states_expanded; BranchPenaltyUpdate update = compute_compiled_wide_branch_update( - item.state_words, item.penalty, actual_dets, layer, compute_penalties); + item.state_words, item.penalty, actual_detector_words, layer, compute_penalties); if (!update.absent_valid && !update.present_valid) { continue; @@ -806,12 +837,15 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { } std::vector> layers; - std::vector initial_active_detectors; + std::vector initial_detector_word_indices; + std::vector initial_detector_bit_masks; + std::vector initial_detector_costs; size_t max_frontier_width; }; std::unique_ptr build_compiled_wide_kernel( - const std::vector& layers, size_t max_frontier_width) { + const std::vector& layers, size_t max_frontier_width, + const std::vector& initial_future_detcost) { const size_t required_words = std::max(1, num_state_words(max_frontier_width)); if (required_words > kMaxCompiledWideStateWords) { throw std::invalid_argument("Wide trellis frontier requires " + std::to_string(required_words) + @@ -820,21 +854,37 @@ std::unique_ptr build_compiled_wide_kernel( " compiled words are enabled."); } - const std::vector initial_active_detectors = - layers.empty() ? std::vector{} : layers.front().current_active_detectors; + std::vector initial_detector_word_indices; + std::vector initial_detector_bit_masks; + std::vector initial_detector_costs; + if (!layers.empty()) { + const auto& initial_active_detectors = layers.front().current_active_detectors; + initial_detector_word_indices.reserve(initial_active_detectors.size()); + initial_detector_bit_masks.reserve(initial_active_detectors.size()); + initial_detector_costs.reserve(initial_active_detectors.size()); + for (int detector : initial_active_detectors) { + initial_detector_word_indices.push_back((uint32_t)detector_word_index((size_t)detector)); + initial_detector_bit_masks.push_back(detector_word_mask((size_t)detector)); + initial_detector_costs.push_back(initial_future_detcost[(size_t)detector]); + } + } switch (required_words) { case 1: return std::make_unique>( - compile_wide_layers<1>(layers), initial_active_detectors, max_frontier_width); + compile_wide_layers<1>(layers), initial_detector_word_indices, initial_detector_bit_masks, + initial_detector_costs, max_frontier_width); case 2: return std::make_unique>( - compile_wide_layers<2>(layers), initial_active_detectors, max_frontier_width); + compile_wide_layers<2>(layers), initial_detector_word_indices, initial_detector_bit_masks, + initial_detector_costs, max_frontier_width); case 3: return std::make_unique>( - compile_wide_layers<3>(layers), initial_active_detectors, max_frontier_width); + compile_wide_layers<3>(layers), initial_detector_word_indices, initial_detector_bit_masks, + initial_detector_costs, max_frontier_width); case 4: return std::make_unique>( - compile_wide_layers<4>(layers), initial_active_detectors, max_frontier_width); + compile_wide_layers<4>(layers), initial_detector_word_indices, initial_detector_bit_masks, + initial_detector_costs, max_frontier_width); default: throw std::invalid_argument("Unsupported compiled wide trellis word count."); } @@ -854,10 +904,12 @@ TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_) num_detectors = config.dem.count_detectors(); num_observables = config.dem.count_observables(); - all_possible_detectors = boost::dynamic_bitset<>(num_detectors); + all_possible_detector_words.assign(num_state_words(num_detectors), 0); + actual_detector_words_scratch.assign(all_possible_detector_words.size(), 0); for (const auto& error : errors) { for (int d : error.symptom.detectors) { - all_possible_detectors[(size_t)d] = true; + all_possible_detector_words[detector_word_index((size_t)d)] |= + detector_word_mask((size_t)d); } } @@ -865,10 +917,12 @@ TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_) size_t wide_frontier_width = 0; build_wide_layer_templates(faults, num_detectors, &wide_layer_templates, &wide_frontier_width); + std::vector initial_future_detcost; build_future_detcost_transitions(faults, num_detectors, &wide_layer_templates, &initial_future_detcost); prepare_projected_fault_masks(&wide_layer_templates); - wide_kernel = build_compiled_wide_kernel(wide_layer_templates, wide_frontier_width); + wide_kernel = + build_compiled_wide_kernel(wide_layer_templates, wide_frontier_width, initial_future_detcost); } __attribute__((hot)) void TesseractTrellisDecoder::decode_shot( diff --git a/src/tesseract_trellis.h b/src/tesseract_trellis.h index 8aed4ae..d3b6ee0 100644 --- a/src/tesseract_trellis.h +++ b/src/tesseract_trellis.h @@ -15,7 +15,6 @@ #ifndef TESSERACT_TRELLIS_DECODER_H #define TESSERACT_TRELLIS_DECODER_H -#include #include #include #include @@ -91,10 +90,10 @@ struct TesseractTrellisDecoder { std::vector errors; size_t num_observables = 0; size_t num_detectors = 0; - boost::dynamic_bitset<> all_possible_detectors; + std::vector all_possible_detector_words; + std::vector actual_detector_words_scratch; std::vector wide_layer_templates; std::unique_ptr wide_kernel; - std::vector initial_future_detcost; std::vector kept_state_histogram_scratch; }; From 83ade9df9fe07678b5c858d57e13f0e1a02511c6 Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sun, 19 Apr 2026 13:56:27 -0700 Subject: [PATCH 23/25] simplify wide path --- src/tesseract_trellis.cc | 193 +++++++++++++++------------------------ 1 file changed, 76 insertions(+), 117 deletions(-) diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index d101e01..b38320c 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -57,29 +57,23 @@ struct Fault { std::vector detectors; }; -struct WideStateGroup { - double mass; - double score; - size_t begin; - size_t end; -}; - template using FixedWideStateWords = std::array; template -struct FixedWidePackedMass { +struct FixedWideStateEntry { FixedWideStateWords state_words{}; - uint64_t obs_mask = 0; - double mass = 0.0; + double mass0 = 0.0; + double mass1 = 0.0; double penalty = 0.0; + double score = -INF; }; template struct CompiledWideLayerTemplate { double q = 0.0; double p = 0.0; - uint64_t obs_mask = 0; + bool toggles_observable = false; std::array surviving_masks{}; std::array projection_dst_words{}; std::array projection_dst_offsets{}; @@ -314,6 +308,11 @@ double score_mass_and_penalty(double mass, double penalty, return std::log(mass) - penalty; } +template +TESSERACT_ALWAYS_INLINE double total_entry_mass(const FixedWideStateEntry& entry) { + return entry.mass0 + entry.mass1; +} + void reset_kept_state_stats(TesseractTrellisDecoder* decoder) { decoder->kept_state_sample_count = 0; decoder->kept_state_min = 0; @@ -489,40 +488,35 @@ BranchPenaltyUpdate compute_compiled_wide_branch_update( } template -void normalize_compiled_items(std::vector>* items) { +void normalize_compiled_items(std::vector>* items) { double total = 0.0; for (const auto& item : *items) { - total += item.mass; + total += total_entry_mass(item); } if (total == 0.0) { return; } for (auto& item : *items) { - item.mass /= total; + item.mass0 /= total; + item.mass1 /= total; } } template -void merge_equal_compiled_keys_inplace(std::vector>* items) { +void merge_equal_compiled_keys_inplace(std::vector>* items) { if (items->empty()) { return; } std::sort(items->begin(), items->end(), - [](const FixedWidePackedMass& a, const FixedWidePackedMass& b) { - if (fixed_wide_state_less(a.state_words, b.state_words)) { - return true; - } - if (fixed_wide_state_less(b.state_words, a.state_words)) { - return false; - } - return a.obs_mask < b.obs_mask; + [](const FixedWideStateEntry& a, const FixedWideStateEntry& b) { + return fixed_wide_state_less(a.state_words, b.state_words); }); size_t out = 0; for (size_t i = 1; i < items->size(); ++i) { - if ((*items)[i].obs_mask == (*items)[out].obs_mask && - (*items)[i].state_words == (*items)[out].state_words) { - (*items)[out].mass += (*items)[i].mass; + if ((*items)[i].state_words == (*items)[out].state_words) { + (*items)[out].mass0 += (*items)[i].mass0; + (*items)[out].mass1 += (*items)[i].mass1; } else { ++out; if (out != i) { @@ -534,111 +528,61 @@ void merge_equal_compiled_keys_inplace(std::vector>* } template -bool compiled_wide_state_group_score_greater(const std::vector>& entries, - const WideStateGroup& a, const WideStateGroup& b) { +bool compiled_state_score_greater(const FixedWideStateEntry& a, + const FixedWideStateEntry& b) { if (a.score != b.score) { return a.score > b.score; } - return fixed_wide_state_less(entries[a.begin].state_words, entries[b.begin].state_words); + return fixed_wide_state_less(a.state_words, b.state_words); } template -size_t trim_compiled_wide_state_groups_by_beam_and_mass( - const std::vector>& entries, std::vector* groups, - size_t beam_width, double beam_eps) { - if (groups->empty()) { +size_t keep_top_compiled_states(std::vector>* entries, + size_t beam_width, double beam_eps, + TesseractTrellisRankingMode ranking_mode) { + if (entries->empty()) { return 0; } double total_mass = 0.0; - if (beam_eps > 0.0) { - for (const auto& group : *groups) { - total_mass += group.mass; + for (auto& entry : *entries) { + const double mass = total_entry_mass(entry); + entry.score = score_mass_and_penalty(mass, entry.penalty, ranking_mode); + if (beam_eps > 0.0) { + total_mass += mass; } } - if (groups->size() > beam_width) { - std::nth_element(groups->begin(), groups->begin() + beam_width, groups->end(), - [&entries](const WideStateGroup& a, const WideStateGroup& b) { - return compiled_wide_state_group_score_greater(entries, a, b); + if (entries->size() > beam_width) { + std::nth_element(entries->begin(), entries->begin() + beam_width, entries->end(), + [](const FixedWideStateEntry& a, const FixedWideStateEntry& b) { + return compiled_state_score_greater(a, b); }); - groups->resize(beam_width); + entries->resize(beam_width); } else if (beam_eps <= 0.0) { - return groups->size(); + return entries->size(); } if (beam_eps <= 0.0 || total_mass <= 0.0) { - return groups->size(); + return entries->size(); } - std::sort(groups->begin(), groups->end(), - [&entries](const WideStateGroup& a, const WideStateGroup& b) { - return compiled_wide_state_group_score_greater(entries, a, b); + std::sort(entries->begin(), entries->end(), + [](const FixedWideStateEntry& a, const FixedWideStateEntry& b) { + return compiled_state_score_greater(a, b); }); const double retained_target_mass = total_mass * (1.0 - beam_eps); double retained_mass = 0.0; size_t keep_count = 0; - while (keep_count < groups->size()) { - retained_mass += (*groups)[keep_count].mass; + while (keep_count < entries->size()) { + retained_mass += total_entry_mass((*entries)[keep_count]); ++keep_count; if (retained_mass >= retained_target_mass) { break; } } - groups->resize(keep_count); - std::sort(groups->begin(), groups->end(), - [](const WideStateGroup& a, const WideStateGroup& b) { return a.begin < b.begin; }); - return groups->size(); -} - -template -std::vector collect_compiled_wide_state_groups( - const std::vector>& entries, - TesseractTrellisRankingMode ranking_mode) { - std::vector groups; - if (entries.empty()) { - return groups; - } - groups.reserve(entries.size()); - size_t begin = 0; - while (begin < entries.size()) { - double mass = 0.0; - size_t end = begin; - while (end < entries.size() && entries[end].state_words == entries[begin].state_words) { - mass += entries[end].mass; - ++end; - } - groups.push_back( - {mass, score_mass_and_penalty(mass, entries[begin].penalty, ranking_mode), begin, end}); - begin = end; - } - return groups; -} - -template -size_t keep_top_compiled_states(std::vector>* entries, - size_t beam_width, double beam_eps, - TesseractTrellisRankingMode ranking_mode) { - if (entries->empty()) { - return 0; - } - auto groups = collect_compiled_wide_state_groups(*entries, ranking_mode); - const size_t kept_group_count = - trim_compiled_wide_state_groups_by_beam_and_mass(*entries, &groups, beam_width, beam_eps); - - std::vector> kept; - size_t kept_entries = 0; - for (const auto& group : groups) { - kept_entries += group.end - group.begin; - } - kept.reserve(kept_entries); - for (const auto& group : groups) { - for (size_t k = group.begin; k < group.end; ++k) { - kept.push_back(std::move((*entries)[k])); - } - } - *entries = std::move(kept); - return kept_group_count; + entries->resize(keep_count); + return keep_count; } template @@ -655,7 +599,10 @@ std::vector> compile_wide_layers( CompiledWideLayerTemplate compiled; compiled.q = layer.q; compiled.p = layer.p; - compiled.obs_mask = layer.obs_mask; + if (layer.obs_mask > 1) { + throw std::invalid_argument("tesseract_trellis currently supports at most 1 observable"); + } + compiled.toggles_observable = layer.obs_mask != 0; std::array surviving_masks{}; for (uint32_t current_local : layer.surviving_local_indices) { @@ -738,11 +685,11 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { actual_detector_words); } - std::vector> beam_entries; - std::vector> next_entries; + std::vector> beam_entries; + std::vector> next_entries; beam_entries.reserve(decoder->config.beam_width * 2 + 2); next_entries.reserve(decoder->config.beam_width * 4 + 4); - beam_entries.push_back({{}, 0, 1.0, initial_penalty}); + beam_entries.push_back({{}, 1.0, 0.0, initial_penalty}); decoder->max_beam_size_seen = 1; const bool compute_penalties = @@ -776,16 +723,28 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { FixedWideStateWords projected_toggled = projected_state; xor_compiled_wide_state(&projected_toggled, layer.projected_fault_mask_words); next_entries.push_back( - {std::move(projected_state), item.obs_mask, item.mass * layer.q, update.absent_penalty}); - next_entries.push_back({std::move(projected_toggled), item.obs_mask ^ layer.obs_mask, - item.mass * layer.p, update.present_penalty}); + {std::move(projected_state), item.mass0 * layer.q, item.mass1 * layer.q, + update.absent_penalty}); + if (layer.toggles_observable) { + next_entries.push_back({std::move(projected_toggled), item.mass1 * layer.p, + item.mass0 * layer.p, update.present_penalty}); + } else { + next_entries.push_back({std::move(projected_toggled), item.mass0 * layer.p, + item.mass1 * layer.p, update.present_penalty}); + } } else if (keep_absent) { next_entries.push_back( - {std::move(projected_state), item.obs_mask, item.mass * layer.q, update.absent_penalty}); + {std::move(projected_state), item.mass0 * layer.q, item.mass1 * layer.q, + update.absent_penalty}); } else if (keep_present) { xor_compiled_wide_state(&projected_state, layer.projected_fault_mask_words); - next_entries.push_back({std::move(projected_state), item.obs_mask ^ layer.obs_mask, - item.mass * layer.p, update.present_penalty}); + if (layer.toggles_observable) { + next_entries.push_back({std::move(projected_state), item.mass1 * layer.p, + item.mass0 * layer.p, update.present_penalty}); + } else { + next_entries.push_back({std::move(projected_state), item.mass0 * layer.p, + item.mass1 * layer.p, update.present_penalty}); + } } } auto t1 = std::chrono::high_resolution_clock::now(); @@ -820,11 +779,8 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { if (!fixed_wide_state_zero(item.state_words)) { continue; } - if (item.obs_mask == 0) { - decoder->total_mass_obs0 += item.mass; - } else if (item.obs_mask == 1) { - decoder->total_mass_obs1 += item.mass; - } + decoder->total_mass_obs0 += item.mass0; + decoder->total_mass_obs1 += item.mass1; } if (decoder->total_mass_obs0 == 0.0 && decoder->total_mass_obs1 == 0.0) { decoder->low_confidence_flag = true; @@ -903,6 +859,9 @@ TesseractTrellisDecoder::TesseractTrellisDecoder(TesseractTrellisConfig config_) errors = get_errors_from_dem(config.dem.flattened()); num_detectors = config.dem.count_detectors(); num_observables = config.dem.count_observables(); + if (num_observables > 1) { + throw std::invalid_argument("tesseract_trellis currently supports at most 1 observable"); + } all_possible_detector_words.assign(num_state_words(num_detectors), 0); actual_detector_words_scratch.assign(all_possible_detector_words.size(), 0); From f016a982284e23520df03ca4d8f5d5d5f033a05e Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sun, 19 Apr 2026 14:34:41 -0700 Subject: [PATCH 24/25] 1.5x speedup on trellis by accumulating states directly into pair buckets --- src/tesseract_trellis.cc | 149 ++++++++++++++++++++++++++++++++------- 1 file changed, 122 insertions(+), 27 deletions(-) diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index b38320c..6561709 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -69,6 +69,16 @@ struct FixedWideStateEntry { double score = -INF; }; +template +struct FixedWidePairBucket { + FixedWideStateWords key{}; + double mass0[2]{}; + double mass1[2]{}; + double penalty[2]{}; + uint8_t used_mask = 0; + bool occupied = false; +}; + template struct CompiledWideLayerTemplate { double q = 0.0; @@ -313,6 +323,13 @@ TESSERACT_ALWAYS_INLINE double total_entry_mass(const FixedWideStateEntry return entry.mass0 + entry.mass1; } +TESSERACT_ALWAYS_INLINE uint64_t mix_splitmix64(uint64_t value) { + value += 0x9e3779b97f4a7c15ULL; + value = (value ^ (value >> 30)) * 0xbf58476d1ce4e5b9ULL; + value = (value ^ (value >> 27)) * 0x94d049bb133111ebULL; + return value ^ (value >> 31); +} + void reset_kept_state_stats(TesseractTrellisDecoder* decoder) { decoder->kept_state_sample_count = 0; decoder->kept_state_min = 0; @@ -422,6 +439,72 @@ void xor_compiled_wide_state(FixedWideStateWords* state_words, } } +template +TESSERACT_ALWAYS_INLINE uint64_t hash_fixed_wide_state(const FixedWideStateWords& state_words) { + uint64_t hash = 0x123456789abcdef0ULL; + for (size_t k = 0; k < Words; ++k) { + hash ^= mix_splitmix64(state_words[k] + 0x9e3779b97f4a7c15ULL * (k + 1)); + hash = std::rotl(hash, 21); + } + return hash; +} + +template +void ensure_pair_bucket_capacity(std::vector>* buckets, + size_t num_parents) { + const size_t required = std::bit_ceil(std::max(16, num_parents * 2)); + if (buckets->size() < required) { + buckets->resize(required); + } +} + +template +void clear_pair_buckets(std::vector>* buckets, + std::vector* used_bucket_indices) { + for (size_t index : *used_bucket_indices) { + (*buckets)[index].occupied = false; + (*buckets)[index].used_mask = 0; + } + used_bucket_indices->clear(); +} + +template +TESSERACT_ALWAYS_INLINE size_t find_or_insert_pair_bucket( + std::vector>* buckets, std::vector* used_bucket_indices, + const FixedWideStateWords& key) { + const size_t mask = buckets->size() - 1; + size_t index = hash_fixed_wide_state(key) & mask; + while ((*buckets)[index].occupied) { + if ((*buckets)[index].key == key) { + return index; + } + index = (index + 1) & mask; + } + + auto& bucket = (*buckets)[index]; + bucket.occupied = true; + bucket.key = key; + bucket.used_mask = 0; + used_bucket_indices->push_back(index); + return index; +} + +template +TESSERACT_ALWAYS_INLINE void accumulate_pair_bucket_slot(FixedWidePairBucket* bucket, + uint8_t slot, double mass0, double mass1, + double penalty) { + const uint8_t bit = (uint8_t)(1u << slot); + if ((bucket->used_mask & bit) == 0) { + bucket->mass0[slot] = mass0; + bucket->mass1[slot] = mass1; + bucket->penalty[slot] = penalty; + bucket->used_mask |= bit; + } else { + bucket->mass0[slot] += mass0; + bucket->mass1[slot] += mass1; + } +} + template FixedWideStateWords project_compiled_wide_state( const FixedWideStateWords& state_words, const CompiledWideLayerTemplate& layer) { @@ -687,6 +770,8 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { std::vector> beam_entries; std::vector> next_entries; + std::vector> pair_buckets; + std::vector used_bucket_indices; beam_entries.reserve(decoder->config.beam_width * 2 + 2); next_entries.reserve(decoder->config.beam_width * 4 + 4); beam_entries.push_back({{}, 1.0, 0.0, initial_penalty}); @@ -697,9 +782,10 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { for (size_t layer_index = 0; layer_index < layers.size(); ++layer_index) { const auto& layer = layers[layer_index]; + ensure_pair_bucket_capacity(&pair_buckets, beam_entries.size()); + clear_pair_buckets(&pair_buckets, &used_bucket_indices); + auto t0 = std::chrono::high_resolution_clock::now(); - next_entries.clear(); - next_entries.reserve(beam_entries.size() * 2); if (decoder->config.verbose) { std::cout << "expanding layer " << layer_index << " / " << (layers.size() - 1) @@ -717,33 +803,29 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { FixedWideStateWords projected_state = project_compiled_wide_state(item.state_words, layer); + FixedWideStateWords projected_toggled = projected_state; + xor_compiled_wide_state(&projected_toggled, layer.projected_fault_mask_words); + const bool projected_is_key = !fixed_wide_state_less(projected_toggled, projected_state); + const auto& bucket_key = projected_is_key ? projected_state : projected_toggled; + const uint8_t absent_slot = projected_is_key ? 0 : 1; + const uint8_t present_slot = projected_toggled == bucket_key ? 0 : 1; + const size_t bucket_index = + find_or_insert_pair_bucket(&pair_buckets, &used_bucket_indices, bucket_key); + auto& bucket = pair_buckets[bucket_index]; const bool keep_absent = update.absent_valid && layer.q != 0.0; const bool keep_present = update.present_valid && layer.p != 0.0; - if (keep_absent && keep_present) { - FixedWideStateWords projected_toggled = projected_state; - xor_compiled_wide_state(&projected_toggled, layer.projected_fault_mask_words); - next_entries.push_back( - {std::move(projected_state), item.mass0 * layer.q, item.mass1 * layer.q, - update.absent_penalty}); - if (layer.toggles_observable) { - next_entries.push_back({std::move(projected_toggled), item.mass1 * layer.p, - item.mass0 * layer.p, update.present_penalty}); - } else { - next_entries.push_back({std::move(projected_toggled), item.mass0 * layer.p, - item.mass1 * layer.p, update.present_penalty}); - } - } else if (keep_absent) { - next_entries.push_back( - {std::move(projected_state), item.mass0 * layer.q, item.mass1 * layer.q, - update.absent_penalty}); - } else if (keep_present) { - xor_compiled_wide_state(&projected_state, layer.projected_fault_mask_words); + + if (keep_absent) { + accumulate_pair_bucket_slot(&bucket, absent_slot, item.mass0 * layer.q, + item.mass1 * layer.q, update.absent_penalty); + } + if (keep_present) { if (layer.toggles_observable) { - next_entries.push_back({std::move(projected_state), item.mass1 * layer.p, - item.mass0 * layer.p, update.present_penalty}); + accumulate_pair_bucket_slot(&bucket, present_slot, item.mass1 * layer.p, + item.mass0 * layer.p, update.present_penalty); } else { - next_entries.push_back({std::move(projected_state), item.mass0 * layer.p, - item.mass1 * layer.p, update.present_penalty}); + accumulate_pair_bucket_slot(&bucket, present_slot, item.mass0 * layer.p, + item.mass1 * layer.p, update.present_penalty); } } } @@ -751,9 +833,22 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { decoder->time_expand_seconds += std::chrono::duration_cast(t1 - t0).count() / 1e6; - beam_entries.swap(next_entries); auto t2a = std::chrono::high_resolution_clock::now(); - merge_equal_compiled_keys_inplace(&beam_entries); + next_entries.clear(); + next_entries.reserve(used_bucket_indices.size() * 2); + for (size_t index : used_bucket_indices) { + auto& bucket = pair_buckets[index]; + if ((bucket.used_mask & 1u) != 0) { + next_entries.push_back({bucket.key, bucket.mass0[0], bucket.mass1[0], bucket.penalty[0]}); + } + if ((bucket.used_mask & 2u) != 0) { + auto other_state = bucket.key; + xor_compiled_wide_state(&other_state, layer.projected_fault_mask_words); + next_entries.push_back( + {std::move(other_state), bucket.mass0[1], bucket.mass1[1], bucket.penalty[1]}); + } + } + beam_entries.swap(next_entries); auto t2 = std::chrono::high_resolution_clock::now(); decoder->time_collapse_seconds += std::chrono::duration_cast(t2 - t2a).count() / 1e6; From 24a75784ad0ac4adf9cd1f8c23433ae00142153a Mon Sep 17 00:00:00 2001 From: Noah Shutty Date: Sun, 19 Apr 2026 14:59:25 -0700 Subject: [PATCH 25/25] =?UTF-8?q?specialize=20future-detcost=20branch-upda?= =?UTF-8?q?te=20loop=20by=20pre-splitting=20detcost=20terms=20into=20?= =?UTF-8?q?=E2=80=9Csurviving=E2=80=9D=20and=20=E2=80=9Cretiring=E2=80=9D?= =?UTF-8?q?=20=20=20terms=20at=20layer=20compile=20time,=20then=20dispatch?= =?UTF-8?q?ing=20each=20layer=20to=20a=20templated=20fast=20path=20that=20?= =?UTF-8?q?skips=20retiring-detector=20validity=20=20=20checks=20entirely?= =?UTF-8?q?=20when=20that=20layer=20doesn=E2=80=99t=20need=20them?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/tesseract_trellis.cc | 181 +++++++++++++++++++++++++-------------- 1 file changed, 117 insertions(+), 64 deletions(-) diff --git a/src/tesseract_trellis.cc b/src/tesseract_trellis.cc index 6561709..89dbe35 100644 --- a/src/tesseract_trellis.cc +++ b/src/tesseract_trellis.cc @@ -84,6 +84,8 @@ struct CompiledWideLayerTemplate { double q = 0.0; double p = 0.0; bool toggles_observable = false; + bool has_retiring_terms = false; + size_t surviving_term_count = 0; std::array surviving_masks{}; std::array projection_dst_words{}; std::array projection_dst_offsets{}; @@ -93,7 +95,6 @@ struct CompiledWideLayerTemplate { std::vector fault_word_indices; std::vector fault_bit_masks; std::vector fault_was_active_before; - std::vector next_local_indices; std::vector current_costs; std::vector next_costs; }; @@ -527,17 +528,15 @@ FixedWideStateWords project_compiled_wide_state( return out; } -template -BranchPenaltyUpdate compute_compiled_wide_branch_update( +template +TESSERACT_ALWAYS_INLINE BranchPenaltyUpdate compute_compiled_wide_branch_update( const FixedWideStateWords& base_state_words, double current_penalty, - const std::vector& actual_detector_words, - const CompiledWideLayerTemplate& layer, - bool compute_penalties) { + const std::vector& actual_detector_words, const CompiledWideLayerTemplate& layer) { BranchPenaltyUpdate update; - update.absent_penalty = compute_penalties ? current_penalty : 0.0; - update.present_penalty = compute_penalties ? current_penalty : 0.0; + update.absent_penalty = ComputePenalties ? current_penalty : 0.0; + update.present_penalty = ComputePenalties ? current_penalty : 0.0; - for (size_t k = 0; k < layer.fault_target_word_indices.size(); ++k) { + for (size_t k = 0; k < layer.surviving_term_count; ++k) { const bool state_bit = layer.fault_was_active_before[k] && ((base_state_words[layer.fault_word_indices[k]] & layer.fault_bit_masks[k]) != 0); @@ -545,29 +544,86 @@ BranchPenaltyUpdate compute_compiled_wide_branch_update( (actual_detector_words[layer.fault_target_word_indices[k]] & layer.fault_target_bit_masks[k]) != 0; const bool mismatch = state_bit ^ target_bit; - const int32_t next_local = layer.next_local_indices[k]; - if (next_local < 0) { + if constexpr (ComputePenalties) { + const double prev_contrib = + (layer.fault_was_active_before[k] && mismatch) ? layer.current_costs[k] : 0.0; + const double next_contrib = mismatch ? layer.next_costs[k] : 0.0; + update.absent_penalty += next_contrib - prev_contrib; + update.present_penalty += (layer.next_costs[k] - next_contrib) - prev_contrib; + } + } + + if constexpr (CheckRetiringTerms) { + for (size_t k = layer.surviving_term_count; k < layer.fault_target_word_indices.size(); ++k) { + const bool state_bit = + layer.fault_was_active_before[k] && + ((base_state_words[layer.fault_word_indices[k]] & layer.fault_bit_masks[k]) != 0); + const bool target_bit = + (actual_detector_words[layer.fault_target_word_indices[k]] & + layer.fault_target_bit_masks[k]) != 0; + const bool mismatch = state_bit ^ target_bit; + if (mismatch) { update.absent_valid = false; } else { update.present_valid = false; } + + if constexpr (ComputePenalties) { + const double prev_contrib = + (layer.fault_was_active_before[k] && mismatch) ? layer.current_costs[k] : 0.0; + update.absent_penalty -= prev_contrib; + update.present_penalty -= prev_contrib; + } } + } + + return update; +} - if (!compute_penalties) { +template +void expand_compiled_layer_into_pair_buckets( + const std::vector>& beam_entries, + std::vector>* pair_buckets, std::vector* used_bucket_indices, + const std::vector& actual_detector_words, const CompiledWideLayerTemplate& layer, + TesseractTrellisDecoder* decoder) { + for (const auto& item : beam_entries) { + ++decoder->num_states_expanded; + BranchPenaltyUpdate update = compute_compiled_wide_branch_update( + item.state_words, item.penalty, actual_detector_words, layer); + + if (!update.absent_valid && !update.present_valid) { continue; } - const double prev_contrib = - (layer.fault_was_active_before[k] && mismatch) ? layer.current_costs[k] : 0.0; - const double absent_contrib = (next_local >= 0 && mismatch) ? layer.next_costs[k] : 0.0; - const double present_contrib = (next_local >= 0 && !mismatch) ? layer.next_costs[k] : 0.0; - update.absent_penalty += absent_contrib - prev_contrib; - update.present_penalty += present_contrib - prev_contrib; + FixedWideStateWords projected_state = project_compiled_wide_state(item.state_words, layer); + FixedWideStateWords projected_toggled = projected_state; + xor_compiled_wide_state(&projected_toggled, layer.projected_fault_mask_words); + const bool projected_is_key = !fixed_wide_state_less(projected_toggled, projected_state); + const auto& bucket_key = projected_is_key ? projected_state : projected_toggled; + const uint8_t absent_slot = projected_is_key ? 0 : 1; + const uint8_t present_slot = projected_toggled == bucket_key ? 0 : 1; + const size_t bucket_index = + find_or_insert_pair_bucket(pair_buckets, used_bucket_indices, bucket_key); + auto& bucket = (*pair_buckets)[bucket_index]; + const bool keep_absent = update.absent_valid && layer.q != 0.0; + const bool keep_present = update.present_valid && layer.p != 0.0; + + if (keep_absent) { + accumulate_pair_bucket_slot(&bucket, absent_slot, item.mass0 * layer.q, item.mass1 * layer.q, + update.absent_penalty); + } + if (keep_present) { + if (layer.toggles_observable) { + accumulate_pair_bucket_slot(&bucket, present_slot, item.mass1 * layer.p, item.mass0 * layer.p, + update.present_penalty); + } else { + accumulate_pair_bucket_slot(&bucket, present_slot, item.mass0 * layer.p, item.mass1 * layer.p, + update.present_penalty); + } + } } - - return update; } template @@ -704,21 +760,36 @@ std::vector> compile_wide_layers( } const auto& transition = layer.detcost_transition; - compiled.fault_target_word_indices.reserve(transition.fault_local_indices.size()); - compiled.fault_target_bit_masks.reserve(transition.fault_local_indices.size()); - compiled.fault_word_indices.reserve(transition.fault_local_indices.size()); - compiled.fault_bit_masks.reserve(transition.fault_local_indices.size()); - compiled.fault_was_active_before.reserve(transition.fault_local_indices.size()); - compiled.next_local_indices = transition.next_local_indices; - compiled.current_costs = transition.current_costs; - compiled.next_costs = transition.next_costs; - for (uint32_t local : transition.fault_local_indices) { + const size_t term_count = transition.fault_local_indices.size(); + compiled.fault_target_word_indices.reserve(term_count); + compiled.fault_target_bit_masks.reserve(term_count); + compiled.fault_word_indices.reserve(term_count); + compiled.fault_bit_masks.reserve(term_count); + compiled.fault_was_active_before.reserve(term_count); + compiled.current_costs.reserve(term_count); + compiled.next_costs.reserve(term_count); + auto append_term = [&](size_t idx) { + const uint32_t local = transition.fault_local_indices[idx]; const uint32_t detector = (uint32_t)layer.current_active_detectors[local]; compiled.fault_target_word_indices.push_back((uint32_t)detector_word_index(detector)); compiled.fault_target_bit_masks.push_back(detector_word_mask(detector)); compiled.fault_word_indices.push_back(static_cast(local >> 6)); compiled.fault_bit_masks.push_back(uint64_t{1} << (local & 63)); compiled.fault_was_active_before.push_back(local < layer.previous_width); + compiled.current_costs.push_back(transition.current_costs[idx]); + compiled.next_costs.push_back(transition.next_costs[idx]); + }; + for (size_t idx = 0; idx < term_count; ++idx) { + if (transition.next_local_indices[idx] >= 0) { + append_term(idx); + } + } + compiled.surviving_term_count = compiled.fault_target_word_indices.size(); + compiled.has_retiring_terms = compiled.surviving_term_count != term_count; + for (size_t idx = 0; idx < term_count; ++idx) { + if (transition.next_local_indices[idx] < 0) { + append_term(idx); + } } compiled_layers.push_back(std::move(compiled)); @@ -792,42 +863,24 @@ struct CompiledWideKernel final : TesseractTrellisWideKernelBase { << std::endl; std::cout << "states to expand = " << beam_entries.size() << std::endl; } - for (const auto& item : beam_entries) { - ++decoder->num_states_expanded; - BranchPenaltyUpdate update = compute_compiled_wide_branch_update( - item.state_words, item.penalty, actual_detector_words, layer, compute_penalties); - - if (!update.absent_valid && !update.present_valid) { - continue; - } - - FixedWideStateWords projected_state = - project_compiled_wide_state(item.state_words, layer); - FixedWideStateWords projected_toggled = projected_state; - xor_compiled_wide_state(&projected_toggled, layer.projected_fault_mask_words); - const bool projected_is_key = !fixed_wide_state_less(projected_toggled, projected_state); - const auto& bucket_key = projected_is_key ? projected_state : projected_toggled; - const uint8_t absent_slot = projected_is_key ? 0 : 1; - const uint8_t present_slot = projected_toggled == bucket_key ? 0 : 1; - const size_t bucket_index = - find_or_insert_pair_bucket(&pair_buckets, &used_bucket_indices, bucket_key); - auto& bucket = pair_buckets[bucket_index]; - const bool keep_absent = update.absent_valid && layer.q != 0.0; - const bool keep_present = update.present_valid && layer.p != 0.0; - - if (keep_absent) { - accumulate_pair_bucket_slot(&bucket, absent_slot, item.mass0 * layer.q, - item.mass1 * layer.q, update.absent_penalty); - } - if (keep_present) { - if (layer.toggles_observable) { - accumulate_pair_bucket_slot(&bucket, present_slot, item.mass1 * layer.p, - item.mass0 * layer.p, update.present_penalty); - } else { - accumulate_pair_bucket_slot(&bucket, present_slot, item.mass0 * layer.p, - item.mass1 * layer.p, update.present_penalty); - } + if (compute_penalties) { + if (layer.has_retiring_terms) { + expand_compiled_layer_into_pair_buckets( + beam_entries, &pair_buckets, &used_bucket_indices, actual_detector_words, layer, + decoder); + } else { + expand_compiled_layer_into_pair_buckets( + beam_entries, &pair_buckets, &used_bucket_indices, actual_detector_words, layer, + decoder); } + } else if (layer.has_retiring_terms) { + expand_compiled_layer_into_pair_buckets( + beam_entries, &pair_buckets, &used_bucket_indices, actual_detector_words, layer, + decoder); + } else { + expand_compiled_layer_into_pair_buckets( + beam_entries, &pair_buckets, &used_bucket_indices, actual_detector_words, layer, + decoder); } auto t1 = std::chrono::high_resolution_clock::now(); decoder->time_expand_seconds +=