From c49f3cd94759257d1d408570fdddcbaf0a39797e Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Thu, 28 May 2026 17:06:40 -0400 Subject: [PATCH 01/14] Enable decoder construction from Stim DEM strings Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 17 ++++ libs/qec/lib/CMakeLists.txt | 13 ++- libs/qec/lib/decoder_stim_dem.cpp | 125 ++++++++++++++++++++++++ libs/qec/python/bindings/py_decoder.cpp | 12 +++ libs/qec/unittests/test_decoders.cpp | 51 ++++++++++ 5 files changed, 217 insertions(+), 1 deletion(-) create mode 100644 libs/qec/lib/decoder_stim_dem.cpp diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index b9fa4b13..0eaa0f3e 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -427,4 +427,21 @@ inline void convert_vec_hard_to_soft(const std::vector> &in, std::unique_ptr get_decoder(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, const cudaqx::heterogeneous_map options = {}); + +/// @brief Creator function for a decoder constructed from a Stim DEM string. +using stim_dem_decoder_creator = std::function( + const std::string &, const cudaqx::heterogeneous_map &)>; + +/// @brief Register a Stim-DEM-string creator for the named decoder. +void register_stim_dem_decoder_creator(const std::string &name, + stim_dem_decoder_creator creator); + +/// @brief Construct a decoder by name from a Stim detector error model text. +/// When no registered creator is found, the DEM is parsed and observables / +/// error rates are injected into \p options under keys \c "O" and +/// \c "error_rate_vec". +std::unique_ptr +get_decoder_from_stim_dem(const std::string &name, + const std::string &stim_dem_text, + const cudaqx::heterogeneous_map options = {}); } // namespace cudaq::qec diff --git a/libs/qec/lib/CMakeLists.txt b/libs/qec/lib/CMakeLists.txt index ef4acf18..c0fec3bf 100644 --- a/libs/qec/lib/CMakeLists.txt +++ b/libs/qec/lib/CMakeLists.txt @@ -39,12 +39,13 @@ endif() set(QEC_SOURCES code.cpp decoder.cpp + decoder_stim_dem.cpp detector_error_model.cpp experiments.cpp pcm_utils.cpp plugin_loader.cpp sparse_binary_matrix.cpp - stabilizer_utils.cpp + stabilizer_utils.cpp decoders/lut.cpp decoders/sliding_window.cpp version.cpp @@ -61,6 +62,16 @@ add_library(${LIBRARY_NAME} SHARED ${QEC_SOURCES}) add_subdirectory(decoders/plugins/example) add_subdirectory(decoders/plugins/pymatching) +# Exclude libstim symbols from libqec's exports (cf. CUDA-Q #3045). +if(NOT TARGET libstim) + message(FATAL_ERROR + "libstim target not available; required by cudaq-qec for Stim DEM parsing.") +endif() +target_link_libraries(${LIBRARY_NAME} PRIVATE libstim) +target_link_options(${LIBRARY_NAME} PRIVATE + $<$:-Wl,--exclude-libs,ALL> +) + # The TRT decoder plugin honors the tri-state `CUDAQ_QEC_BUILD_TRT_DECODER` # cache variable (AUTO/ON/OFF) declared in the parent CMakeLists.txt. Skip # descending entirely when the user explicitly opted out; otherwise let the diff --git a/libs/qec/lib/decoder_stim_dem.cpp b/libs/qec/lib/decoder_stim_dem.cpp new file mode 100644 index 00000000..404248e1 --- /dev/null +++ b/libs/qec/lib/decoder_stim_dem.cpp @@ -0,0 +1,125 @@ +/******************************************************************************* + * Copyright (c) 2024 - 2026 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "stim.h" +#include "cudaq/qec/decoder.h" + +#include +#include +#include +#include +#include + +namespace cudaq::qec { + +namespace { + +std::pair &> +get_stim_dem_registry() { + static std::recursive_mutex *mutex = new std::recursive_mutex(); + static auto *registry = + new std::unordered_map(); + return {*mutex, *registry}; +} + +struct extracted_dem { + cudaqx::tensor H; + cudaqx::tensor O; + std::vector error_rates; +}; + +extracted_dem extract_from_stim_dem(const std::string &dem_text) { + stim::DetectorErrorModel dem(dem_text); + const std::size_t num_detectors = + static_cast(dem.count_detectors()); + const std::size_t num_observables = + static_cast(dem.count_observables()); + + std::vector> detector_hits; + std::vector> observable_hits; + std::vector rates; + + dem.iter_flatten_error_instructions([&](const stim::DemInstruction &inst) { + if (inst.arg_data.size() == 0) + throw std::runtime_error( + "Stim DEM error instruction missing probability argument"); + const double prob = inst.arg_data[0]; + std::vector dets; + std::vector obs; + for (const auto &target : inst.target_data) { + if (target.is_separator()) + continue; + if (target.is_relative_detector_id()) { + dets.push_back(static_cast(target.val())); + } else if (target.is_observable_id()) { + obs.push_back(static_cast(target.val())); + } + } + detector_hits.push_back(std::move(dets)); + observable_hits.push_back(std::move(obs)); + rates.push_back(prob); + }); + + const std::size_t num_errors = rates.size(); + extracted_dem result; + result.H = cudaqx::tensor({num_detectors, num_errors}); + result.O = cudaqx::tensor({num_observables, num_errors}); + result.error_rates = std::move(rates); + + for (std::size_t err = 0; err < num_errors; ++err) { + for (auto det : detector_hits[err]) { + if (det >= num_detectors) + throw std::runtime_error( + "Stim DEM detector id out of range while extracting H"); + result.H.at({det, err}) = 1; + } + for (auto ob : observable_hits[err]) { + if (ob >= num_observables) + throw std::runtime_error( + "Stim DEM observable id out of range while extracting O"); + result.O.at({ob, err}) = 1; + } + } + + return result; +} + +} // namespace + +void register_stim_dem_decoder_creator(const std::string &name, + stim_dem_decoder_creator creator) { + auto [mutex, registry] = get_stim_dem_registry(); + std::lock_guard lock(mutex); + registry[name] = std::move(creator); +} + +std::unique_ptr +get_decoder_from_stim_dem(const std::string &name, + const std::string &stim_dem_text, + const cudaqx::heterogeneous_map options) { + { + auto [mutex, registry] = get_stim_dem_registry(); + std::lock_guard lock(mutex); + auto iter = registry.find(name); + if (iter != registry.end()) + return iter->second(stim_dem_text, options); + } + + auto extracted = extract_from_stim_dem(stim_dem_text); + + cudaqx::heterogeneous_map merged = options; + if (!merged.contains("O")) + merged.insert("O", extracted.O); + if (!merged.contains("error_rate_vec")) + merged.insert("error_rate_vec", extracted.error_rates); + + return decoder::get(name, extracted.H, merged); +} + +} // namespace cudaq::qec diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 6732cb0a..707d9fb2 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -766,6 +766,18 @@ void bindDecoder(nb::module_ &mod) { allocation. )pbdoc"); + qecmod.def( + "get_decoder_from_stim_dem", + [](const std::string &name, const std::string &dem_text, + const nb::kwargs options) -> std::unique_ptr { + return get_decoder_from_stim_dem(name, dem_text, + hetMapFromKwargs(options)); + }, + "Construct a decoder by name from a Stim detector error model string. " + "Observables and per-error rates from the DEM are injected into options " + "under keys \"O\" and \"error_rate_vec\" when no registered Stim-DEM " + "creator is found."); + qecmod.def( "get_sorted_pcm_column_indices", [](const nb::ndarray &H, diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index c21eecbc..42c0e636 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -768,3 +768,54 @@ TEST(DecoderRegistryTest, SingleParameterRegistryDirect) { // works This test passes if no exceptions are thrown, proving the // single-parameter registry is instantiated } + +TEST(StimDemDecoderFactory, ConstructsLutDecoderFromStimDemText) { + const std::string dem_text = R"(error(0.1) D0 L0 +error(0.1) D1 L0 +error(0.05) D0 D1 +)"; + + auto d = cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text); + ASSERT_NE(d, nullptr); + EXPECT_EQ(d->get_syndrome_size(), 2u); + EXPECT_EQ(d->get_block_size(), 3u); + + std::vector syndrome = {1.0, 0.0}; + auto result = d->decode(syndrome); + EXPECT_EQ(result.result.size(), 3u); +} + +TEST(StimDemDecoderFactory, ThrowsOnMalformedStimDem) { + EXPECT_THROW(cudaq::qec::get_decoder_from_stim_dem("single_error_lut", + "not a valid DEM"), + std::exception); +} + +TEST(StimDemDecoderFactory, ThrowsOnUnknownDecoderName) { + const std::string dem_text = "error(0.1) D0\n"; + EXPECT_THROW( + cudaq::qec::get_decoder_from_stim_dem("__no_such_decoder__", dem_text), + std::runtime_error); +} + +TEST(StimDemDecoderFactory, RegisteredCreatorIsUsed) { + // static: the registry outlives this test, so the lambda must not capture + // a stack reference. + static bool registered_creator_was_called = false; + registered_creator_was_called = false; + cudaq::qec::register_stim_dem_decoder_creator( + "__stim_dem_test_decoder__", + [](const std::string &dem_text, const cudaqx::heterogeneous_map &) + -> std::unique_ptr { + registered_creator_was_called = true; + EXPECT_EQ(dem_text, "passthrough"); + cudaqx::tensor H({2u, 2u}); + cudaqx::heterogeneous_map empty; + return cudaq::qec::decoder::get("single_error_lut", H, empty); + }); + + auto d = cudaq::qec::get_decoder_from_stim_dem("__stim_dem_test_decoder__", + "passthrough"); + EXPECT_TRUE(registered_creator_was_called); + ASSERT_NE(d, nullptr); +} From 00acd546f86c7aaf964d5790325e1bcb3ad1700e Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Mon, 1 Jun 2026 17:39:37 -0400 Subject: [PATCH 02/14] update test case with syndrome patterns Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 21 +++- .../include/cudaq/qec/detector_error_model.h | 8 ++ libs/qec/lib/CMakeLists.txt | 3 +- libs/qec/lib/decoder_stim_dem.cpp | 116 ++++++------------ libs/qec/lib/detector_error_model.cpp | 68 ++++++++++ libs/qec/python/bindings/py_decoder.cpp | 36 +++++- libs/qec/python/tests/test_decoder.py | 35 ++++++ libs/qec/unittests/test_decoders.cpp | 102 ++++++++++++++- 8 files changed, 298 insertions(+), 91 deletions(-) diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index 0eaa0f3e..dee99f14 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -12,8 +12,11 @@ #include "cuda-qx/core/heterogeneous_map.h" #include "cuda-qx/core/tensor.h" #include "sparse_binary_matrix.h" +#include #include +#include #include +#include #include namespace cudaq::qec { @@ -433,13 +436,25 @@ using stim_dem_decoder_creator = std::function( const std::string &, const cudaqx::heterogeneous_map &)>; /// @brief Register a Stim-DEM-string creator for the named decoder. +/// @see get_decoder_from_stim_dem void register_stim_dem_decoder_creator(const std::string &name, stim_dem_decoder_creator creator); +/// @brief Unregister a previously registered Stim-DEM-string creator. No-op if +/// \p name has no registered creator. +/// @see register_stim_dem_decoder_creator +void unregister_stim_dem_decoder_creator(const std::string &name); + /// @brief Construct a decoder by name from a Stim detector error model text. -/// When no registered creator is found, the DEM is parsed and observables / -/// error rates are injected into \p options under keys \c "O" and -/// \c "error_rate_vec". +/// +/// When a Stim-DEM creator is registered for \p name it is used directly. +/// Otherwise the DEM is parsed and forwarded to the existing H-based path +/// after injecting two derived entries into \p options if they are not +/// already present: +/// - \c "O" : \c cudaqx::tensor observables_flips_matrix +/// - \c "error_rate_vec" : \c std::vector per-error probabilities +/// User-supplied values for either key win over the DEM-derived ones. +/// @see register_stim_dem_decoder_creator std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, diff --git a/libs/qec/include/cudaq/qec/detector_error_model.h b/libs/qec/include/cudaq/qec/detector_error_model.h index 747f01bb..c0a03cb6 100644 --- a/libs/qec/include/cudaq/qec/detector_error_model.h +++ b/libs/qec/include/cudaq/qec/detector_error_model.h @@ -9,6 +9,7 @@ #include "cuda-qx/core/tensor.h" #include +#include namespace cudaq::qec { @@ -65,4 +66,11 @@ struct detector_error_model { void canonicalize_for_rounds(uint32_t num_syndromes_per_round); }; +/// @brief Parse a Stim detector error model text into a +/// \p cudaq::qec::detector_error_model. Each \c error instruction in the DEM +/// becomes a single column in \p detector_error_matrix and +/// \p observables_flips_matrix; suggested decomposition separators are +/// folded into the same column. +detector_error_model dem_from_stim_text(const std::string &dem_text); + } // namespace cudaq::qec diff --git a/libs/qec/lib/CMakeLists.txt b/libs/qec/lib/CMakeLists.txt index c0fec3bf..92991048 100644 --- a/libs/qec/lib/CMakeLists.txt +++ b/libs/qec/lib/CMakeLists.txt @@ -62,14 +62,13 @@ add_library(${LIBRARY_NAME} SHARED ${QEC_SOURCES}) add_subdirectory(decoders/plugins/example) add_subdirectory(decoders/plugins/pymatching) -# Exclude libstim symbols from libqec's exports (cf. CUDA-Q #3045). if(NOT TARGET libstim) message(FATAL_ERROR "libstim target not available; required by cudaq-qec for Stim DEM parsing.") endif() target_link_libraries(${LIBRARY_NAME} PRIVATE libstim) target_link_options(${LIBRARY_NAME} PRIVATE - $<$:-Wl,--exclude-libs,ALL> + $<$,$>:-Wl,--exclude-libs,libstim.a> ) # The TRT decoder plugin honors the tri-state `CUDAQ_QEC_BUILD_TRT_DECODER` diff --git a/libs/qec/lib/decoder_stim_dem.cpp b/libs/qec/lib/decoder_stim_dem.cpp index 404248e1..54560d90 100644 --- a/libs/qec/lib/decoder_stim_dem.cpp +++ b/libs/qec/lib/decoder_stim_dem.cpp @@ -6,120 +6,74 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ -#include "stim.h" #include "cudaq/qec/decoder.h" +#include "cudaq/qec/detector_error_model.h" #include #include #include #include -#include namespace cudaq::qec { namespace { -std::pair &> -get_stim_dem_registry() { +struct stim_dem_registry { + std::recursive_mutex &mutex; + std::unordered_map ↦ +}; +stim_dem_registry get_stim_dem_registry() { static std::recursive_mutex *mutex = new std::recursive_mutex(); - static auto *registry = + static auto *map = new std::unordered_map(); - return {*mutex, *registry}; -} - -struct extracted_dem { - cudaqx::tensor H; - cudaqx::tensor O; - std::vector error_rates; -}; - -extracted_dem extract_from_stim_dem(const std::string &dem_text) { - stim::DetectorErrorModel dem(dem_text); - const std::size_t num_detectors = - static_cast(dem.count_detectors()); - const std::size_t num_observables = - static_cast(dem.count_observables()); - - std::vector> detector_hits; - std::vector> observable_hits; - std::vector rates; - - dem.iter_flatten_error_instructions([&](const stim::DemInstruction &inst) { - if (inst.arg_data.size() == 0) - throw std::runtime_error( - "Stim DEM error instruction missing probability argument"); - const double prob = inst.arg_data[0]; - std::vector dets; - std::vector obs; - for (const auto &target : inst.target_data) { - if (target.is_separator()) - continue; - if (target.is_relative_detector_id()) { - dets.push_back(static_cast(target.val())); - } else if (target.is_observable_id()) { - obs.push_back(static_cast(target.val())); - } - } - detector_hits.push_back(std::move(dets)); - observable_hits.push_back(std::move(obs)); - rates.push_back(prob); - }); - - const std::size_t num_errors = rates.size(); - extracted_dem result; - result.H = cudaqx::tensor({num_detectors, num_errors}); - result.O = cudaqx::tensor({num_observables, num_errors}); - result.error_rates = std::move(rates); - - for (std::size_t err = 0; err < num_errors; ++err) { - for (auto det : detector_hits[err]) { - if (det >= num_detectors) - throw std::runtime_error( - "Stim DEM detector id out of range while extracting H"); - result.H.at({det, err}) = 1; - } - for (auto ob : observable_hits[err]) { - if (ob >= num_observables) - throw std::runtime_error( - "Stim DEM observable id out of range while extracting O"); - result.O.at({ob, err}) = 1; - } - } - - return result; + return {*mutex, *map}; } } // namespace void register_stim_dem_decoder_creator(const std::string &name, stim_dem_decoder_creator creator) { - auto [mutex, registry] = get_stim_dem_registry(); - std::lock_guard lock(mutex); - registry[name] = std::move(creator); + auto reg = get_stim_dem_registry(); + std::lock_guard lock(reg.mutex); + reg.map[name] = std::move(creator); +} + +void unregister_stim_dem_decoder_creator(const std::string &name) { + auto reg = get_stim_dem_registry(); + std::lock_guard lock(reg.mutex); + reg.map.erase(name); } std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, const cudaqx::heterogeneous_map options) { + stim_dem_decoder_creator creator; { - auto [mutex, registry] = get_stim_dem_registry(); - std::lock_guard lock(mutex); - auto iter = registry.find(name); - if (iter != registry.end()) - return iter->second(stim_dem_text, options); + auto reg = get_stim_dem_registry(); + std::lock_guard lock(reg.mutex); + auto iter = reg.map.find(name); + if (iter != reg.map.end()) + creator = iter->second; } + if (creator) + return creator(stim_dem_text, options); + + if (!decoder::is_registered(name)) + throw std::runtime_error( + "get_decoder_from_stim_dem: decoder \"" + name + + "\" is not registered. Run with CUDAQ_LOG_LEVEL=info to see plugin " + "diagnostics at startup."); - auto extracted = extract_from_stim_dem(stim_dem_text); + auto dem = dem_from_stim_text(stim_dem_text); cudaqx::heterogeneous_map merged = options; if (!merged.contains("O")) - merged.insert("O", extracted.O); + merged.insert("O", dem.observables_flips_matrix); if (!merged.contains("error_rate_vec")) - merged.insert("error_rate_vec", extracted.error_rates); + merged.insert("error_rate_vec", dem.error_rates); - return decoder::get(name, extracted.H, merged); + return decoder::get(name, dem.detector_error_matrix, merged); } } // namespace cudaq::qec diff --git a/libs/qec/lib/detector_error_model.cpp b/libs/qec/lib/detector_error_model.cpp index 14612ea4..5e2e6d58 100644 --- a/libs/qec/lib/detector_error_model.cpp +++ b/libs/qec/lib/detector_error_model.cpp @@ -10,8 +10,76 @@ #include "cudaq/qec/pcm_utils.h" #include "cudaq/runtime/logger/logger.h" +#include "stim.h" + namespace cudaq::qec { +detector_error_model dem_from_stim_text(const std::string &dem_text) { + stim::DetectorErrorModel dem(dem_text); + const std::size_t num_detectors = + static_cast(dem.count_detectors()); + const std::size_t num_observables = + static_cast(dem.count_observables()); + + std::vector> detector_hits; + std::vector> observable_hits; + std::vector rates; + std::size_t instruction_index = 0; + + dem.iter_flatten_error_instructions([&](const stim::DemInstruction &inst) { + if (inst.arg_data.size() == 0) + throw std::runtime_error( + "Stim DEM error instruction missing probability argument (index " + + std::to_string(instruction_index) + ")"); + const double prob = inst.arg_data[0]; + if (!(prob >= 0.0 && prob <= 1.0)) + throw std::runtime_error("Stim DEM error probability " + + std::to_string(prob) + + " out of range [0, 1] at instruction index " + + std::to_string(instruction_index)); + std::vector dets; + std::vector obs; + for (const auto &target : inst.target_data) { + if (target.is_separator()) + continue; + if (target.is_relative_detector_id()) { + dets.push_back(static_cast(target.val())); + } else if (target.is_observable_id()) { + obs.push_back(static_cast(target.val())); + } + } + detector_hits.push_back(std::move(dets)); + observable_hits.push_back(std::move(obs)); + rates.push_back(prob); + ++instruction_index; + }); + + const std::size_t num_errors = rates.size(); + detector_error_model result; + result.detector_error_matrix = + cudaqx::tensor({num_detectors, num_errors}); + result.observables_flips_matrix = + cudaqx::tensor({num_observables, num_errors}); + result.error_rates = std::move(rates); + + for (std::size_t err = 0; err < num_errors; ++err) { + for (auto det : detector_hits[err]) { + if (det >= num_detectors) + throw std::runtime_error( + "Stim DEM detector id out of range while extracting H"); + result.detector_error_matrix.at({det, err}) ^= 1; + } + for (auto ob : observable_hits[err]) { + if (ob >= num_observables) + throw std::runtime_error( + "Stim DEM observable id out of range while extracting O"); + result.observables_flips_matrix.at({ob, err}) ^= 1; + } + } + + return result; +} + std::size_t detector_error_model::num_detectors() const { auto shape = detector_error_matrix.shape(); if (shape.size() == 2) diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 707d9fb2..2e6e2d7c 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -769,14 +769,46 @@ void bindDecoder(nb::module_ &mod) { qecmod.def( "get_decoder_from_stim_dem", [](const std::string &name, const std::string &dem_text, - const nb::kwargs options) -> std::unique_ptr { + nb::kwargs options) + -> std::variant> { + if (PyDecoderRegistry::contains(name)) { + auto dem = dem_from_stim_text(dem_text); + + if (!options.contains("O")) { + const auto &O = dem.observables_flips_matrix; + size_t shape[2] = {O.shape()[0], O.shape()[1]}; + auto O_arr = nb::ndarray( + const_cast(O.data()), 2, shape, nb::none()); + options["O"] = nb::cast(O_arr).attr("copy")(); + } + if (!options.contains("error_rate_vec")) { + const auto &rates = dem.error_rates; + size_t rates_shape[1] = {rates.size()}; + auto rates_arr = nb::ndarray( + const_cast(rates.data()), 1, rates_shape, nb::none()); + options["error_rate_vec"] = nb::cast(rates_arr).attr("copy")(); + } + + const auto &H = dem.detector_error_matrix; + size_t H_shape[2] = {H.shape()[0], H.shape()[1]}; + auto H_arr = nb::ndarray( + const_cast(H.data()), 2, H_shape, nb::none()); + nb::object H_obj = nb::cast(H_arr).attr("copy")(); + return PyDecoderRegistry::get_decoder( + name, nb::cast>(H_obj), options); + } + return get_decoder_from_stim_dem(name, dem_text, hetMapFromKwargs(options)); }, "Construct a decoder by name from a Stim detector error model string. " "Observables and per-error rates from the DEM are injected into options " "under keys \"O\" and \"error_rate_vec\" when no registered Stim-DEM " - "creator is found."); + "creator is found. User-supplied values for either key win over the " + "DEM-derived ones. Python decoders registered via @qec.decoder receive " + "the parsed H and O as numpy.ndarray and error_rate_vec as a 1-D " + "numpy.ndarray of float64; to register a native DEM consumer, use the " + "C++ register_stim_dem_decoder_creator API."); qecmod.def( "get_sorted_pcm_column_indices", diff --git a/libs/qec/python/tests/test_decoder.py b/libs/qec/python/tests/test_decoder.py index 25d31da4..3e6e39e1 100644 --- a/libs/qec/python/tests/test_decoder.py +++ b/libs/qec/python/tests/test_decoder.py @@ -752,5 +752,40 @@ def test_generate_random_pcm_signed_weight_rejects_negative(): seed=1) +def test_get_decoder_from_stim_dem(): + # 2 detectors, 1 observable, 3 errors. Matches the C++ + # StimDemDecoderFactory.ConstructsLutDecoderFromStimDemText DEM so the + # truth-data assertions stay in sync across language bindings. + dem_text = ("error(0.1) D0 L0\n" + "error(0.1) D1 L0\n" + "error(0.05) D0 D1\n") + + decoder = qec.get_decoder_from_stim_dem("single_error_lut", dem_text) + assert decoder is not None + assert decoder.get_syndrome_size() == 2 + assert decoder.get_block_size() == 3 + + cases = [ + ([0.0, 0.0], [0.0, 0.0, 0.0]), + ([1.0, 0.0], [1.0, 0.0, 0.0]), + ([0.0, 1.0], [0.0, 1.0, 0.0]), + ([1.0, 1.0], [0.0, 0.0, 1.0]), + ] + for syndrome, expected in cases: + result = decoder.decode(syndrome) + assert result.converged is True, f"syndrome {syndrome}" + assert list(result.result) == expected, f"syndrome {syndrome}" + + +def test_get_decoder_from_stim_dem_rejects_malformed_text(): + with pytest.raises(Exception): + qec.get_decoder_from_stim_dem("single_error_lut", "not a valid DEM") + + +def test_get_decoder_from_stim_dem_rejects_unknown_decoder(): + with pytest.raises(RuntimeError): + qec.get_decoder_from_stim_dem("__no_such_decoder__", "error(0.1) D0\n") + + if __name__ == "__main__": pytest.main() diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index 42c0e636..680da665 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -780,9 +780,26 @@ error(0.05) D0 D1 EXPECT_EQ(d->get_syndrome_size(), 2u); EXPECT_EQ(d->get_block_size(), 3u); - std::vector syndrome = {1.0, 0.0}; - auto result = d->decode(syndrome); - EXPECT_EQ(result.result.size(), 3u); + struct Case { + std::vector syndrome; + std::vector expected; + }; + const std::vector cases = { + {{0.0, 0.0}, {0.0, 0.0, 0.0}}, + {{1.0, 0.0}, {1.0, 0.0, 0.0}}, + {{0.0, 1.0}, {0.0, 1.0, 0.0}}, + {{1.0, 1.0}, {0.0, 0.0, 1.0}}, + }; + for (const auto &c : cases) { + auto result = d->decode(c.syndrome); + EXPECT_TRUE(result.converged) + << "syndrome {" << c.syndrome[0] << ", " << c.syndrome[1] << "}"; + ASSERT_EQ(result.result.size(), 3u); + for (std::size_t i = 0; i < 3u; ++i) + EXPECT_FLOAT_EQ(result.result[i], c.expected[i]) + << "error " << i << " for syndrome {" << c.syndrome[0] << ", " + << c.syndrome[1] << "}"; + } } TEST(StimDemDecoderFactory, ThrowsOnMalformedStimDem) { @@ -818,4 +835,83 @@ TEST(StimDemDecoderFactory, RegisteredCreatorIsUsed) { "passthrough"); EXPECT_TRUE(registered_creator_was_called); ASSERT_NE(d, nullptr); + cudaq::qec::unregister_stim_dem_decoder_creator("__stim_dem_test_decoder__"); +} + +TEST(StimDemDecoderFactory, RepeatedDetectorOrObservableTargetsXorFold) { + const std::string dem_text = R"(error(0.1) D0 D0 +error(0.1) L0 L0 +)"; + + auto dem = cudaq::qec::dem_from_stim_text(dem_text); + ASSERT_EQ(dem.num_detectors(), 1u); + ASSERT_EQ(dem.num_observables(), 1u); + ASSERT_EQ(dem.num_error_mechanisms(), 2u); + EXPECT_EQ(dem.detector_error_matrix.at({0u, 0u}), 0u) + << "duplicate D0 in error 0 should XOR-cancel to 0"; + EXPECT_EQ(dem.observables_flips_matrix.at({0u, 1u}), 0u) + << "duplicate L0 in error 1 should XOR-cancel to 0"; +} + +TEST(StimDemDecoderFactory, ThrowsOnProbabilityOutOfRange) { + const std::string dem_text = "error(1.5) D0\n"; + EXPECT_THROW( + cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text), + std::runtime_error); +} + +TEST(StimDemDecoderFactory, RegisteredCreatorTakesPrecedenceOverFallback) { + static bool creator_was_called = false; + creator_was_called = false; + cudaq::qec::register_stim_dem_decoder_creator( + "single_error_lut", + [](const std::string &, const cudaqx::heterogeneous_map &) + -> std::unique_ptr { + creator_was_called = true; + cudaqx::tensor H({2u, 2u}); + cudaqx::heterogeneous_map empty; + return cudaq::qec::decoder::get("single_error_lut", H, empty); + }); + + const std::string dem_text = "error(0.1) D0 L0\n"; + auto d = cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text); + EXPECT_TRUE(creator_was_called); + ASSERT_NE(d, nullptr); + cudaq::qec::unregister_stim_dem_decoder_creator("single_error_lut"); +} + +TEST(StimDemDecoderFactory, UserOptionsAreNotOverwritten) { + const std::string dem_text = R"(error(0.1) D0 L0 +error(0.1) D1 L0 +error(0.05) D0 D1 +)"; + cudaqx::heterogeneous_map opts; + opts.insert("error_rate_vec", std::vector{0.5}); // wrong size + EXPECT_THROW( + cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text, opts), + std::runtime_error); +} + +TEST(StimDemDecoderFactory, RegisteredCreatorReceivesUserOptionsVerbatim) { + static std::vector observed_rates; + observed_rates.clear(); + cudaq::qec::register_stim_dem_decoder_creator( + "__stim_dem_echo__", + [](const std::string &, const cudaqx::heterogeneous_map &opts) + -> std::unique_ptr { + if (opts.contains("error_rate_vec")) + observed_rates = opts.get>("error_rate_vec"); + cudaqx::tensor H({2u, 2u}); + cudaqx::heterogeneous_map empty; + return cudaq::qec::decoder::get("single_error_lut", H, empty); + }); + + const std::vector user_rates = {0.42, 0.13, 0.07}; + cudaqx::heterogeneous_map opts; + opts.insert("error_rate_vec", user_rates); + auto d = cudaq::qec::get_decoder_from_stim_dem("__stim_dem_echo__", + "error(0.5) D0\n", opts); + ASSERT_NE(d, nullptr); + EXPECT_EQ(observed_rates, user_rates); + cudaq::qec::unregister_stim_dem_decoder_creator("__stim_dem_echo__"); } From dde324998e6aec292e757ec831e3292ee56da92b Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Mon, 1 Jun 2026 23:30:25 -0400 Subject: [PATCH 03/14] fix ci failure Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 27 +++- .../include/cudaq/qec/detector_error_model.h | 8 +- libs/qec/lib/CMakeLists.txt | 1 + libs/qec/lib/decoder_stim_dem.cpp | 38 ++++-- libs/qec/lib/detector_error_model.cpp | 23 +++- libs/qec/python/bindings/py_decoder.cpp | 44 ++++--- libs/qec/python/cudaq_qec/__init__.py | 1 + libs/qec/python/tests/test_decoder.py | 16 ++- libs/qec/unittests/CMakeLists.txt | 4 +- libs/qec/unittests/test_decoders.cpp | 115 +++++++++++++++--- 10 files changed, 222 insertions(+), 55 deletions(-) diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index dee99f14..6232aa2a 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -431,11 +431,29 @@ std::unique_ptr get_decoder(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, const cudaqx::heterogeneous_map options = {}); +struct detector_error_model; + +/// @brief DEM-derived defaults; pointers alias into the source `dem`. +struct dem_default_values { + const cudaqx::tensor *O = nullptr; + const std::vector *error_rate_vec = nullptr; +}; + +/// @brief Return DEM defaults for any key not already supplied by the user. +/// Shared by `get_decoder_from_stim_dem` and its Python binding. +dem_default_values dem_defaults_for_missing_keys( + const std::function &contains_user_key, + const detector_error_model &dem); + /// @brief Creator function for a decoder constructed from a Stim DEM string. using stim_dem_decoder_creator = std::function( const std::string &, const cudaqx::heterogeneous_map &)>; /// @brief Register a Stim-DEM-string creator for the named decoder. +/// @param name Decoder name; same name used by `get_decoder_from_stim_dem`. +/// @param creator Builds a decoder from the raw DEM string + options. Takes +/// precedence over the H/O fallback; must not re-enter the registry (the +/// factory copies it out before invoking). /// @see get_decoder_from_stim_dem void register_stim_dem_decoder_creator(const std::string &name, stim_dem_decoder_creator creator); @@ -451,12 +469,15 @@ void unregister_stim_dem_decoder_creator(const std::string &name); /// Otherwise the DEM is parsed and forwarded to the existing H-based path /// after injecting two derived entries into \p options if they are not /// already present: -/// - \c "O" : \c cudaqx::tensor observables_flips_matrix -/// - \c "error_rate_vec" : \c std::vector per-error probabilities +/// - `"O"` : `cudaqx::tensor` observables_flips_matrix +/// - `"error_rate_vec"` : `std::vector` per-error probabilities /// User-supplied values for either key win over the DEM-derived ones. +/// +/// @note Decoders that need full DEM metadata (e.g. Chromobius) must +/// register a Stim-DEM creator; the fallback only extracts H/O/rates. /// @see register_stim_dem_decoder_creator std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, - const cudaqx::heterogeneous_map options = {}); + const cudaqx::heterogeneous_map &options = {}); } // namespace cudaq::qec diff --git a/libs/qec/include/cudaq/qec/detector_error_model.h b/libs/qec/include/cudaq/qec/detector_error_model.h index c0a03cb6..6062083d 100644 --- a/libs/qec/include/cudaq/qec/detector_error_model.h +++ b/libs/qec/include/cudaq/qec/detector_error_model.h @@ -67,10 +67,16 @@ struct detector_error_model { }; /// @brief Parse a Stim detector error model text into a -/// \p cudaq::qec::detector_error_model. Each \c error instruction in the DEM +/// \p cudaq::qec::detector_error_model. Each `error` instruction in the DEM /// becomes a single column in \p detector_error_matrix and /// \p observables_flips_matrix; suggested decomposition separators are /// folded into the same column. +/// +/// @note Lossy: only detector/observable flips and error probabilities +/// are extracted. Annotations (`detector`, `logical_observable`), +/// suggested-decomposition separators, and \p error_ids are dropped. +/// Decoders that need the full DEM (e.g. Chromobius) must consume the +/// raw string via `register_stim_dem_decoder_creator`. detector_error_model dem_from_stim_text(const std::string &dem_text); } // namespace cudaq::qec diff --git a/libs/qec/lib/CMakeLists.txt b/libs/qec/lib/CMakeLists.txt index 92991048..2a60d52e 100644 --- a/libs/qec/lib/CMakeLists.txt +++ b/libs/qec/lib/CMakeLists.txt @@ -62,6 +62,7 @@ add_library(${LIBRARY_NAME} SHARED ${QEC_SOURCES}) add_subdirectory(decoders/plugins/example) add_subdirectory(decoders/plugins/pymatching) +# libstim comes from the parent build (CUDA-Q). if(NOT TARGET libstim) message(FATAL_ERROR "libstim target not available; required by cudaq-qec for Stim DEM parsing.") diff --git a/libs/qec/lib/decoder_stim_dem.cpp b/libs/qec/lib/decoder_stim_dem.cpp index 54560d90..e72eccdb 100644 --- a/libs/qec/lib/decoder_stim_dem.cpp +++ b/libs/qec/lib/decoder_stim_dem.cpp @@ -18,12 +18,16 @@ namespace cudaq::qec { namespace { +// std::mutex is enough: the factory copies the creator out before +// invoking, so creators cannot re-enter the registry. struct stim_dem_registry { - std::recursive_mutex &mutex; + std::mutex &mutex; std::unordered_map ↦ }; stim_dem_registry get_stim_dem_registry() { - static std::recursive_mutex *mutex = new std::recursive_mutex(); + // Heap-allocated to outlive static destructors (plugin dlclose unregister + // path); matches the cudaqx extension_point pattern. See extension_point.h. + static std::mutex *mutex = new std::mutex(); static auto *map = new std::unordered_map(); return {*mutex, *map}; @@ -31,27 +35,38 @@ stim_dem_registry get_stim_dem_registry() { } // namespace +dem_default_values dem_defaults_for_missing_keys( + const std::function &contains_user_key, + const detector_error_model &dem) { + dem_default_values out; + if (!contains_user_key("O")) + out.O = &dem.observables_flips_matrix; + if (!contains_user_key("error_rate_vec")) + out.error_rate_vec = &dem.error_rates; + return out; +} + void register_stim_dem_decoder_creator(const std::string &name, stim_dem_decoder_creator creator) { auto reg = get_stim_dem_registry(); - std::lock_guard lock(reg.mutex); + std::lock_guard lock(reg.mutex); reg.map[name] = std::move(creator); } void unregister_stim_dem_decoder_creator(const std::string &name) { auto reg = get_stim_dem_registry(); - std::lock_guard lock(reg.mutex); + std::lock_guard lock(reg.mutex); reg.map.erase(name); } std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, - const cudaqx::heterogeneous_map options) { + const cudaqx::heterogeneous_map &options) { stim_dem_decoder_creator creator; { auto reg = get_stim_dem_registry(); - std::lock_guard lock(reg.mutex); + std::lock_guard lock(reg.mutex); auto iter = reg.map.find(name); if (iter != reg.map.end()) creator = iter->second; @@ -68,10 +83,13 @@ get_decoder_from_stim_dem(const std::string &name, auto dem = dem_from_stim_text(stim_dem_text); cudaqx::heterogeneous_map merged = options; - if (!merged.contains("O")) - merged.insert("O", dem.observables_flips_matrix); - if (!merged.contains("error_rate_vec")) - merged.insert("error_rate_vec", dem.error_rates); + // Keep in sync with the Python binding in py_decoder.cpp. + auto defaults = dem_defaults_for_missing_keys( + [&](const std::string &key) { return merged.contains(key); }, dem); + if (defaults.O) + merged.insert("O", *defaults.O); + if (defaults.error_rate_vec) + merged.insert("error_rate_vec", *defaults.error_rate_vec); return decoder::get(name, dem.detector_error_matrix, merged); } diff --git a/libs/qec/lib/detector_error_model.cpp b/libs/qec/lib/detector_error_model.cpp index 5e2e6d58..0376cf78 100644 --- a/libs/qec/lib/detector_error_model.cpp +++ b/libs/qec/lib/detector_error_model.cpp @@ -15,7 +15,14 @@ namespace cudaq::qec { detector_error_model dem_from_stim_text(const std::string &dem_text) { - stim::DetectorErrorModel dem(dem_text); + auto dem = [&dem_text]() { + try { + return stim::DetectorErrorModel(dem_text); + } catch (const std::exception &e) { + throw std::runtime_error(std::string("Stim DEM parse failed: ") + + e.what()); + } + }(); const std::size_t num_detectors = static_cast(dem.count_detectors()); const std::size_t num_observables = @@ -46,6 +53,15 @@ detector_error_model dem_from_stim_text(const std::string &dem_text) { dets.push_back(static_cast(target.val())); } else if (target.is_observable_id()) { obs.push_back(static_cast(target.val())); + } else { + // Forward-compat tripwire; unreachable today (stim's three + // DemTarget categories are exhaustive -- pinned by + // StimDemTargetCategoriesAreExhaustive). + throw std::runtime_error( + "Stim DEM error instruction (index " + + std::to_string(instruction_index) + + ") contains an unsupported target kind; only D* (detector) and " + "L* (observable) targets are supported by the fallback parser"); } } detector_hits.push_back(std::move(dets)); @@ -55,6 +71,11 @@ detector_error_model dem_from_stim_text(const std::string &dem_text) { }); const std::size_t num_errors = rates.size(); + // Reject zero-column H at the boundary instead of letting decoders + // crash with block_size == 0. + if (num_errors == 0) + throw std::runtime_error( + "Stim DEM contains no error mechanisms after flattening"); detector_error_model result; result.detector_error_matrix = cudaqx::tensor({num_detectors, num_errors}); diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 2e6e2d7c..8c4ca4a8 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -339,6 +339,22 @@ makeBatchDecoderResult(const std::vector &results) { }; } +// Wrap a borrowed cudaqx buffer in a NumPy array and force a Python-side copy, +// so the returned object owns its data. +nb::object toPyArray(const cudaqx::tensor &t) { + size_t shape[2] = {t.shape()[0], t.shape()[1]}; + auto arr = nb::ndarray(const_cast(t.data()), 2, + shape, nb::none()); + return nb::cast(arr).attr("copy")(); +} + +nb::object toPyArray(const std::vector &v) { + size_t shape[1] = {v.size()}; + auto arr = nb::ndarray(const_cast(v.data()), 1, + shape, nb::none()); + return nb::cast(arr).attr("copy")(); +} + } // namespace void bindDecoder(nb::module_ &mod) { @@ -774,26 +790,16 @@ void bindDecoder(nb::module_ &mod) { if (PyDecoderRegistry::contains(name)) { auto dem = dem_from_stim_text(dem_text); - if (!options.contains("O")) { - const auto &O = dem.observables_flips_matrix; - size_t shape[2] = {O.shape()[0], O.shape()[1]}; - auto O_arr = nb::ndarray( - const_cast(O.data()), 2, shape, nb::none()); - options["O"] = nb::cast(O_arr).attr("copy")(); - } - if (!options.contains("error_rate_vec")) { - const auto &rates = dem.error_rates; - size_t rates_shape[1] = {rates.size()}; - auto rates_arr = nb::ndarray( - const_cast(rates.data()), 1, rates_shape, nb::none()); - options["error_rate_vec"] = nb::cast(rates_arr).attr("copy")(); - } + // Keep in sync with the C++ fallback in decoder_stim_dem.cpp. + auto defaults = dem_defaults_for_missing_keys( + [&](const std::string &key) { return options.contains(key); }, + dem); + if (defaults.O) + options["O"] = toPyArray(*defaults.O); + if (defaults.error_rate_vec) + options["error_rate_vec"] = toPyArray(*defaults.error_rate_vec); - const auto &H = dem.detector_error_matrix; - size_t H_shape[2] = {H.shape()[0], H.shape()[1]}; - auto H_arr = nb::ndarray( - const_cast(H.data()), 2, H_shape, nb::none()); - nb::object H_obj = nb::cast(H_arr).attr("copy")(); + nb::object H_obj = toPyArray(dem.detector_error_matrix); return PyDecoderRegistry::get_decoder( name, nb::cast>(H_obj), options); } diff --git a/libs/qec/python/cudaq_qec/__init__.py b/libs/qec/python/cudaq_qec/__init__.py index d2ae1add..ac037c23 100644 --- a/libs/qec/python/cudaq_qec/__init__.py +++ b/libs/qec/python/cudaq_qec/__init__.py @@ -84,6 +84,7 @@ def checked_decode_batch(self, *args, **kwargs): get_code = qecrt.get_code get_available_codes = qecrt.get_available_codes get_decoder = qecrt.get_decoder +get_decoder_from_stim_dem = qecrt.get_decoder_from_stim_dem DecoderResult = qecrt.DecoderResult BatchDecoderResult = qecrt.BatchDecoderResult DetectorErrorModel = qecrt.DetectorErrorModel diff --git a/libs/qec/python/tests/test_decoder.py b/libs/qec/python/tests/test_decoder.py index 3e6e39e1..98dddd2d 100644 --- a/libs/qec/python/tests/test_decoder.py +++ b/libs/qec/python/tests/test_decoder.py @@ -778,13 +778,25 @@ def test_get_decoder_from_stim_dem(): def test_get_decoder_from_stim_dem_rejects_malformed_text(): - with pytest.raises(Exception): + with pytest.raises(RuntimeError): qec.get_decoder_from_stim_dem("single_error_lut", "not a valid DEM") def test_get_decoder_from_stim_dem_rejects_unknown_decoder(): + with pytest.raises(RuntimeError, match="__no_such_decoder__"): + qec.get_decoder_from_stim_dem("__no_such_decoder__", + "error(0.1) D0 L0\n") + + +def test_get_decoder_from_stim_dem_user_O_wins_over_dem_derived(): + # Wrong-shape user O trips PyMatching's validation; silent overwrite + # by the DEM-derived O would suppress the throw. + dem_text = ("error(0.1) D0 L0\n" + "error(0.1) D1 L0\n" + "error(0.05) D0 D1\n") + bad_O = np.zeros((1, 4), dtype=np.uint8) with pytest.raises(RuntimeError): - qec.get_decoder_from_stim_dem("__no_such_decoder__", "error(0.1) D0\n") + qec.get_decoder_from_stim_dem("pymatching", dem_text, O=bad_O) if __name__ == "__main__": diff --git a/libs/qec/unittests/CMakeLists.txt b/libs/qec/unittests/CMakeLists.txt index 0e878dac..97aeb07b 100644 --- a/libs/qec/unittests/CMakeLists.txt +++ b/libs/qec/unittests/CMakeLists.txt @@ -35,7 +35,9 @@ find_package(CUDAToolkit REQUIRED) add_compile_options(-Wno-attributes) add_executable(test_decoders test_decoders.cpp decoders/sample_decoder.cpp) -target_link_libraries(test_decoders PRIVATE GTest::gtest_main cudaq-qec cudaq-qec-realtime-decoding cudaq::cudaq) +# Direct libstim link for StimDemTargetCategoriesAreExhaustive; +# cudaq-qec hides stim symbols via --exclude-libs. +target_link_libraries(test_decoders PRIVATE GTest::gtest_main cudaq-qec cudaq-qec-realtime-decoding cudaq::cudaq libstim) add_dependencies(CUDAQXQECUnitTests test_decoders) gtest_discover_tests(test_decoders) diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index 680da665..b19f401b 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -6,7 +6,9 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ +#include "stim.h" #include "cudaq/qec/decoder.h" +#include "cudaq/qec/detector_error_model.h" #include "cudaq/qec/pcm_utils.h" #include #include @@ -802,24 +804,16 @@ error(0.05) D0 D1 } } -TEST(StimDemDecoderFactory, ThrowsOnMalformedStimDem) { - EXPECT_THROW(cudaq::qec::get_decoder_from_stim_dem("single_error_lut", - "not a valid DEM"), - std::exception); -} - -TEST(StimDemDecoderFactory, ThrowsOnUnknownDecoderName) { - const std::string dem_text = "error(0.1) D0\n"; - EXPECT_THROW( - cudaq::qec::get_decoder_from_stim_dem("__no_such_decoder__", dem_text), - std::runtime_error); -} - TEST(StimDemDecoderFactory, RegisteredCreatorIsUsed) { - // static: the registry outlives this test, so the lambda must not capture - // a stack reference. + // static: the registry outlives this test, so the lambda must not + // capture a stack reference. static bool registered_creator_was_called = false; registered_creator_was_called = false; + // RAII guard: restore the registry slot on any exit path. + struct CreatorGuard { + const char *name; + ~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); } + } guard{"__stim_dem_test_decoder__"}; cudaq::qec::register_stim_dem_decoder_creator( "__stim_dem_test_decoder__", [](const std::string &dem_text, const cudaqx::heterogeneous_map &) @@ -835,7 +829,6 @@ TEST(StimDemDecoderFactory, RegisteredCreatorIsUsed) { "passthrough"); EXPECT_TRUE(registered_creator_was_called); ASSERT_NE(d, nullptr); - cudaq::qec::unregister_stim_dem_decoder_creator("__stim_dem_test_decoder__"); } TEST(StimDemDecoderFactory, RepeatedDetectorOrObservableTargetsXorFold) { @@ -860,9 +853,59 @@ TEST(StimDemDecoderFactory, ThrowsOnProbabilityOutOfRange) { std::runtime_error); } +TEST(StimDemDecoderFactory, ThrowsOnMalformedStimDem) { + EXPECT_THROW(cudaq::qec::get_decoder_from_stim_dem("single_error_lut", + "not a valid DEM"), + std::runtime_error); +} + +TEST(StimDemDecoderFactory, ThrowsOnUnknownDecoderName) { + const std::string dem_text = "error(0.1) D0 L0\n"; + EXPECT_THROW( + cudaq::qec::get_decoder_from_stim_dem("__no_such_decoder__", dem_text), + std::runtime_error); +} + +TEST(StimDemDecoderFactory, ThrowsOnEmptyErrorMechanisms) { + // A bare detector(...) line parses but yields zero error mechanisms. + const std::string dem_text = "detector(0, 0, 0)\n"; + EXPECT_THROW( + cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text), + std::runtime_error); +} + +TEST(StimDemDecoderFactory, StimDemTargetCategoriesAreExhaustive) { + // Pins the invariant that keeps the defensive throw in + // dem_from_stim_text unreachable. Fires first if stim's encoding + // changes. + const std::vector samples = { + stim::DemTarget::separator(), + stim::DemTarget::relative_detector_id(0), + stim::DemTarget::relative_detector_id(42), + stim::DemTarget::observable_id(0), + stim::DemTarget::observable_id(7), + }; + for (const auto &t : samples) { + const int kinds = static_cast(t.is_separator()) + + static_cast(t.is_relative_detector_id()) + + static_cast(t.is_observable_id()); + EXPECT_EQ(kinds, 1) << "DemTarget " << t.str() << " matched " << kinds + << " predicates; expected exactly 1"; + } +} + TEST(StimDemDecoderFactory, RegisteredCreatorTakesPrecedenceOverFallback) { + // Real decoder name on purpose: pins creator-over-fallback for an + // existing decoder (a sentinel name would just retest + // RegisteredCreatorIsUsed). Mutates the real "single_error_lut" slot; + // safe only because gtest runs tests serially in a binary. If this + // suite is ever parallelized, register against a sentinel name instead. static bool creator_was_called = false; creator_was_called = false; + struct CreatorGuard { + const char *name; + ~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); } + } guard{"single_error_lut"}; cudaq::qec::register_stim_dem_decoder_creator( "single_error_lut", [](const std::string &, const cudaqx::heterogeneous_map &) @@ -877,7 +920,6 @@ TEST(StimDemDecoderFactory, RegisteredCreatorTakesPrecedenceOverFallback) { auto d = cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text); EXPECT_TRUE(creator_was_called); ASSERT_NE(d, nullptr); - cudaq::qec::unregister_stim_dem_decoder_creator("single_error_lut"); } TEST(StimDemDecoderFactory, UserOptionsAreNotOverwritten) { @@ -892,9 +934,47 @@ error(0.05) D0 D1 std::runtime_error); } +TEST(StimDemDecoderFactory, UserSuppliedObservablesAreNotOverwritten) { + // Symmetric with UserOptionsAreNotOverwritten but for "O", via an + // echo creator (decoder-validation-independent). + static std::vector observed_O_shape; + observed_O_shape.clear(); + struct CreatorGuard { + const char *name; + ~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); } + } guard{"__stim_dem_echo_O__"}; + cudaq::qec::register_stim_dem_decoder_creator( + "__stim_dem_echo_O__", + [](const std::string &, const cudaqx::heterogeneous_map &opts) + -> std::unique_ptr { + if (opts.contains("O")) { + auto O = opts.get>("O"); + observed_O_shape = O.shape(); + } + cudaqx::tensor H({2u, 2u}); + cudaqx::heterogeneous_map empty; + return cudaq::qec::decoder::get("single_error_lut", H, empty); + }); + + // Distinctive shape; a match proves the user's O reached the creator. + cudaqx::tensor user_O({7u, 11u}); + cudaqx::heterogeneous_map opts; + opts.insert("O", user_O); + auto d = cudaq::qec::get_decoder_from_stim_dem("__stim_dem_echo_O__", + "error(0.1) D0 L0\n", opts); + ASSERT_NE(d, nullptr); + ASSERT_EQ(observed_O_shape.size(), 2u); + EXPECT_EQ(observed_O_shape[0], 7u); + EXPECT_EQ(observed_O_shape[1], 11u); +} + TEST(StimDemDecoderFactory, RegisteredCreatorReceivesUserOptionsVerbatim) { static std::vector observed_rates; observed_rates.clear(); + struct CreatorGuard { + const char *name; + ~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); } + } guard{"__stim_dem_echo__"}; cudaq::qec::register_stim_dem_decoder_creator( "__stim_dem_echo__", [](const std::string &, const cudaqx::heterogeneous_map &opts) @@ -913,5 +993,4 @@ TEST(StimDemDecoderFactory, RegisteredCreatorReceivesUserOptionsVerbatim) { "error(0.5) D0\n", opts); ASSERT_NE(d, nullptr); EXPECT_EQ(observed_rates, user_rates); - cudaq::qec::unregister_stim_dem_decoder_creator("__stim_dem_echo__"); } From 541f12d52dfd3c2536461ce7c3f3bd2c3343e9ee Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Tue, 2 Jun 2026 16:55:27 -0400 Subject: [PATCH 04/14] address pr comments Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 34 ++--- .../include/cudaq/qec/detector_error_model.h | 5 +- libs/qec/lib/decoder_stim_dem.cpp | 46 ------- libs/qec/python/bindings/py_decoder.cpp | 9 +- libs/qec/unittests/test_decoders.cpp | 116 ------------------ 5 files changed, 16 insertions(+), 194 deletions(-) diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index 6232aa2a..6a0e4481 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -445,37 +445,21 @@ dem_default_values dem_defaults_for_missing_keys( const std::function &contains_user_key, const detector_error_model &dem); -/// @brief Creator function for a decoder constructed from a Stim DEM string. -using stim_dem_decoder_creator = std::function( - const std::string &, const cudaqx::heterogeneous_map &)>; - -/// @brief Register a Stim-DEM-string creator for the named decoder. -/// @param name Decoder name; same name used by `get_decoder_from_stim_dem`. -/// @param creator Builds a decoder from the raw DEM string + options. Takes -/// precedence over the H/O fallback; must not re-enter the registry (the -/// factory copies it out before invoking). -/// @see get_decoder_from_stim_dem -void register_stim_dem_decoder_creator(const std::string &name, - stim_dem_decoder_creator creator); - -/// @brief Unregister a previously registered Stim-DEM-string creator. No-op if -/// \p name has no registered creator. -/// @see register_stim_dem_decoder_creator -void unregister_stim_dem_decoder_creator(const std::string &name); - /// @brief Construct a decoder by name from a Stim detector error model text. /// -/// When a Stim-DEM creator is registered for \p name it is used directly. -/// Otherwise the DEM is parsed and forwarded to the existing H-based path -/// after injecting two derived entries into \p options if they are not -/// already present: +/// Thin wrapper over \c dem_from_stim_text: parses the DEM and forwards to +/// the existing H-based \c decoder::get after injecting two derived entries +/// into \p options if they are not already present: /// - `"O"` : `cudaqx::tensor` observables_flips_matrix /// - `"error_rate_vec"` : `std::vector` per-error probabilities /// User-supplied values for either key win over the DEM-derived ones. /// -/// @note Decoders that need full DEM metadata (e.g. Chromobius) must -/// register a Stim-DEM creator; the fallback only extracts H/O/rates. -/// @see register_stim_dem_decoder_creator +/// @note Lossy: detector annotations, decomposition separators, and +/// `error_ids` are dropped. Sufficient for matching-style / H-based +/// decoders (LUT, NV, sliding_window, TRT, PyMatching). Decoders that +/// need full DEM metadata (e.g. Chromobius detector color/basis) require +/// the planned \c detector_coords extension on +/// \c cudaq::qec::detector_error_model; tracked as a follow-up. std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, diff --git a/libs/qec/include/cudaq/qec/detector_error_model.h b/libs/qec/include/cudaq/qec/detector_error_model.h index 6062083d..5acef6f9 100644 --- a/libs/qec/include/cudaq/qec/detector_error_model.h +++ b/libs/qec/include/cudaq/qec/detector_error_model.h @@ -75,8 +75,9 @@ struct detector_error_model { /// @note Lossy: only detector/observable flips and error probabilities /// are extracted. Annotations (`detector`, `logical_observable`), /// suggested-decomposition separators, and \p error_ids are dropped. -/// Decoders that need the full DEM (e.g. Chromobius) must consume the -/// raw string via `register_stim_dem_decoder_creator`. +/// Decoders that need full DEM metadata (e.g. Chromobius detector +/// color/basis) require the planned \p detector_coords extension on +/// \p detector_error_model; tracked as a follow-up. detector_error_model dem_from_stim_text(const std::string &dem_text); } // namespace cudaq::qec diff --git a/libs/qec/lib/decoder_stim_dem.cpp b/libs/qec/lib/decoder_stim_dem.cpp index e72eccdb..4aedaa37 100644 --- a/libs/qec/lib/decoder_stim_dem.cpp +++ b/libs/qec/lib/decoder_stim_dem.cpp @@ -1,7 +1,6 @@ /******************************************************************************* * Copyright (c) 2024 - 2026 NVIDIA Corporation & Affiliates. * * All rights reserved. * - * * * This source code and the accompanying materials are made available under * * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ @@ -9,32 +8,11 @@ #include "cudaq/qec/decoder.h" #include "cudaq/qec/detector_error_model.h" -#include #include #include -#include namespace cudaq::qec { -namespace { - -// std::mutex is enough: the factory copies the creator out before -// invoking, so creators cannot re-enter the registry. -struct stim_dem_registry { - std::mutex &mutex; - std::unordered_map ↦ -}; -stim_dem_registry get_stim_dem_registry() { - // Heap-allocated to outlive static destructors (plugin dlclose unregister - // path); matches the cudaqx extension_point pattern. See extension_point.h. - static std::mutex *mutex = new std::mutex(); - static auto *map = - new std::unordered_map(); - return {*mutex, *map}; -} - -} // namespace - dem_default_values dem_defaults_for_missing_keys( const std::function &contains_user_key, const detector_error_model &dem) { @@ -46,34 +24,10 @@ dem_default_values dem_defaults_for_missing_keys( return out; } -void register_stim_dem_decoder_creator(const std::string &name, - stim_dem_decoder_creator creator) { - auto reg = get_stim_dem_registry(); - std::lock_guard lock(reg.mutex); - reg.map[name] = std::move(creator); -} - -void unregister_stim_dem_decoder_creator(const std::string &name) { - auto reg = get_stim_dem_registry(); - std::lock_guard lock(reg.mutex); - reg.map.erase(name); -} - std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, const cudaqx::heterogeneous_map &options) { - stim_dem_decoder_creator creator; - { - auto reg = get_stim_dem_registry(); - std::lock_guard lock(reg.mutex); - auto iter = reg.map.find(name); - if (iter != reg.map.end()) - creator = iter->second; - } - if (creator) - return creator(stim_dem_text, options); - if (!decoder::is_registered(name)) throw std::runtime_error( "get_decoder_from_stim_dem: decoder \"" + name + diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 8c4ca4a8..9b9425d2 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -808,13 +808,12 @@ void bindDecoder(nb::module_ &mod) { hetMapFromKwargs(options)); }, "Construct a decoder by name from a Stim detector error model string. " - "Observables and per-error rates from the DEM are injected into options " - "under keys \"O\" and \"error_rate_vec\" when no registered Stim-DEM " - "creator is found. User-supplied values for either key win over the " + "Thin wrapper over dem_from_stim_text: observables and per-error rates " + "from the DEM are injected into options under keys \"O\" and " + "\"error_rate_vec\". User-supplied values for either key win over the " "DEM-derived ones. Python decoders registered via @qec.decoder receive " "the parsed H and O as numpy.ndarray and error_rate_vec as a 1-D " - "numpy.ndarray of float64; to register a native DEM consumer, use the " - "C++ register_stim_dem_decoder_creator API."); + "numpy.ndarray of float64."); qecmod.def( "get_sorted_pcm_column_indices", diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index b19f401b..455d1e69 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -804,33 +804,6 @@ error(0.05) D0 D1 } } -TEST(StimDemDecoderFactory, RegisteredCreatorIsUsed) { - // static: the registry outlives this test, so the lambda must not - // capture a stack reference. - static bool registered_creator_was_called = false; - registered_creator_was_called = false; - // RAII guard: restore the registry slot on any exit path. - struct CreatorGuard { - const char *name; - ~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); } - } guard{"__stim_dem_test_decoder__"}; - cudaq::qec::register_stim_dem_decoder_creator( - "__stim_dem_test_decoder__", - [](const std::string &dem_text, const cudaqx::heterogeneous_map &) - -> std::unique_ptr { - registered_creator_was_called = true; - EXPECT_EQ(dem_text, "passthrough"); - cudaqx::tensor H({2u, 2u}); - cudaqx::heterogeneous_map empty; - return cudaq::qec::decoder::get("single_error_lut", H, empty); - }); - - auto d = cudaq::qec::get_decoder_from_stim_dem("__stim_dem_test_decoder__", - "passthrough"); - EXPECT_TRUE(registered_creator_was_called); - ASSERT_NE(d, nullptr); -} - TEST(StimDemDecoderFactory, RepeatedDetectorOrObservableTargetsXorFold) { const std::string dem_text = R"(error(0.1) D0 D0 error(0.1) L0 L0 @@ -894,34 +867,6 @@ TEST(StimDemDecoderFactory, StimDemTargetCategoriesAreExhaustive) { } } -TEST(StimDemDecoderFactory, RegisteredCreatorTakesPrecedenceOverFallback) { - // Real decoder name on purpose: pins creator-over-fallback for an - // existing decoder (a sentinel name would just retest - // RegisteredCreatorIsUsed). Mutates the real "single_error_lut" slot; - // safe only because gtest runs tests serially in a binary. If this - // suite is ever parallelized, register against a sentinel name instead. - static bool creator_was_called = false; - creator_was_called = false; - struct CreatorGuard { - const char *name; - ~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); } - } guard{"single_error_lut"}; - cudaq::qec::register_stim_dem_decoder_creator( - "single_error_lut", - [](const std::string &, const cudaqx::heterogeneous_map &) - -> std::unique_ptr { - creator_was_called = true; - cudaqx::tensor H({2u, 2u}); - cudaqx::heterogeneous_map empty; - return cudaq::qec::decoder::get("single_error_lut", H, empty); - }); - - const std::string dem_text = "error(0.1) D0 L0\n"; - auto d = cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text); - EXPECT_TRUE(creator_was_called); - ASSERT_NE(d, nullptr); -} - TEST(StimDemDecoderFactory, UserOptionsAreNotOverwritten) { const std::string dem_text = R"(error(0.1) D0 L0 error(0.1) D1 L0 @@ -933,64 +878,3 @@ error(0.05) D0 D1 cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text, opts), std::runtime_error); } - -TEST(StimDemDecoderFactory, UserSuppliedObservablesAreNotOverwritten) { - // Symmetric with UserOptionsAreNotOverwritten but for "O", via an - // echo creator (decoder-validation-independent). - static std::vector observed_O_shape; - observed_O_shape.clear(); - struct CreatorGuard { - const char *name; - ~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); } - } guard{"__stim_dem_echo_O__"}; - cudaq::qec::register_stim_dem_decoder_creator( - "__stim_dem_echo_O__", - [](const std::string &, const cudaqx::heterogeneous_map &opts) - -> std::unique_ptr { - if (opts.contains("O")) { - auto O = opts.get>("O"); - observed_O_shape = O.shape(); - } - cudaqx::tensor H({2u, 2u}); - cudaqx::heterogeneous_map empty; - return cudaq::qec::decoder::get("single_error_lut", H, empty); - }); - - // Distinctive shape; a match proves the user's O reached the creator. - cudaqx::tensor user_O({7u, 11u}); - cudaqx::heterogeneous_map opts; - opts.insert("O", user_O); - auto d = cudaq::qec::get_decoder_from_stim_dem("__stim_dem_echo_O__", - "error(0.1) D0 L0\n", opts); - ASSERT_NE(d, nullptr); - ASSERT_EQ(observed_O_shape.size(), 2u); - EXPECT_EQ(observed_O_shape[0], 7u); - EXPECT_EQ(observed_O_shape[1], 11u); -} - -TEST(StimDemDecoderFactory, RegisteredCreatorReceivesUserOptionsVerbatim) { - static std::vector observed_rates; - observed_rates.clear(); - struct CreatorGuard { - const char *name; - ~CreatorGuard() { cudaq::qec::unregister_stim_dem_decoder_creator(name); } - } guard{"__stim_dem_echo__"}; - cudaq::qec::register_stim_dem_decoder_creator( - "__stim_dem_echo__", - [](const std::string &, const cudaqx::heterogeneous_map &opts) - -> std::unique_ptr { - if (opts.contains("error_rate_vec")) - observed_rates = opts.get>("error_rate_vec"); - cudaqx::tensor H({2u, 2u}); - cudaqx::heterogeneous_map empty; - return cudaq::qec::decoder::get("single_error_lut", H, empty); - }); - - const std::vector user_rates = {0.42, 0.13, 0.07}; - cudaqx::heterogeneous_map opts; - opts.insert("error_rate_vec", user_rates); - auto d = cudaq::qec::get_decoder_from_stim_dem("__stim_dem_echo__", - "error(0.5) D0\n", opts); - ASSERT_NE(d, nullptr); - EXPECT_EQ(observed_rates, user_rates); -} From 87504ac9590682d9e73a8c435c325a2fe29e9e5e Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Wed, 3 Jun 2026 16:37:43 -0400 Subject: [PATCH 05/14] Expose Stim DEM parsing in Python Signed-off-by: vedika-saravanan --- libs/qec/python/bindings/py_decoder.cpp | 5 +++++ libs/qec/python/cudaq_qec/__init__.py | 1 + libs/qec/python/tests/test_decoder.py | 17 +++++++++++++++++ 3 files changed, 23 insertions(+) diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 9b9425d2..d15cf7d8 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -699,6 +699,11 @@ void bindDecoder(nb::module_ &mod) { )pbdoc", nb::arg("num_syndromes_per_round")); + qecmod.def( + "dem_from_stim_text", &dem_from_stim_text, + "Parse a Stim detector error model string into a DetectorErrorModel.", + nb::arg("dem_text")); + // Expose decorator function that handles inheritance qecmod.def("decoder", [&](const std::string &name) { return nb::cpp_function([name](nb::object decoder_class) -> nb::object { diff --git a/libs/qec/python/cudaq_qec/__init__.py b/libs/qec/python/cudaq_qec/__init__.py index ac037c23..0447cc6d 100644 --- a/libs/qec/python/cudaq_qec/__init__.py +++ b/libs/qec/python/cudaq_qec/__init__.py @@ -85,6 +85,7 @@ def checked_decode_batch(self, *args, **kwargs): get_available_codes = qecrt.get_available_codes get_decoder = qecrt.get_decoder get_decoder_from_stim_dem = qecrt.get_decoder_from_stim_dem +dem_from_stim_text = qecrt.dem_from_stim_text DecoderResult = qecrt.DecoderResult BatchDecoderResult = qecrt.BatchDecoderResult DetectorErrorModel = qecrt.DetectorErrorModel diff --git a/libs/qec/python/tests/test_decoder.py b/libs/qec/python/tests/test_decoder.py index 98dddd2d..482458b2 100644 --- a/libs/qec/python/tests/test_decoder.py +++ b/libs/qec/python/tests/test_decoder.py @@ -777,6 +777,23 @@ def test_get_decoder_from_stim_dem(): assert list(result.result) == expected, f"syndrome {syndrome}" +def test_dem_from_stim_text_explicit_parse_then_get_decoder(): + dem_text = ("error(0.1) D0 L0\n" + "error(0.1) D1 L0\n" + "error(0.05) D0 D1\n") + + dem = qec.dem_from_stim_text(dem_text) + assert isinstance(dem, qec.DetectorErrorModel) + assert dem.num_detectors() == 2 + assert dem.num_error_mechanisms() == 3 + assert dem.num_observables() == 1 + assert dem.detector_error_matrix.shape == (2, 3) + + decoder = qec.get_decoder("single_error_lut", dem.detector_error_matrix) + assert decoder.get_syndrome_size() == 2 + assert decoder.get_block_size() == 3 + + def test_get_decoder_from_stim_dem_rejects_malformed_text(): with pytest.raises(RuntimeError): qec.get_decoder_from_stim_dem("single_error_lut", "not a valid DEM") From eb9d05145e508a3541aec4246ada81c439d06b40 Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Fri, 5 Jun 2026 09:38:11 -0400 Subject: [PATCH 06/14] wip Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 127 +++++++++++++++--- libs/qec/lib/decoder.cpp | 10 +- libs/qec/lib/decoder_stim_dem.cpp | 32 ++--- libs/qec/lib/decoders/lut.cpp | 8 +- .../example/single_error_lut_example.cpp | 5 +- .../plugins/pymatching/pymatching.cpp | 4 +- .../plugins/trt_decoder/trt_decoder.cpp | 4 +- libs/qec/lib/decoders/sliding_window.h | 4 +- libs/qec/python/bindings/py_decoder.cpp | 66 +++++---- libs/qec/python/tests/test_decoder.py | 17 +++ .../qec/unittests/decoders/sample_decoder.cpp | 6 +- libs/qec/unittests/test_decoders.cpp | 31 +++++ 12 files changed, 228 insertions(+), 86 deletions(-) diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index 6a0e4481..7f98c7c8 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -12,11 +12,14 @@ #include "cuda-qx/core/heterogeneous_map.h" #include "cuda-qx/core/tensor.h" #include "sparse_binary_matrix.h" +#include "cudaq/qec/detector_error_model.h" #include #include #include #include #include +#include +#include #include namespace cudaq::qec { @@ -27,6 +30,22 @@ using float_t = CUDAQX_QEC_FLOAT_TYPE; using float_t = double; #endif +/// @brief Construction input for a decoder: either an explicit parity-check +/// matrix (\c cudaq::qec::sparse_binary_matrix) or a Stim detector error model string +/// (\c std::string_view). +/// +/// Parity-check-matrix-based decoders (LUT, sliding_window, TRT, PyMatching, +/// nv-qldpc, ...) accept either alternative: a DEM string is parsed into a +/// parity-check matrix via \c dem_from_stim_text. Decoders that require the raw +/// DEM (e.g. Chromobius, which needs detector color/basis annotations) require +/// the string alternative and reject a bare matrix. +/// +/// @note The string alternative is non-owning. The referenced buffer must stay +/// alive for the duration of the \c get_decoder / \c decoder::get call; +/// decoders parse it during construction and do not retain the view. +using decoder_init = + std::variant; + /// @brief Validates that all keys in a heterogeneous map are found in a list of /// acceptable types /// @param config The heterogeneous map to validate @@ -125,8 +144,7 @@ class async_decoder_result { /// arbitrary constructor parameters that can be unique to each specific /// decoder. class decoder - : public cudaqx::extension_point { private: struct rt_impl; @@ -175,11 +193,35 @@ class decoder virtual std::vector decode_batch(const std::vector> &syndrome); - /// @brief This `get` overload supports default values. + /// @brief Construct a registered decoder by name. + /// @param name The registered decoder name. + /// @param init Either a parity-check matrix or a Stim DEM string (see + /// \c decoder_init). The variant is forwarded to the decoder's creator, so + /// parity-check-matrix-based decoders and DEM-native decoders (Chromobius) + /// share a single entry point. + /// @param param_map Optional decoder-specific parameters. static std::unique_ptr - get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, + get(const std::string &name, const decoder_init &init, const cudaqx::heterogeneous_map ¶m_map = cudaqx::heterogeneous_map()); + static std::unique_ptr + get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, + const cudaqx::heterogeneous_map ¶m_map = cudaqx::heterogeneous_map()) { + return get(name, decoder_init{H}, param_map); + } + + static std::unique_ptr + get(const std::string &name, const cudaqx::tensor &H, + const cudaqx::heterogeneous_map ¶m_map = cudaqx::heterogeneous_map()) { + return get(name, cudaq::qec::sparse_binary_matrix(H), param_map); + } + + static std::unique_ptr + get(const std::string &name, std::string_view stim_dem_text, + const cudaqx::heterogeneous_map ¶m_map = cudaqx::heterogeneous_map()) { + return get(name, decoder_init{stim_dem_text}, param_map); + } + std::size_t get_block_size() { return block_size; } std::size_t get_syndrome_size() { return syndrome_size; } @@ -428,10 +470,26 @@ inline void convert_vec_hard_to_soft(const std::vector> &in, } std::unique_ptr -get_decoder(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, +get_decoder(const std::string &name, const decoder_init &init, const cudaqx::heterogeneous_map options = {}); -struct detector_error_model; +inline std::unique_ptr +get_decoder(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, + const cudaqx::heterogeneous_map options = {}) { + return get_decoder(name, decoder_init{H}, options); +} + +inline std::unique_ptr +get_decoder(const std::string &name, const cudaqx::tensor &H, + const cudaqx::heterogeneous_map options = {}) { + return get_decoder(name, cudaq::qec::sparse_binary_matrix(H), options); +} + +inline std::unique_ptr +get_decoder(const std::string &name, std::string_view stim_dem_text, + const cudaqx::heterogeneous_map options = {}) { + return get_decoder(name, decoder_init{stim_dem_text}, options); +} /// @brief DEM-derived defaults; pointers alias into the source `dem`. struct dem_default_values { @@ -440,26 +498,55 @@ struct dem_default_values { }; /// @brief Return DEM defaults for any key not already supplied by the user. -/// Shared by `get_decoder_from_stim_dem` and its Python binding. +/// Shared by `make_pcm_decoder` and the Python binding. dem_default_values dem_defaults_for_missing_keys( const std::function &contains_user_key, const detector_error_model &dem); -/// @brief Construct a decoder by name from a Stim detector error model text. +/// @brief Extract the Stim DEM text from a \c decoder_init, throwing if it holds +/// a parity-check matrix instead. Use this in the create() function of decoders +/// that require a raw DEM (e.g. Chromobius), which cannot be reconstructed from +/// a bare parity-check matrix. +std::string_view require_dem_text(const decoder_init &init); + +/// @brief Build a parity-check-matrix-based decoder from a \c decoder_init. +/// +/// If \p init holds a sparse matrix, it is used directly as the parity-check matrix. +/// If it holds a Stim DEM string, it is parsed via \c dem_from_stim_text and the +/// derived observables (`"O"`) and per-error rates (`"error_rate_vec"`) are +/// injected into \p params unless the user already supplied them (user values +/// win). This is the shared implementation behind the create() function of every +/// parity-check-matrix-based decoder, giving them DEM-string support for free. /// -/// Thin wrapper over \c dem_from_stim_text: parses the DEM and forwards to -/// the existing H-based \c decoder::get after injecting two derived entries -/// into \p options if they are not already present: -/// - `"O"` : `cudaqx::tensor` observables_flips_matrix -/// - `"error_rate_vec"` : `std::vector` per-error probabilities -/// User-supplied values for either key win over the DEM-derived ones. +/// @note The DEM parse is lossy: detector annotations, decomposition +/// separators, and `error_ids` are dropped. Sufficient for matching-style / +/// parity-check-matrix decoders (LUT, NV, sliding_window, TRT, PyMatching). +/// Decoders that need full DEM metadata (e.g. Chromobius detector color/basis) +/// must consume the string directly via \c require_dem_text. +template +std::unique_ptr +make_pcm_decoder(const decoder_init &init, + const cudaqx::heterogeneous_map ¶ms) { + if (const auto *H = std::get_if(&init)) + return std::make_unique(*H, params); + + const auto dem = + dem_from_stim_text(std::string(std::get(init))); + cudaqx::heterogeneous_map merged = params; + const auto defaults = dem_defaults_for_missing_keys( + [&](const std::string &key) { return merged.contains(key); }, dem); + if (defaults.O) + merged.insert("O", *defaults.O); + if (defaults.error_rate_vec) + merged.insert("error_rate_vec", *defaults.error_rate_vec); + return std::make_unique( + cudaq::qec::sparse_binary_matrix(dem.detector_error_matrix), merged); +} + +/// @brief Construct a decoder by name from a Stim detector error model string. /// -/// @note Lossy: detector annotations, decomposition separators, and -/// `error_ids` are dropped. Sufficient for matching-style / H-based -/// decoders (LUT, NV, sliding_window, TRT, PyMatching). Decoders that -/// need full DEM metadata (e.g. Chromobius detector color/basis) require -/// the planned \c detector_coords extension on -/// \c cudaq::qec::detector_error_model; tracked as a follow-up. +/// @deprecated Prefer \c get_decoder, which now accepts a Stim DEM string +/// directly via \c decoder_init. Retained as a thin convenience alias. std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, diff --git a/libs/qec/lib/decoder.cpp b/libs/qec/lib/decoder.cpp index 020c3980..fea6ea1e 100644 --- a/libs/qec/lib/decoder.cpp +++ b/libs/qec/lib/decoder.cpp @@ -20,7 +20,7 @@ INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaq::qec::sparse_binary_matrix &) INSTANTIATE_REGISTRY(cudaq::qec::decoder, - const cudaq::qec::sparse_binary_matrix &, + const cudaq::qec::decoder_init &, const cudaqx::heterogeneous_map &) // Include decoder implementations AFTER registry instantiation @@ -131,7 +131,7 @@ decoder::decode_async(const std::vector &syndrome) { } std::unique_ptr -decoder::get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, +decoder::get(const std::string &name, const decoder_init &init, const cudaqx::heterogeneous_map ¶m_map) { auto [mutex, registry] = get_registry(); std::lock_guard lock(mutex); @@ -141,7 +141,7 @@ decoder::get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, "invalid decoder requested: " + name + ". Run with CUDAQ_LOG_LEVEL=info (environment variable) to see " "additional plugin diagnostics at startup."); - return iter->second(H, param_map); + return iter->second(init, param_map); } static uint32_t calculate_num_msyn_per_decode( @@ -480,9 +480,9 @@ void decoder::reset_decoder() { } std::unique_ptr get_decoder(const std::string &name, - const cudaq::qec::sparse_binary_matrix &H, + const decoder_init &init, const cudaqx::heterogeneous_map options) { - return decoder::get(name, H, options); + return decoder::get(name, init, options); } // Constructor function for auto-loading plugins diff --git a/libs/qec/lib/decoder_stim_dem.cpp b/libs/qec/lib/decoder_stim_dem.cpp index 4aedaa37..f1f5bafd 100644 --- a/libs/qec/lib/decoder_stim_dem.cpp +++ b/libs/qec/lib/decoder_stim_dem.cpp @@ -24,28 +24,24 @@ dem_default_values dem_defaults_for_missing_keys( return out; } +std::string_view require_dem_text(const decoder_init &init) { + if (const auto *dem_text = std::get_if(&init)) + return *dem_text; + throw std::runtime_error( + "This decoder requires a Stim detector error model string; a " + "parity-check matrix cannot be used to reconstruct the detector " + "annotations it needs."); +} + std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, const cudaqx::heterogeneous_map &options) { - if (!decoder::is_registered(name)) - throw std::runtime_error( - "get_decoder_from_stim_dem: decoder \"" + name + - "\" is not registered. Run with CUDAQ_LOG_LEVEL=info to see plugin " - "diagnostics at startup."); - - auto dem = dem_from_stim_text(stim_dem_text); - - cudaqx::heterogeneous_map merged = options; - // Keep in sync with the Python binding in py_decoder.cpp. - auto defaults = dem_defaults_for_missing_keys( - [&](const std::string &key) { return merged.contains(key); }, dem); - if (defaults.O) - merged.insert("O", *defaults.O); - if (defaults.error_rate_vec) - merged.insert("error_rate_vec", *defaults.error_rate_vec); - - return decoder::get(name, dem.detector_error_matrix, merged); + // Retained for backward compatibility: get_decoder now accepts a Stim DEM + // string directly via decoder_init. The string_view aliases stim_dem_text, + // which outlives this call. + return get_decoder(name, decoder_init{std::string_view{stim_dem_text}}, + options); } } // namespace cudaq::qec diff --git a/libs/qec/lib/decoders/lut.cpp b/libs/qec/lib/decoders/lut.cpp index 50c2870d..c548fa90 100644 --- a/libs/qec/lib/decoders/lut.cpp +++ b/libs/qec/lib/decoders/lut.cpp @@ -228,9 +228,9 @@ class multi_error_lut : public decoder { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( multi_error_lut, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) }; @@ -246,9 +246,9 @@ class single_error_lut : public multi_error_lut { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( single_error_lut, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) }; diff --git a/libs/qec/lib/decoders/plugins/example/single_error_lut_example.cpp b/libs/qec/lib/decoders/plugins/example/single_error_lut_example.cpp index 201bfcc3..bf366b53 100644 --- a/libs/qec/lib/decoders/plugins/example/single_error_lut_example.cpp +++ b/libs/qec/lib/decoders/plugins/example/single_error_lut_example.cpp @@ -77,9 +77,10 @@ class single_error_lut_example : public decoder { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( single_error_lut_example, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, + params); }) }; diff --git a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp index a38fcf1e..3d8eadfc 100644 --- a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp +++ b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp @@ -247,9 +247,9 @@ class pymatching : public decoder { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( pymatching, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) }; diff --git a/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp b/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp index aa6419da..8ed56830 100644 --- a/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp +++ b/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp @@ -432,9 +432,9 @@ class trt_decoder : public decoder { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( trt_decoder, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) private: diff --git a/libs/qec/lib/decoders/sliding_window.h b/libs/qec/lib/decoders/sliding_window.h index e240da9f..826a991d 100644 --- a/libs/qec/lib/decoders/sliding_window.h +++ b/libs/qec/lib/decoders/sliding_window.h @@ -142,9 +142,9 @@ class sliding_window : public decoder { // Plugin registration macros CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( sliding_window, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) }; diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index d15cf7d8..ba2b46c4 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -787,38 +787,48 @@ void bindDecoder(nb::module_ &mod) { allocation. )pbdoc"); - qecmod.def( - "get_decoder_from_stim_dem", + // Shared implementation for constructing a decoder from a Stim DEM string, + // used by both the get_decoder(str) overload and the (deprecated) + // get_decoder_from_stim_dem entry point. + auto get_decoder_from_dem_text = [](const std::string &name, const std::string &dem_text, nb::kwargs options) - -> std::variant> { - if (PyDecoderRegistry::contains(name)) { - auto dem = dem_from_stim_text(dem_text); - - // Keep in sync with the C++ fallback in decoder_stim_dem.cpp. - auto defaults = dem_defaults_for_missing_keys( - [&](const std::string &key) { return options.contains(key); }, - dem); - if (defaults.O) - options["O"] = toPyArray(*defaults.O); - if (defaults.error_rate_vec) - options["error_rate_vec"] = toPyArray(*defaults.error_rate_vec); - - nb::object H_obj = toPyArray(dem.detector_error_matrix); - return PyDecoderRegistry::get_decoder( - name, nb::cast>(H_obj), options); - } + -> std::variant> { + if (PyDecoderRegistry::contains(name)) { + auto dem = dem_from_stim_text(dem_text); + + // Keep in sync with make_pcm_decoder in decoder.h. + auto defaults = dem_defaults_for_missing_keys( + [&](const std::string &key) { return options.contains(key); }, dem); + if (defaults.O) + options["O"] = toPyArray(*defaults.O); + if (defaults.error_rate_vec) + options["error_rate_vec"] = toPyArray(*defaults.error_rate_vec); + + nb::object H_obj = toPyArray(dem.detector_error_matrix); + return PyDecoderRegistry::get_decoder( + name, nb::cast>(H_obj), options); + } - return get_decoder_from_stim_dem(name, dem_text, - hetMapFromKwargs(options)); - }, + return get_decoder(name, decoder_init{std::string_view{dem_text}}, + hetMapFromKwargs(options)); + }; + + // Unified entry point: get_decoder also accepts a Stim DEM string. nanobind + // resolves this overload when the second argument is a str rather than a + // numpy array. + qecmod.def( + "get_decoder", get_decoder_from_dem_text, "Construct a decoder by name from a Stim detector error model string. " - "Thin wrapper over dem_from_stim_text: observables and per-error rates " - "from the DEM are injected into options under keys \"O\" and " - "\"error_rate_vec\". User-supplied values for either key win over the " - "DEM-derived ones. Python decoders registered via @qec.decoder receive " - "the parsed H and O as numpy.ndarray and error_rate_vec as a 1-D " - "numpy.ndarray of float64."); + "Parity-check-matrix decoders parse the DEM into H (observables and " + "per-error rates are injected into options under keys \"O\" and " + "\"error_rate_vec\"; user-supplied values win). Decoders that need the " + "raw DEM (e.g. Chromobius) consume the string directly."); + + qecmod.def( + "get_decoder_from_stim_dem", get_decoder_from_dem_text, + "Deprecated: prefer get_decoder, which now accepts a Stim detector error " + "model string directly. Retained as a thin alias."); qecmod.def( "get_sorted_pcm_column_indices", diff --git a/libs/qec/python/tests/test_decoder.py b/libs/qec/python/tests/test_decoder.py index 482458b2..d1d6491b 100644 --- a/libs/qec/python/tests/test_decoder.py +++ b/libs/qec/python/tests/test_decoder.py @@ -777,6 +777,23 @@ def test_get_decoder_from_stim_dem(): assert list(result.result) == expected, f"syndrome {syndrome}" +def test_get_decoder_accepts_stim_dem_string(): + # get_decoder dispatches on the second argument: a str is treated as a Stim + # DEM, an ndarray as a parity-check matrix. + dem_text = ("error(0.1) D0 L0\n" + "error(0.1) D1 L0\n" + "error(0.05) D0 D1\n") + + decoder = qec.get_decoder("single_error_lut", dem_text) + assert decoder is not None + assert decoder.get_syndrome_size() == 2 + assert decoder.get_block_size() == 3 + + result = decoder.decode([1.0, 1.0]) + assert result.converged is True + assert list(result.result) == [0.0, 0.0, 1.0] + + def test_dem_from_stim_text_explicit_parse_then_get_decoder(): dem_text = ("error(0.1) D0 L0\n" "error(0.1) D1 L0\n" diff --git a/libs/qec/unittests/decoders/sample_decoder.cpp b/libs/qec/unittests/decoders/sample_decoder.cpp index 8b3b509a..da953411 100644 --- a/libs/qec/unittests/decoders/sample_decoder.cpp +++ b/libs/qec/unittests/decoders/sample_decoder.cpp @@ -14,7 +14,7 @@ using namespace cudaqx; namespace cudaq::qec { /// @brief This is a sample (dummy) decoder that demonstrates how to build a -/// bare bones custom decoder based on the `cudaqx::qec::decoder` interface. +/// bare bones custom decoder based on the `cudaq::qec::decoder` interface. class sample_decoder : public decoder { public: sample_decoder(const cudaq::qec::sparse_binary_matrix &H, @@ -35,9 +35,9 @@ class sample_decoder : public decoder { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( sample_decoder, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) }; diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index 455d1e69..dcd5aa01 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -804,6 +804,37 @@ error(0.05) D0 D1 } } +TEST(StimDemDecoderFactory, UnifiedGetDecoderAcceptsStimDemString) { + // get_decoder and decoder::get accept a Stim DEM string directly via + // decoder_init, matching the (deprecated) get_decoder_from_stim_dem path. + const std::string dem_text = R"(error(0.1) D0 L0 +error(0.1) D1 L0 +error(0.05) D0 D1 +)"; + + auto check = [&](std::unique_ptr d) { + ASSERT_NE(d, nullptr); + EXPECT_EQ(d->get_syndrome_size(), 2u); + EXPECT_EQ(d->get_block_size(), 3u); + auto result = d->decode(std::vector{1.0, 1.0}); + EXPECT_TRUE(result.converged); + ASSERT_EQ(result.result.size(), 3u); + EXPECT_FLOAT_EQ(result.result[2], 1.0); + }; + check(cudaq::qec::get_decoder("single_error_lut", dem_text)); + check(cudaq::qec::decoder::get("single_error_lut", dem_text)); +} + +TEST(StimDemDecoderFactory, UnifiedGetDecoderStillAcceptsParityCheckMatrix) { + // Dense tensor inputs still convert to sparse PCM storage unchanged. + cudaqx::tensor H({2, 3}); + H.copy(std::vector{1, 0, 1, 0, 1, 1}.data(), {2, 3}); + auto d = cudaq::qec::get_decoder("single_error_lut", H); + ASSERT_NE(d, nullptr); + EXPECT_EQ(d->get_syndrome_size(), 2u); + EXPECT_EQ(d->get_block_size(), 3u); +} + TEST(StimDemDecoderFactory, RepeatedDetectorOrObservableTargetsXorFold) { const std::string dem_text = R"(error(0.1) D0 D0 error(0.1) L0 L0 From 85e9d2b469a192c646bdbc7befb7c1b836eb28db Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Fri, 5 Jun 2026 13:55:17 -0400 Subject: [PATCH 07/14] fix formatting Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 56 ++++++++++--------- .../include/cudaq/qec/detector_error_model.h | 20 +++---- libs/qec/lib/decoder.cpp | 3 +- libs/qec/python/bindings/py_decoder.cpp | 6 +- 4 files changed, 44 insertions(+), 41 deletions(-) diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index 7f98c7c8..c4dcaff0 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -31,17 +31,17 @@ using float_t = double; #endif /// @brief Construction input for a decoder: either an explicit parity-check -/// matrix (\c cudaq::qec::sparse_binary_matrix) or a Stim detector error model string -/// (\c std::string_view). +/// matrix (`cudaq::qec::sparse_binary_matrix`) or a Stim detector error model +/// string (`std::string_view`). /// /// Parity-check-matrix-based decoders (LUT, sliding_window, TRT, PyMatching, /// nv-qldpc, ...) accept either alternative: a DEM string is parsed into a -/// parity-check matrix via \c dem_from_stim_text. Decoders that require the raw +/// parity-check matrix via `dem_from_stim_text`. Decoders that require the raw /// DEM (e.g. Chromobius, which needs detector color/basis annotations) require /// the string alternative and reject a bare matrix. /// /// @note The string alternative is non-owning. The referenced buffer must stay -/// alive for the duration of the \c get_decoder / \c decoder::get call; +/// alive for the duration of the `get_decoder` / `decoder::get` call; /// decoders parse it during construction and do not retain the view. using decoder_init = std::variant; @@ -164,16 +164,16 @@ class decoder /// @brief Decode a single syndrome /// @param syndrome A vector of syndrome measurements where the floating point /// value is the probability that the syndrome measurement is a |1>. The - /// length of the syndrome vector should be equal to \p syndrome_size. - /// @returns Vector of length \p block_size with soft probabilities of errors + /// length of the syndrome vector should be equal to `syndrome_size`. + /// @returns Vector of length `block_size` with soft probabilities of errors /// in each index. virtual decoder_result decode(const std::vector &syndrome) = 0; /// @brief Decode a single syndrome /// @param syndrome An order-1 tensor of syndrome measurements where a 1 bit /// represents that the syndrome measurement is a |1>. The - /// length of the syndrome vector should be equal to \p syndrome_size. - /// @returns Vector of length \p block_size of errors in each index. + /// length of the syndrome vector should be equal to `syndrome_size`. + /// @returns Vector of length `block_size` of errors in each index. virtual decoder_result decode(const cudaqx::tensor &syndrome); /// @brief Decode a single syndrome @@ -196,7 +196,7 @@ class decoder /// @brief Construct a registered decoder by name. /// @param name The registered decoder name. /// @param init Either a parity-check matrix or a Stim DEM string (see - /// \c decoder_init). The variant is forwarded to the decoder's creator, so + /// `decoder_init`). The variant is forwarded to the decoder's creator, so /// parity-check-matrix-based decoders and DEM-native decoders (Chromobius) /// share a single entry point. /// @param param_map Optional decoder-specific parameters. @@ -206,19 +206,22 @@ class decoder static std::unique_ptr get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, - const cudaqx::heterogeneous_map ¶m_map = cudaqx::heterogeneous_map()) { + const cudaqx::heterogeneous_map ¶m_map = + cudaqx::heterogeneous_map()) { return get(name, decoder_init{H}, param_map); } static std::unique_ptr get(const std::string &name, const cudaqx::tensor &H, - const cudaqx::heterogeneous_map ¶m_map = cudaqx::heterogeneous_map()) { + const cudaqx::heterogeneous_map ¶m_map = + cudaqx::heterogeneous_map()) { return get(name, cudaq::qec::sparse_binary_matrix(H), param_map); } static std::unique_ptr get(const std::string &name, std::string_view stim_dem_text, - const cudaqx::heterogeneous_map ¶m_map = cudaqx::heterogeneous_map()) { + const cudaqx::heterogeneous_map ¶m_map = + cudaqx::heterogeneous_map()) { return get(name, decoder_init{stim_dem_text}, param_map); } @@ -503,26 +506,27 @@ dem_default_values dem_defaults_for_missing_keys( const std::function &contains_user_key, const detector_error_model &dem); -/// @brief Extract the Stim DEM text from a \c decoder_init, throwing if it holds -/// a parity-check matrix instead. Use this in the create() function of decoders -/// that require a raw DEM (e.g. Chromobius), which cannot be reconstructed from -/// a bare parity-check matrix. +/// @brief Extract the Stim DEM text from a `decoder_init`, throwing if it +/// holds a parity-check matrix instead. Use this in the create() function of +/// decoders that require a raw DEM (e.g. Chromobius), which cannot be +/// reconstructed from a bare parity-check matrix. std::string_view require_dem_text(const decoder_init &init); -/// @brief Build a parity-check-matrix-based decoder from a \c decoder_init. +/// @brief Build a parity-check-matrix-based decoder from a `decoder_init`. /// -/// If \p init holds a sparse matrix, it is used directly as the parity-check matrix. -/// If it holds a Stim DEM string, it is parsed via \c dem_from_stim_text and the -/// derived observables (`"O"`) and per-error rates (`"error_rate_vec"`) are -/// injected into \p params unless the user already supplied them (user values -/// win). This is the shared implementation behind the create() function of every -/// parity-check-matrix-based decoder, giving them DEM-string support for free. +/// If `init` holds a sparse matrix, it is used directly as the parity-check +/// matrix. If it holds a Stim DEM string, it is parsed via +/// `dem_from_stim_text` and the derived observables (`"O"`) and per-error +/// rates (`"error_rate_vec"`) are injected into `params` unless the user +/// already supplied them (user values win). This is the shared implementation +/// behind the create() function of every parity-check-matrix-based decoder, +/// giving them DEM-string support for free. /// /// @note The DEM parse is lossy: detector annotations, decomposition /// separators, and `error_ids` are dropped. Sufficient for matching-style / /// parity-check-matrix decoders (LUT, NV, sliding_window, TRT, PyMatching). /// Decoders that need full DEM metadata (e.g. Chromobius detector color/basis) -/// must consume the string directly via \c require_dem_text. +/// must consume the string directly via `require_dem_text`. template std::unique_ptr make_pcm_decoder(const decoder_init &init, @@ -545,8 +549,8 @@ make_pcm_decoder(const decoder_init &init, /// @brief Construct a decoder by name from a Stim detector error model string. /// -/// @deprecated Prefer \c get_decoder, which now accepts a Stim DEM string -/// directly via \c decoder_init. Retained as a thin convenience alias. +/// @deprecated Prefer `get_decoder`, which now accepts a Stim DEM string +/// directly via `decoder_init`. Retained as a thin convenience alias. std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, diff --git a/libs/qec/include/cudaq/qec/detector_error_model.h b/libs/qec/include/cudaq/qec/detector_error_model.h index 5acef6f9..a7d15382 100644 --- a/libs/qec/include/cudaq/qec/detector_error_model.h +++ b/libs/qec/include/cudaq/qec/detector_error_model.h @@ -19,9 +19,9 @@ namespace cudaq::qec { /// decoder to help make predictions about observables flips. /// /// Shared size parameters among the matrix types. -/// - \p detector_error_matrix: num_detectors x num_error_mechanisms [d, e] -/// - \p error_rates: num_error_mechanisms -/// - \p observables_flips_matrix: num_observables x num_error_mechanisms [k, e] +/// - `detector_error_matrix`: num_detectors x num_error_mechanisms [d, e] +/// - `error_rates`: num_error_mechanisms +/// - `observables_flips_matrix`: num_observables x num_error_mechanisms [k, e] /// /// @note The C++ API for this class may change in the future. The Python API is /// more likely to be backwards compatible. @@ -33,7 +33,7 @@ struct detector_error_model { cudaqx::tensor detector_error_matrix; /// The list of weights has length equal to the number of columns of - /// \p detector_error_matrix, which assigns a likelihood to each error + /// `detector_error_matrix`, which assigns a likelihood to each error /// mechanism. std::vector error_rates; @@ -67,17 +67,17 @@ struct detector_error_model { }; /// @brief Parse a Stim detector error model text into a -/// \p cudaq::qec::detector_error_model. Each `error` instruction in the DEM -/// becomes a single column in \p detector_error_matrix and -/// \p observables_flips_matrix; suggested decomposition separators are +/// `cudaq::qec::detector_error_model`. Each `error` instruction in the DEM +/// becomes a single column in `detector_error_matrix` and +/// `observables_flips_matrix`; suggested decomposition separators are /// folded into the same column. /// /// @note Lossy: only detector/observable flips and error probabilities /// are extracted. Annotations (`detector`, `logical_observable`), -/// suggested-decomposition separators, and \p error_ids are dropped. +/// suggested-decomposition separators, and `error_ids` are dropped. /// Decoders that need full DEM metadata (e.g. Chromobius detector -/// color/basis) require the planned \p detector_coords extension on -/// \p detector_error_model; tracked as a follow-up. +/// color/basis) require the planned `detector_coords` extension on +/// `detector_error_model`; tracked as a follow-up. detector_error_model dem_from_stim_text(const std::string &dem_text); } // namespace cudaq::qec diff --git a/libs/qec/lib/decoder.cpp b/libs/qec/lib/decoder.cpp index fea6ea1e..3b059679 100644 --- a/libs/qec/lib/decoder.cpp +++ b/libs/qec/lib/decoder.cpp @@ -19,8 +19,7 @@ INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaq::qec::sparse_binary_matrix &) -INSTANTIATE_REGISTRY(cudaq::qec::decoder, - const cudaq::qec::decoder_init &, +INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaq::qec::decoder_init &, const cudaqx::heterogeneous_map &) // Include decoder implementations AFTER registry instantiation diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index ba2b46c4..0a173791 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -790,9 +790,9 @@ void bindDecoder(nb::module_ &mod) { // Shared implementation for constructing a decoder from a Stim DEM string, // used by both the get_decoder(str) overload and the (deprecated) // get_decoder_from_stim_dem entry point. - auto get_decoder_from_dem_text = - [](const std::string &name, const std::string &dem_text, - nb::kwargs options) + auto get_decoder_from_dem_text = [](const std::string &name, + const std::string &dem_text, + nb::kwargs options) -> std::variant> { if (PyDecoderRegistry::contains(name)) { auto dem = dem_from_stim_text(dem_text); From 200804a5639248be178b3763e774707af8e41a49 Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Fri, 5 Jun 2026 15:09:27 -0400 Subject: [PATCH 08/14] fix ci failure Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 43 +++++++++++++++---- .../include/cudaq/qec/detector_error_model.h | 4 +- libs/qec/lib/decoder_stim_dem.cpp | 10 ++--- libs/qec/python/bindings/py_decoder.cpp | 6 +-- 4 files changed, 42 insertions(+), 21 deletions(-) diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index c4dcaff0..fb4f8455 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -32,7 +32,7 @@ using float_t = double; /// @brief Construction input for a decoder: either an explicit parity-check /// matrix (`cudaq::qec::sparse_binary_matrix`) or a Stim detector error model -/// string (`std::string_view`). +/// string (`std::string`). /// /// Parity-check-matrix-based decoders (LUT, sliding_window, TRT, PyMatching, /// nv-qldpc, ...) accept either alternative: a DEM string is parsed into a @@ -40,11 +40,11 @@ using float_t = double; /// DEM (e.g. Chromobius, which needs detector color/basis annotations) require /// the string alternative and reject a bare matrix. /// -/// @note The string alternative is non-owning. The referenced buffer must stay -/// alive for the duration of the `get_decoder` / `decoder::get` call; -/// decoders parse it during construction and do not retain the view. +/// @note The factory input owns the DEM text. Public overloads accept +/// `std::string`, `const char *`, and `std::string_view` and copy into this +/// variant before dispatching to decoder plugins. using decoder_init = - std::variant; + std::variant; /// @brief Validates that all keys in a heterogeneous map are found in a list of /// acceptable types @@ -219,12 +219,26 @@ class decoder } static std::unique_ptr - get(const std::string &name, std::string_view stim_dem_text, + get(const std::string &name, const std::string &stim_dem_text, const cudaqx::heterogeneous_map ¶m_map = cudaqx::heterogeneous_map()) { return get(name, decoder_init{stim_dem_text}, param_map); } + static std::unique_ptr + get(const std::string &name, const char *stim_dem_text, + const cudaqx::heterogeneous_map ¶m_map = + cudaqx::heterogeneous_map()) { + return get(name, decoder_init{std::string{stim_dem_text}}, param_map); + } + + static std::unique_ptr + get(const std::string &name, std::string_view stim_dem_text, + const cudaqx::heterogeneous_map ¶m_map = + cudaqx::heterogeneous_map()) { + return get(name, decoder_init{std::string{stim_dem_text}}, param_map); + } + std::size_t get_block_size() { return block_size; } std::size_t get_syndrome_size() { return syndrome_size; } @@ -489,11 +503,23 @@ get_decoder(const std::string &name, const cudaqx::tensor &H, } inline std::unique_ptr -get_decoder(const std::string &name, std::string_view stim_dem_text, +get_decoder(const std::string &name, const std::string &stim_dem_text, const cudaqx::heterogeneous_map options = {}) { return get_decoder(name, decoder_init{stim_dem_text}, options); } +inline std::unique_ptr +get_decoder(const std::string &name, const char *stim_dem_text, + const cudaqx::heterogeneous_map options = {}) { + return get_decoder(name, decoder_init{std::string{stim_dem_text}}, options); +} + +inline std::unique_ptr +get_decoder(const std::string &name, std::string_view stim_dem_text, + const cudaqx::heterogeneous_map options = {}) { + return get_decoder(name, decoder_init{std::string{stim_dem_text}}, options); +} + /// @brief DEM-derived defaults; pointers alias into the source `dem`. struct dem_default_values { const cudaqx::tensor *O = nullptr; @@ -534,8 +560,7 @@ make_pcm_decoder(const decoder_init &init, if (const auto *H = std::get_if(&init)) return std::make_unique(*H, params); - const auto dem = - dem_from_stim_text(std::string(std::get(init))); + const auto dem = dem_from_stim_text(std::get(init)); cudaqx::heterogeneous_map merged = params; const auto defaults = dem_defaults_for_missing_keys( [&](const std::string &key) { return merged.contains(key); }, dem); diff --git a/libs/qec/include/cudaq/qec/detector_error_model.h b/libs/qec/include/cudaq/qec/detector_error_model.h index a7d15382..94699bd5 100644 --- a/libs/qec/include/cudaq/qec/detector_error_model.h +++ b/libs/qec/include/cudaq/qec/detector_error_model.h @@ -76,8 +76,8 @@ struct detector_error_model { /// are extracted. Annotations (`detector`, `logical_observable`), /// suggested-decomposition separators, and `error_ids` are dropped. /// Decoders that need full DEM metadata (e.g. Chromobius detector -/// color/basis) require the planned `detector_coords` extension on -/// `detector_error_model`; tracked as a follow-up. +/// color/basis) should consume the raw Stim DEM string via `decoder_init` +/// and `require_dem_text` instead of this lossy representation. detector_error_model dem_from_stim_text(const std::string &dem_text); } // namespace cudaq::qec diff --git a/libs/qec/lib/decoder_stim_dem.cpp b/libs/qec/lib/decoder_stim_dem.cpp index f1f5bafd..16a09310 100644 --- a/libs/qec/lib/decoder_stim_dem.cpp +++ b/libs/qec/lib/decoder_stim_dem.cpp @@ -25,8 +25,8 @@ dem_default_values dem_defaults_for_missing_keys( } std::string_view require_dem_text(const decoder_init &init) { - if (const auto *dem_text = std::get_if(&init)) - return *dem_text; + if (const auto *dem_text = std::get_if(&init)) + return std::string_view{*dem_text}; throw std::runtime_error( "This decoder requires a Stim detector error model string; a " "parity-check matrix cannot be used to reconstruct the detector " @@ -38,10 +38,8 @@ get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, const cudaqx::heterogeneous_map &options) { // Retained for backward compatibility: get_decoder now accepts a Stim DEM - // string directly via decoder_init. The string_view aliases stim_dem_text, - // which outlives this call. - return get_decoder(name, decoder_init{std::string_view{stim_dem_text}}, - options); + // string directly via decoder_init. + return get_decoder(name, decoder_init{stim_dem_text}, options); } } // namespace cudaq::qec diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 0a173791..c13a1467 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -806,12 +806,10 @@ void bindDecoder(nb::module_ &mod) { options["error_rate_vec"] = toPyArray(*defaults.error_rate_vec); nb::object H_obj = toPyArray(dem.detector_error_matrix); - return PyDecoderRegistry::get_decoder( - name, nb::cast>(H_obj), options); + return PyDecoderRegistry::get_decoder(name, H_obj, options); } - return get_decoder(name, decoder_init{std::string_view{dem_text}}, - hetMapFromKwargs(options)); + return get_decoder(name, decoder_init{dem_text}, hetMapFromKwargs(options)); }; // Unified entry point: get_decoder also accepts a Stim DEM string. nanobind From fb6dcfce39cfac2018f6e4383a913827500dd432 Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Mon, 8 Jun 2026 11:31:31 -0400 Subject: [PATCH 09/14] update constructing a decoder from a Stim DEM string Signed-off-by: vedika-saravanan --- libs/qec/lib/decoder.cpp | 2 - libs/qec/python/bindings/py_decoder.cpp | 74 ++++++++++++------------- libs/qec/unittests/test_decoders.cpp | 61 -------------------- 3 files changed, 36 insertions(+), 101 deletions(-) diff --git a/libs/qec/lib/decoder.cpp b/libs/qec/lib/decoder.cpp index 3b059679..6b8f441d 100644 --- a/libs/qec/lib/decoder.cpp +++ b/libs/qec/lib/decoder.cpp @@ -17,8 +17,6 @@ #include #include -INSTANTIATE_REGISTRY(cudaq::qec::decoder, - const cudaq::qec::sparse_binary_matrix &) INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaq::qec::decoder_init &, const cudaqx::heterogeneous_map &) diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index c13a1467..4185c7d7 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -735,10 +735,41 @@ void bindDecoder(nb::module_ &mod) { }); }); + // Shared implementation for constructing a decoder from a Stim DEM string, + // used by both get_decoder overload paths and the (deprecated) + // get_decoder_from_stim_dem entry point. + auto get_decoder_from_dem_text = [](const std::string &name, + const std::string &dem_text, + nb::kwargs options) + -> std::variant> { + if (PyDecoderRegistry::contains(name)) { + auto dem = dem_from_stim_text(dem_text); + + // Keep in sync with make_pcm_decoder in decoder.h. + auto defaults = dem_defaults_for_missing_keys( + [&](const std::string &key) { return options.contains(key); }, dem); + if (defaults.O) + options["O"] = toPyArray(*defaults.O); + if (defaults.error_rate_vec) + options["error_rate_vec"] = toPyArray(*defaults.error_rate_vec); + + nb::object H_obj = toPyArray(dem.detector_error_matrix); + return PyDecoderRegistry::get_decoder(name, H_obj, options); + } + + return get_decoder(name, decoder_init{dem_text}, hetMapFromKwargs(options)); + }; + qecmod.def( "get_decoder", - [](const std::string &name, nb::object H, nb::kwargs options) + [get_decoder_from_dem_text](const std::string &name, nb::object H, + nb::kwargs options) -> std::variant> { + if (nb::isinstance(H)) { + return get_decoder_from_dem_text(name, nb::cast(H), + options); + } + if (PyDecoderRegistry::contains(name)) { return PyDecoderRegistry::get_decoder(name, H, options); } @@ -779,50 +810,17 @@ void bindDecoder(nb::module_ &mod) { builds ``sparse_binary_matrix`` directly (no dense tensor for ``H``), allowing native C++ decoders like ``pymatching`` to be constructed from very large parity-check matrices. + - A Stim detector error model string: native C++ decoders receive the + raw DEM text via ``decoder_init``; Python-registered decoders receive + the DEM-derived PCM plus ``O`` and ``error_rate_vec`` defaults. For Python-registered decoders (``cudaq.qec.decoder`` decorator), ``H`` is passed through to ``__init__`` unchanged (NumPy array or sparse dict). Call ``Decoder.__init__(self, H)`` so nanobind can store the PCM in CSC - form when ``H`` is a dict without building a dense ``rows × cols`` + form when ``H`` is a dict without building a dense ``rows x cols`` allocation. )pbdoc"); - // Shared implementation for constructing a decoder from a Stim DEM string, - // used by both the get_decoder(str) overload and the (deprecated) - // get_decoder_from_stim_dem entry point. - auto get_decoder_from_dem_text = [](const std::string &name, - const std::string &dem_text, - nb::kwargs options) - -> std::variant> { - if (PyDecoderRegistry::contains(name)) { - auto dem = dem_from_stim_text(dem_text); - - // Keep in sync with make_pcm_decoder in decoder.h. - auto defaults = dem_defaults_for_missing_keys( - [&](const std::string &key) { return options.contains(key); }, dem); - if (defaults.O) - options["O"] = toPyArray(*defaults.O); - if (defaults.error_rate_vec) - options["error_rate_vec"] = toPyArray(*defaults.error_rate_vec); - - nb::object H_obj = toPyArray(dem.detector_error_matrix); - return PyDecoderRegistry::get_decoder(name, H_obj, options); - } - - return get_decoder(name, decoder_init{dem_text}, hetMapFromKwargs(options)); - }; - - // Unified entry point: get_decoder also accepts a Stim DEM string. nanobind - // resolves this overload when the second argument is a str rather than a - // numpy array. - qecmod.def( - "get_decoder", get_decoder_from_dem_text, - "Construct a decoder by name from a Stim detector error model string. " - "Parity-check-matrix decoders parse the DEM into H (observables and " - "per-error rates are injected into options under keys \"O\" and " - "\"error_rate_vec\"; user-supplied values win). Decoders that need the " - "raw DEM (e.g. Chromobius) consume the string directly."); - qecmod.def( "get_decoder_from_stim_dem", get_decoder_from_dem_text, "Deprecated: prefer get_decoder, which now accepts a Stim detector error " diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index dcd5aa01..000f92c9 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -710,67 +710,6 @@ TEST(DecoderTest, GetBlockSizeAndSyndromeSize) { EXPECT_EQ(decoder2->get_syndrome_size(), new_syndrome_size); } -TEST(DecoderRegistryTest, SingleParameterRegistryDirect) { - // Test the single-parameter registry instantiation (line 18 in decoder.cpp) - // This directly tests the registry for decoder constructors that only take - // tensor by accessing the single-parameter extension_point registry - // directly - - std::size_t block_size = 8; - std::size_t syndrome_size = 4; - cudaqx::tensor H({syndrome_size, block_size}); - - // Initialize with some test data to ensure it's a valid matrix - for (std::size_t i = 0; i < syndrome_size; ++i) { - for (std::size_t j = 0; j < block_size; ++j) { - H.at({i, j}) = (i + j) % 2; - } - } - - auto H_sparse = cudaq::qec::sparse_binary_matrix(H); - - // Test that the single-parameter registry exists and can be accessed - // This directly tests line 18: INSTANTIATE_REGISTRY(cudaq::qec::decoder, - // const cudaqx::tensor &) - try { - // Create a decoder using the single-parameter extension_point directly - // This bypasses decoder::get and directly uses the single-parameter - // registry - auto single_param_decoder = cudaqx::extension_point< - cudaq::qec::decoder, - const cudaq::qec::sparse_binary_matrix &>::get("sample_decoder", - H_sparse); - - ASSERT_NE(single_param_decoder, nullptr); - - // Verify the decoder works correctly - EXPECT_EQ(single_param_decoder->get_block_size(), block_size); - EXPECT_EQ(single_param_decoder->get_syndrome_size(), syndrome_size); - - // Test with a syndrome decode to ensure functionality - std::vector syndrome(syndrome_size, 0.0f); - auto result = single_param_decoder->decode(syndrome); - EXPECT_EQ(result.result.size(), block_size); - - } catch (const std::runtime_error &e) { - // This is expected if "sample_decoder" is not registered in the - // single-parameter registry The test still passes because it verifies that - // line 18 creates a functional registry - EXPECT_TRUE(std::string(e.what()).find("Cannot find extension with name") != - std::string::npos); - } - - // Test that we can check if extensions are registered in the single-parameter - // registry - auto registered_single = cudaqx::extension_point< - cudaq::qec::decoder, - const cudaq::qec::sparse_binary_matrix &>::get_registered(); - - // The registry should exist (even if empty), proving line 18 instantiation - // works This test passes if no exceptions are thrown, proving the - // single-parameter registry is instantiated -} - TEST(StimDemDecoderFactory, ConstructsLutDecoderFromStimDemText) { const std::string dem_text = R"(error(0.1) D0 L0 error(0.1) D1 L0 From f6d7fdf0daf0b51e86805c09c2af9b733ca1838b Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Mon, 8 Jun 2026 12:33:19 -0400 Subject: [PATCH 10/14] remove verbose doc added during testing Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 55 ++++--------------- .../include/cudaq/qec/detector_error_model.h | 14 +---- libs/qec/lib/CMakeLists.txt | 1 - libs/qec/lib/decoder_stim_dem.cpp | 2 - libs/qec/lib/detector_error_model.cpp | 7 +-- libs/qec/python/bindings/py_decoder.cpp | 18 ++---- libs/qec/python/tests/test_decoder.py | 7 --- libs/qec/unittests/CMakeLists.txt | 3 - libs/qec/unittests/test_decoders.cpp | 9 +-- 9 files changed, 21 insertions(+), 95 deletions(-) diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index 753f701d..34310304 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -30,21 +30,10 @@ using float_t = CUDAQX_QEC_FLOAT_TYPE; using float_t = double; #endif -/// @brief Construction input for a decoder: either an explicit parity-check -/// matrix (`cudaq::qec::sparse_binary_matrix`) or a Stim detector error model -/// string (`std::string`). -/// -/// Parity-check-matrix-based decoders (LUT, sliding_window, TRT, PyMatching, -/// nv-qldpc, ...) accept either alternative: a DEM string is parsed into a -/// parity-check matrix via `dem_from_stim_text`. Decoders that require the raw -/// DEM (e.g. Chromobius, which needs detector color/basis annotations) require -/// the string alternative and reject a bare matrix. -/// -/// @note The factory input owns the DEM text. Public overloads accept -/// `std::string`, `const char *`, and `std::string_view` and copy into this -/// variant before dispatching to decoder plugins. -using decoder_init = - std::variant; +/// Decoder construction input: either a parity-check matrix or raw Stim DEM +/// text. PCM-based decoders can accept both; DEM-native decoders can require +/// the string alternative via `require_dem_text`. +using decoder_init = std::variant; /// @brief Validates that all keys in a heterogeneous map are found in a list of /// acceptable types @@ -195,10 +184,7 @@ class decoder /// @brief Construct a registered decoder by name. /// @param name The registered decoder name. - /// @param init Either a parity-check matrix or a Stim DEM string (see - /// `decoder_init`). The variant is forwarded to the decoder's creator, so - /// parity-check-matrix-based decoders and DEM-native decoders (Chromobius) - /// share a single entry point. + /// @param init A parity-check matrix or raw Stim DEM string. /// @param param_map Optional decoder-specific parameters. static std::unique_ptr get(const std::string &name, const decoder_init &init, @@ -530,39 +516,22 @@ get_decoder(const std::string &name, std::string_view stim_dem_text, return get_decoder(name, decoder_init{std::string{stim_dem_text}}, options); } -/// @brief DEM-derived defaults; pointers alias into the source `dem`. +/// DEM-derived defaults; pointers alias into the source `dem`. struct dem_default_values { const cudaqx::tensor *O = nullptr; const std::vector *error_rate_vec = nullptr; }; -/// @brief Return DEM defaults for any key not already supplied by the user. -/// Shared by `make_pcm_decoder` and the Python binding. +/// Return DEM defaults for keys not already supplied by the user. dem_default_values dem_defaults_for_missing_keys( const std::function &contains_user_key, const detector_error_model &dem); -/// @brief Extract the Stim DEM text from a `decoder_init`, throwing if it -/// holds a parity-check matrix instead. Use this in the create() function of -/// decoders that require a raw DEM (e.g. Chromobius), which cannot be -/// reconstructed from a bare parity-check matrix. +/// Extract raw Stim DEM text for DEM-native decoders. std::string_view require_dem_text(const decoder_init &init); -/// @brief Build a parity-check-matrix-based decoder from a `decoder_init`. -/// -/// If `init` holds a sparse matrix, it is used directly as the parity-check -/// matrix. If it holds a Stim DEM string, it is parsed via -/// `dem_from_stim_text` and the derived observables (`"O"`) and per-error -/// rates (`"error_rate_vec"`) are injected into `params` unless the user -/// already supplied them (user values win). This is the shared implementation -/// behind the create() function of every parity-check-matrix-based decoder, -/// giving them DEM-string support for free. -/// -/// @note The DEM parse is lossy: detector annotations, decomposition -/// separators, and `error_ids` are dropped. Sufficient for matching-style / -/// parity-check-matrix decoders (LUT, NV, sliding_window, TRT, PyMatching). -/// Decoders that need full DEM metadata (e.g. Chromobius detector color/basis) -/// must consume the string directly via `require_dem_text`. +/// Build a PCM-based decoder. If `init` holds DEM text, it is parsed into a PCM +/// and DEM-derived `"O"` / `"error_rate_vec"` defaults are added when absent. template std::unique_ptr make_pcm_decoder(const decoder_init &init, @@ -582,10 +551,8 @@ make_pcm_decoder(const decoder_init &init, cudaq::qec::sparse_binary_matrix(dem.detector_error_matrix), merged); } -/// @brief Construct a decoder by name from a Stim detector error model string. -/// /// @deprecated Prefer `get_decoder`, which now accepts a Stim DEM string -/// directly via `decoder_init`. Retained as a thin convenience alias. +/// directly. std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, diff --git a/libs/qec/include/cudaq/qec/detector_error_model.h b/libs/qec/include/cudaq/qec/detector_error_model.h index 94699bd5..f6d82d50 100644 --- a/libs/qec/include/cudaq/qec/detector_error_model.h +++ b/libs/qec/include/cudaq/qec/detector_error_model.h @@ -66,18 +66,8 @@ struct detector_error_model { void canonicalize_for_rounds(uint32_t num_syndromes_per_round); }; -/// @brief Parse a Stim detector error model text into a -/// `cudaq::qec::detector_error_model`. Each `error` instruction in the DEM -/// becomes a single column in `detector_error_matrix` and -/// `observables_flips_matrix`; suggested decomposition separators are -/// folded into the same column. -/// -/// @note Lossy: only detector/observable flips and error probabilities -/// are extracted. Annotations (`detector`, `logical_observable`), -/// suggested-decomposition separators, and `error_ids` are dropped. -/// Decoders that need full DEM metadata (e.g. Chromobius detector -/// color/basis) should consume the raw Stim DEM string via `decoder_init` -/// and `require_dem_text` instead of this lossy representation. +/// Parse Stim DEM text into detector/observable flip matrices and error rates. +/// This is lossy; DEM-native decoders should consume raw DEM text instead. detector_error_model dem_from_stim_text(const std::string &dem_text); } // namespace cudaq::qec diff --git a/libs/qec/lib/CMakeLists.txt b/libs/qec/lib/CMakeLists.txt index 2a60d52e..92991048 100644 --- a/libs/qec/lib/CMakeLists.txt +++ b/libs/qec/lib/CMakeLists.txt @@ -62,7 +62,6 @@ add_library(${LIBRARY_NAME} SHARED ${QEC_SOURCES}) add_subdirectory(decoders/plugins/example) add_subdirectory(decoders/plugins/pymatching) -# libstim comes from the parent build (CUDA-Q). if(NOT TARGET libstim) message(FATAL_ERROR "libstim target not available; required by cudaq-qec for Stim DEM parsing.") diff --git a/libs/qec/lib/decoder_stim_dem.cpp b/libs/qec/lib/decoder_stim_dem.cpp index 16a09310..567f6594 100644 --- a/libs/qec/lib/decoder_stim_dem.cpp +++ b/libs/qec/lib/decoder_stim_dem.cpp @@ -37,8 +37,6 @@ std::unique_ptr get_decoder_from_stim_dem(const std::string &name, const std::string &stim_dem_text, const cudaqx::heterogeneous_map &options) { - // Retained for backward compatibility: get_decoder now accepts a Stim DEM - // string directly via decoder_init. return get_decoder(name, decoder_init{stim_dem_text}, options); } diff --git a/libs/qec/lib/detector_error_model.cpp b/libs/qec/lib/detector_error_model.cpp index 0376cf78..c9c9c09d 100644 --- a/libs/qec/lib/detector_error_model.cpp +++ b/libs/qec/lib/detector_error_model.cpp @@ -34,7 +34,7 @@ detector_error_model dem_from_stim_text(const std::string &dem_text) { std::size_t instruction_index = 0; dem.iter_flatten_error_instructions([&](const stim::DemInstruction &inst) { - if (inst.arg_data.size() == 0) + if (inst.arg_data.empty()) throw std::runtime_error( "Stim DEM error instruction missing probability argument (index " + std::to_string(instruction_index) + ")"); @@ -54,9 +54,6 @@ detector_error_model dem_from_stim_text(const std::string &dem_text) { } else if (target.is_observable_id()) { obs.push_back(static_cast(target.val())); } else { - // Forward-compat tripwire; unreachable today (stim's three - // DemTarget categories are exhaustive -- pinned by - // StimDemTargetCategoriesAreExhaustive). throw std::runtime_error( "Stim DEM error instruction (index " + std::to_string(instruction_index) + @@ -71,8 +68,6 @@ detector_error_model dem_from_stim_text(const std::string &dem_text) { }); const std::size_t num_errors = rates.size(); - // Reject zero-column H at the boundary instead of letting decoders - // crash with block_size == 0. if (num_errors == 0) throw std::runtime_error( "Stim DEM contains no error mechanisms after flattening"); diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 4185c7d7..5f24bded 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -339,16 +339,14 @@ makeBatchDecoderResult(const std::vector &results) { }; } -// Wrap a borrowed cudaqx buffer in a NumPy array and force a Python-side copy, -// so the returned object owns its data. -nb::object toPyArray(const cudaqx::tensor &t) { +nb::object copyToPyArray(const cudaqx::tensor &t) { size_t shape[2] = {t.shape()[0], t.shape()[1]}; auto arr = nb::ndarray(const_cast(t.data()), 2, shape, nb::none()); return nb::cast(arr).attr("copy")(); } -nb::object toPyArray(const std::vector &v) { +nb::object copyToPyArray(const std::vector &v) { size_t shape[1] = {v.size()}; auto arr = nb::ndarray(const_cast(v.data()), 1, shape, nb::none()); @@ -735,9 +733,6 @@ void bindDecoder(nb::module_ &mod) { }); }); - // Shared implementation for constructing a decoder from a Stim DEM string, - // used by both get_decoder overload paths and the (deprecated) - // get_decoder_from_stim_dem entry point. auto get_decoder_from_dem_text = [](const std::string &name, const std::string &dem_text, nb::kwargs options) @@ -745,15 +740,14 @@ void bindDecoder(nb::module_ &mod) { if (PyDecoderRegistry::contains(name)) { auto dem = dem_from_stim_text(dem_text); - // Keep in sync with make_pcm_decoder in decoder.h. auto defaults = dem_defaults_for_missing_keys( [&](const std::string &key) { return options.contains(key); }, dem); if (defaults.O) - options["O"] = toPyArray(*defaults.O); + options["O"] = copyToPyArray(*defaults.O); if (defaults.error_rate_vec) - options["error_rate_vec"] = toPyArray(*defaults.error_rate_vec); + options["error_rate_vec"] = copyToPyArray(*defaults.error_rate_vec); - nb::object H_obj = toPyArray(dem.detector_error_matrix); + nb::object H_obj = copyToPyArray(dem.detector_error_matrix); return PyDecoderRegistry::get_decoder(name, H_obj, options); } @@ -824,7 +818,7 @@ void bindDecoder(nb::module_ &mod) { qecmod.def( "get_decoder_from_stim_dem", get_decoder_from_dem_text, "Deprecated: prefer get_decoder, which now accepts a Stim detector error " - "model string directly. Retained as a thin alias."); + "model string directly."); qecmod.def( "get_sorted_pcm_column_indices", diff --git a/libs/qec/python/tests/test_decoder.py b/libs/qec/python/tests/test_decoder.py index 0b51b8d3..460c4ca8 100644 --- a/libs/qec/python/tests/test_decoder.py +++ b/libs/qec/python/tests/test_decoder.py @@ -763,9 +763,6 @@ def test_generate_random_pcm_signed_weight_rejects_negative(): def test_get_decoder_from_stim_dem(): - # 2 detectors, 1 observable, 3 errors. Matches the C++ - # StimDemDecoderFactory.ConstructsLutDecoderFromStimDemText DEM so the - # truth-data assertions stay in sync across language bindings. dem_text = ("error(0.1) D0 L0\n" "error(0.1) D1 L0\n" "error(0.05) D0 D1\n") @@ -788,8 +785,6 @@ def test_get_decoder_from_stim_dem(): def test_get_decoder_accepts_stim_dem_string(): - # get_decoder dispatches on the second argument: a str is treated as a Stim - # DEM, an ndarray as a parity-check matrix. dem_text = ("error(0.1) D0 L0\n" "error(0.1) D1 L0\n" "error(0.05) D0 D1\n") @@ -833,8 +828,6 @@ def test_get_decoder_from_stim_dem_rejects_unknown_decoder(): def test_get_decoder_from_stim_dem_user_O_wins_over_dem_derived(): - # Wrong-shape user O trips PyMatching's validation; silent overwrite - # by the DEM-derived O would suppress the throw. dem_text = ("error(0.1) D0 L0\n" "error(0.1) D1 L0\n" "error(0.05) D0 D1\n") diff --git a/libs/qec/unittests/CMakeLists.txt b/libs/qec/unittests/CMakeLists.txt index 97aeb07b..923ae9fa 100644 --- a/libs/qec/unittests/CMakeLists.txt +++ b/libs/qec/unittests/CMakeLists.txt @@ -35,8 +35,6 @@ find_package(CUDAToolkit REQUIRED) add_compile_options(-Wno-attributes) add_executable(test_decoders test_decoders.cpp decoders/sample_decoder.cpp) -# Direct libstim link for StimDemTargetCategoriesAreExhaustive; -# cudaq-qec hides stim symbols via --exclude-libs. target_link_libraries(test_decoders PRIVATE GTest::gtest_main cudaq-qec cudaq-qec-realtime-decoding cudaq::cudaq libstim) add_dependencies(CUDAQXQECUnitTests test_decoders) gtest_discover_tests(test_decoders) @@ -511,4 +509,3 @@ add_subdirectory(decoders/pymatching) if(CUDAQX_QEC_ENABLE_HOLOLINK_TOOLS) add_subdirectory(utils) endif() - diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index d74d4867..26ccd506 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -779,8 +779,6 @@ error(0.05) D0 D1 } TEST(StimDemDecoderFactory, UnifiedGetDecoderAcceptsStimDemString) { - // get_decoder and decoder::get accept a Stim DEM string directly via - // decoder_init, matching the (deprecated) get_decoder_from_stim_dem path. const std::string dem_text = R"(error(0.1) D0 L0 error(0.1) D1 L0 error(0.05) D0 D1 @@ -800,7 +798,6 @@ error(0.05) D0 D1 } TEST(StimDemDecoderFactory, UnifiedGetDecoderStillAcceptsParityCheckMatrix) { - // Dense tensor inputs still convert to sparse PCM storage unchanged. cudaqx::tensor H({2, 3}); H.copy(std::vector{1, 0, 1, 0, 1, 1}.data(), {2, 3}); auto d = cudaq::qec::get_decoder("single_error_lut", H); @@ -845,7 +842,6 @@ TEST(StimDemDecoderFactory, ThrowsOnUnknownDecoderName) { } TEST(StimDemDecoderFactory, ThrowsOnEmptyErrorMechanisms) { - // A bare detector(...) line parses but yields zero error mechanisms. const std::string dem_text = "detector(0, 0, 0)\n"; EXPECT_THROW( cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text), @@ -853,9 +849,6 @@ TEST(StimDemDecoderFactory, ThrowsOnEmptyErrorMechanisms) { } TEST(StimDemDecoderFactory, StimDemTargetCategoriesAreExhaustive) { - // Pins the invariant that keeps the defensive throw in - // dem_from_stim_text unreachable. Fires first if stim's encoding - // changes. const std::vector samples = { stim::DemTarget::separator(), stim::DemTarget::relative_detector_id(0), @@ -878,7 +871,7 @@ error(0.1) D1 L0 error(0.05) D0 D1 )"; cudaqx::heterogeneous_map opts; - opts.insert("error_rate_vec", std::vector{0.5}); // wrong size + opts.insert("error_rate_vec", std::vector{0.5}); EXPECT_THROW( cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text, opts), std::runtime_error); From 125354087574b4f10c6af241e085c6d34962eb8e Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Mon, 8 Jun 2026 12:45:48 -0400 Subject: [PATCH 11/14] fetches Stim directly with FetchContent when libstim is not already available Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 6 -- libs/qec/lib/CMakeLists.txt | 18 +++++- libs/qec/lib/decoder_stim_dem.cpp | 9 +-- libs/qec/python/bindings/py_decoder.cpp | 5 -- libs/qec/python/cudaq_qec/__init__.py | 1 - libs/qec/python/tests/test_decoder.py | 40 ++++++------- libs/qec/unittests/test_decoders.cpp | 74 +++++++++++++------------ 7 files changed, 70 insertions(+), 83 deletions(-) diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index 34310304..e2c952df 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -551,10 +551,4 @@ make_pcm_decoder(const decoder_init &init, cudaq::qec::sparse_binary_matrix(dem.detector_error_matrix), merged); } -/// @deprecated Prefer `get_decoder`, which now accepts a Stim DEM string -/// directly. -std::unique_ptr -get_decoder_from_stim_dem(const std::string &name, - const std::string &stim_dem_text, - const cudaqx::heterogeneous_map &options = {}); } // namespace cudaq::qec diff --git a/libs/qec/lib/CMakeLists.txt b/libs/qec/lib/CMakeLists.txt index 92991048..1a930e71 100644 --- a/libs/qec/lib/CMakeLists.txt +++ b/libs/qec/lib/CMakeLists.txt @@ -8,6 +8,8 @@ set(LIBRARY_NAME cudaq-qec) +include(FetchContent) + add_compile_options(-Wno-attributes) find_package(CUDAToolkit REQUIRED) @@ -59,18 +61,28 @@ list(APPEND QEC_SOURCES # FIXME?: This must be a shared library. Trying to build a static one will fail. add_library(${LIBRARY_NAME} SHARED ${QEC_SOURCES}) -add_subdirectory(decoders/plugins/example) -add_subdirectory(decoders/plugins/pymatching) +if(NOT TARGET libstim) + FetchContent_Declare( + stim + GIT_REPOSITORY https://github.com/quantumlib/Stim.git + GIT_TAG v1.15.0 + EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(stim) +endif() if(NOT TARGET libstim) message(FATAL_ERROR - "libstim target not available; required by cudaq-qec for Stim DEM parsing.") + "Stim FetchContent did not provide the libstim target.") endif() target_link_libraries(${LIBRARY_NAME} PRIVATE libstim) target_link_options(${LIBRARY_NAME} PRIVATE $<$,$>:-Wl,--exclude-libs,libstim.a> ) +add_subdirectory(decoders/plugins/example) +add_subdirectory(decoders/plugins/pymatching) + # The TRT decoder plugin honors the tri-state `CUDAQ_QEC_BUILD_TRT_DECODER` # cache variable (AUTO/ON/OFF) declared in the parent CMakeLists.txt. Skip # descending entirely when the user explicitly opted out; otherwise let the diff --git a/libs/qec/lib/decoder_stim_dem.cpp b/libs/qec/lib/decoder_stim_dem.cpp index 567f6594..4f260807 100644 --- a/libs/qec/lib/decoder_stim_dem.cpp +++ b/libs/qec/lib/decoder_stim_dem.cpp @@ -17,7 +17,7 @@ dem_default_values dem_defaults_for_missing_keys( const std::function &contains_user_key, const detector_error_model &dem) { dem_default_values out; - if (!contains_user_key("O")) + if (!contains_user_key("O") && dem.num_observables() > 0) out.O = &dem.observables_flips_matrix; if (!contains_user_key("error_rate_vec")) out.error_rate_vec = &dem.error_rates; @@ -33,11 +33,4 @@ std::string_view require_dem_text(const decoder_init &init) { "annotations it needs."); } -std::unique_ptr -get_decoder_from_stim_dem(const std::string &name, - const std::string &stim_dem_text, - const cudaqx::heterogeneous_map &options) { - return get_decoder(name, decoder_init{stim_dem_text}, options); -} - } // namespace cudaq::qec diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 5f24bded..29ca63c8 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -815,11 +815,6 @@ void bindDecoder(nb::module_ &mod) { allocation. )pbdoc"); - qecmod.def( - "get_decoder_from_stim_dem", get_decoder_from_dem_text, - "Deprecated: prefer get_decoder, which now accepts a Stim detector error " - "model string directly."); - qecmod.def( "get_sorted_pcm_column_indices", [](const nb::ndarray &H, diff --git a/libs/qec/python/cudaq_qec/__init__.py b/libs/qec/python/cudaq_qec/__init__.py index 0447cc6d..692a3af8 100644 --- a/libs/qec/python/cudaq_qec/__init__.py +++ b/libs/qec/python/cudaq_qec/__init__.py @@ -84,7 +84,6 @@ def checked_decode_batch(self, *args, **kwargs): get_code = qecrt.get_code get_available_codes = qecrt.get_available_codes get_decoder = qecrt.get_decoder -get_decoder_from_stim_dem = qecrt.get_decoder_from_stim_dem dem_from_stim_text = qecrt.dem_from_stim_text DecoderResult = qecrt.DecoderResult BatchDecoderResult = qecrt.BatchDecoderResult diff --git a/libs/qec/python/tests/test_decoder.py b/libs/qec/python/tests/test_decoder.py index 460c4ca8..da5555cd 100644 --- a/libs/qec/python/tests/test_decoder.py +++ b/libs/qec/python/tests/test_decoder.py @@ -762,12 +762,12 @@ def test_generate_random_pcm_signed_weight_rejects_negative(): seed=1) -def test_get_decoder_from_stim_dem(): +def test_get_decoder_accepts_stim_dem_string(): dem_text = ("error(0.1) D0 L0\n" "error(0.1) D1 L0\n" "error(0.05) D0 D1\n") - decoder = qec.get_decoder_from_stim_dem("single_error_lut", dem_text) + decoder = qec.get_decoder("single_error_lut", dem_text) assert decoder is not None assert decoder.get_syndrome_size() == 2 assert decoder.get_block_size() == 3 @@ -784,21 +784,6 @@ def test_get_decoder_from_stim_dem(): assert list(result.result) == expected, f"syndrome {syndrome}" -def test_get_decoder_accepts_stim_dem_string(): - dem_text = ("error(0.1) D0 L0\n" - "error(0.1) D1 L0\n" - "error(0.05) D0 D1\n") - - decoder = qec.get_decoder("single_error_lut", dem_text) - assert decoder is not None - assert decoder.get_syndrome_size() == 2 - assert decoder.get_block_size() == 3 - - result = decoder.decode([1.0, 1.0]) - assert result.converged is True - assert list(result.result) == [0.0, 0.0, 1.0] - - def test_dem_from_stim_text_explicit_parse_then_get_decoder(): dem_text = ("error(0.1) D0 L0\n" "error(0.1) D1 L0\n" @@ -816,24 +801,31 @@ def test_dem_from_stim_text_explicit_parse_then_get_decoder(): assert decoder.get_block_size() == 3 -def test_get_decoder_from_stim_dem_rejects_malformed_text(): +def test_get_decoder_rejects_malformed_stim_dem_text(): with pytest.raises(RuntimeError): - qec.get_decoder_from_stim_dem("single_error_lut", "not a valid DEM") + qec.get_decoder("single_error_lut", "not a valid DEM") -def test_get_decoder_from_stim_dem_rejects_unknown_decoder(): +def test_get_decoder_rejects_unknown_decoder_for_stim_dem_text(): with pytest.raises(RuntimeError, match="__no_such_decoder__"): - qec.get_decoder_from_stim_dem("__no_such_decoder__", - "error(0.1) D0 L0\n") + qec.get_decoder("__no_such_decoder__", "error(0.1) D0 L0\n") -def test_get_decoder_from_stim_dem_user_O_wins_over_dem_derived(): +def test_get_decoder_user_O_wins_over_dem_derived(): dem_text = ("error(0.1) D0 L0\n" "error(0.1) D1 L0\n" "error(0.05) D0 D1\n") bad_O = np.zeros((1, 4), dtype=np.uint8) with pytest.raises(RuntimeError): - qec.get_decoder_from_stim_dem("pymatching", dem_text, O=bad_O) + qec.get_decoder("pymatching", dem_text, O=bad_O) + + +def test_get_decoder_stim_dem_without_observables_returns_errors(): + decoder = qec.get_decoder("pymatching", "error(0.1) D0\n") + + result = decoder.decode([1.0]) + assert result.converged is True + assert list(result.result) == [1.0] if __name__ == "__main__": diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index 26ccd506..400e2d68 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -745,13 +745,13 @@ TEST(DecoderTest, GetBlockSizeAndSyndromeSize) { EXPECT_EQ(decoder2->get_syndrome_size(), new_syndrome_size); } -TEST(StimDemDecoderFactory, ConstructsLutDecoderFromStimDemText) { +TEST(StimDemGetDecoder, ConstructsLutDecoderFromStimDemText) { const std::string dem_text = R"(error(0.1) D0 L0 error(0.1) D1 L0 error(0.05) D0 D1 )"; - auto d = cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text); + auto d = cudaq::qec::get_decoder("single_error_lut", dem_text); ASSERT_NE(d, nullptr); EXPECT_EQ(d->get_syndrome_size(), 2u); EXPECT_EQ(d->get_block_size(), 3u); @@ -778,26 +778,23 @@ error(0.05) D0 D1 } } -TEST(StimDemDecoderFactory, UnifiedGetDecoderAcceptsStimDemString) { +TEST(StimDemGetDecoder, StaticDecoderGetAcceptsStimDemString) { const std::string dem_text = R"(error(0.1) D0 L0 error(0.1) D1 L0 error(0.05) D0 D1 )"; - auto check = [&](std::unique_ptr d) { - ASSERT_NE(d, nullptr); - EXPECT_EQ(d->get_syndrome_size(), 2u); - EXPECT_EQ(d->get_block_size(), 3u); - auto result = d->decode(std::vector{1.0, 1.0}); - EXPECT_TRUE(result.converged); - ASSERT_EQ(result.result.size(), 3u); - EXPECT_FLOAT_EQ(result.result[2], 1.0); - }; - check(cudaq::qec::get_decoder("single_error_lut", dem_text)); - check(cudaq::qec::decoder::get("single_error_lut", dem_text)); + auto d = cudaq::qec::decoder::get("single_error_lut", dem_text); + ASSERT_NE(d, nullptr); + EXPECT_EQ(d->get_syndrome_size(), 2u); + EXPECT_EQ(d->get_block_size(), 3u); + auto result = d->decode(std::vector{1.0, 1.0}); + EXPECT_TRUE(result.converged); + ASSERT_EQ(result.result.size(), 3u); + EXPECT_FLOAT_EQ(result.result[2], 1.0); } -TEST(StimDemDecoderFactory, UnifiedGetDecoderStillAcceptsParityCheckMatrix) { +TEST(StimDemGetDecoder, StillAcceptsParityCheckMatrix) { cudaqx::tensor H({2, 3}); H.copy(std::vector{1, 0, 1, 0, 1, 1}.data(), {2, 3}); auto d = cudaq::qec::get_decoder("single_error_lut", H); @@ -806,7 +803,7 @@ TEST(StimDemDecoderFactory, UnifiedGetDecoderStillAcceptsParityCheckMatrix) { EXPECT_EQ(d->get_block_size(), 3u); } -TEST(StimDemDecoderFactory, RepeatedDetectorOrObservableTargetsXorFold) { +TEST(StimDemGetDecoder, RepeatedDetectorOrObservableTargetsXorFold) { const std::string dem_text = R"(error(0.1) D0 D0 error(0.1) L0 L0 )"; @@ -821,34 +818,40 @@ error(0.1) L0 L0 << "duplicate L0 in error 1 should XOR-cancel to 0"; } -TEST(StimDemDecoderFactory, ThrowsOnProbabilityOutOfRange) { +TEST(StimDemGetDecoder, DemWithoutObservablesDoesNotAddODefault) { + auto dem = cudaq::qec::dem_from_stim_text("error(0.1) D0\n"); + auto defaults = cudaq::qec::dem_defaults_for_missing_keys( + [](const std::string &) { return false; }, dem); + + EXPECT_EQ(defaults.O, nullptr); + ASSERT_NE(defaults.error_rate_vec, nullptr); + EXPECT_EQ(defaults.error_rate_vec->size(), 1u); +} + +TEST(StimDemGetDecoder, ThrowsOnProbabilityOutOfRange) { const std::string dem_text = "error(1.5) D0\n"; - EXPECT_THROW( - cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text), - std::runtime_error); + EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text), + std::runtime_error); } -TEST(StimDemDecoderFactory, ThrowsOnMalformedStimDem) { - EXPECT_THROW(cudaq::qec::get_decoder_from_stim_dem("single_error_lut", - "not a valid DEM"), +TEST(StimDemGetDecoder, ThrowsOnMalformedStimDem) { + EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", "not a valid DEM"), std::runtime_error); } -TEST(StimDemDecoderFactory, ThrowsOnUnknownDecoderName) { +TEST(StimDemGetDecoder, ThrowsOnUnknownDecoderName) { const std::string dem_text = "error(0.1) D0 L0\n"; - EXPECT_THROW( - cudaq::qec::get_decoder_from_stim_dem("__no_such_decoder__", dem_text), - std::runtime_error); + EXPECT_THROW(cudaq::qec::get_decoder("__no_such_decoder__", dem_text), + std::runtime_error); } -TEST(StimDemDecoderFactory, ThrowsOnEmptyErrorMechanisms) { +TEST(StimDemGetDecoder, ThrowsOnEmptyErrorMechanisms) { const std::string dem_text = "detector(0, 0, 0)\n"; - EXPECT_THROW( - cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text), - std::runtime_error); + EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text), + std::runtime_error); } -TEST(StimDemDecoderFactory, StimDemTargetCategoriesAreExhaustive) { +TEST(StimDemGetDecoder, StimDemTargetCategoriesAreExhaustive) { const std::vector samples = { stim::DemTarget::separator(), stim::DemTarget::relative_detector_id(0), @@ -865,14 +868,13 @@ TEST(StimDemDecoderFactory, StimDemTargetCategoriesAreExhaustive) { } } -TEST(StimDemDecoderFactory, UserOptionsAreNotOverwritten) { +TEST(StimDemGetDecoder, UserOptionsAreNotOverwritten) { const std::string dem_text = R"(error(0.1) D0 L0 error(0.1) D1 L0 error(0.05) D0 D1 )"; cudaqx::heterogeneous_map opts; opts.insert("error_rate_vec", std::vector{0.5}); - EXPECT_THROW( - cudaq::qec::get_decoder_from_stim_dem("single_error_lut", dem_text, opts), - std::runtime_error); + EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text, opts), + std::runtime_error); } From 05b5d5a5962dcd5e31450099b96e85b16386463b Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Mon, 8 Jun 2026 14:13:44 -0400 Subject: [PATCH 12/14] remove require_dem_text helper function Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 12 +++---- .../include/cudaq/qec/detector_error_model.h | 3 ++ libs/qec/lib/CMakeLists.txt | 1 - libs/qec/lib/decoder.cpp | 15 ++++++++ libs/qec/lib/decoder_stim_dem.cpp | 36 ------------------- libs/qec/lib/detector_error_model.cpp | 6 ++++ libs/qec/python/bindings/py_decoder.cpp | 11 +++--- libs/qec/unittests/test_decoders.cpp | 2 +- 8 files changed, 37 insertions(+), 49 deletions(-) delete mode 100644 libs/qec/lib/decoder_stim_dem.cpp diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index e2c952df..661d4c76 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -13,12 +13,14 @@ #include "cuda-qx/core/tensor.h" #include "sparse_binary_matrix.h" #include "cudaq/qec/detector_error_model.h" +#include #include #include #include #include #include #include +#include #include #include @@ -31,8 +33,7 @@ using float_t = double; #endif /// Decoder construction input: either a parity-check matrix or raw Stim DEM -/// text. PCM-based decoders can accept both; DEM-native decoders can require -/// the string alternative via `require_dem_text`. +/// text. using decoder_init = std::variant; /// @brief Validates that all keys in a heterogeneous map are found in a list of @@ -516,6 +517,7 @@ get_decoder(const std::string &name, std::string_view stim_dem_text, return get_decoder(name, decoder_init{std::string{stim_dem_text}}, options); } +namespace details { /// DEM-derived defaults; pointers alias into the source `dem`. struct dem_default_values { const cudaqx::tensor *O = nullptr; @@ -526,9 +528,7 @@ struct dem_default_values { dem_default_values dem_defaults_for_missing_keys( const std::function &contains_user_key, const detector_error_model &dem); - -/// Extract raw Stim DEM text for DEM-native decoders. -std::string_view require_dem_text(const decoder_init &init); +} // namespace details /// Build a PCM-based decoder. If `init` holds DEM text, it is parsed into a PCM /// and DEM-derived `"O"` / `"error_rate_vec"` defaults are added when absent. @@ -541,7 +541,7 @@ make_pcm_decoder(const decoder_init &init, const auto dem = dem_from_stim_text(std::get(init)); cudaqx::heterogeneous_map merged = params; - const auto defaults = dem_defaults_for_missing_keys( + const auto defaults = details::dem_defaults_for_missing_keys( [&](const std::string &key) { return merged.contains(key); }, dem); if (defaults.O) merged.insert("O", *defaults.O); diff --git a/libs/qec/include/cudaq/qec/detector_error_model.h b/libs/qec/include/cudaq/qec/detector_error_model.h index f6d82d50..2fb8d6c6 100644 --- a/libs/qec/include/cudaq/qec/detector_error_model.h +++ b/libs/qec/include/cudaq/qec/detector_error_model.h @@ -8,8 +8,11 @@ #pragma once #include "cuda-qx/core/tensor.h" +#include +#include #include #include +#include namespace cudaq::qec { diff --git a/libs/qec/lib/CMakeLists.txt b/libs/qec/lib/CMakeLists.txt index 1a930e71..eea07bfc 100644 --- a/libs/qec/lib/CMakeLists.txt +++ b/libs/qec/lib/CMakeLists.txt @@ -41,7 +41,6 @@ endif() set(QEC_SOURCES code.cpp decoder.cpp - decoder_stim_dem.cpp detector_error_model.cpp experiments.cpp pcm_utils.cpp diff --git a/libs/qec/lib/decoder.cpp b/libs/qec/lib/decoder.cpp index 69ccf0f7..71fa4afb 100644 --- a/libs/qec/lib/decoder.cpp +++ b/libs/qec/lib/decoder.cpp @@ -141,6 +141,21 @@ decoder::get(const std::string &name, const decoder_init &init, return iter->second(init, param_map); } +namespace details { + +dem_default_values dem_defaults_for_missing_keys( + const std::function &contains_user_key, + const detector_error_model &dem) { + dem_default_values out; + if (!contains_user_key("O") && dem.num_observables() > 0) + out.O = &dem.observables_flips_matrix; + if (!contains_user_key("error_rate_vec")) + out.error_rate_vec = &dem.error_rates; + return out; +} + +} // namespace details + static uint32_t calculate_num_msyn_per_decode( const std::vector> &D_sparse) { uint32_t max_col = 0; diff --git a/libs/qec/lib/decoder_stim_dem.cpp b/libs/qec/lib/decoder_stim_dem.cpp deleted file mode 100644 index 4f260807..00000000 --- a/libs/qec/lib/decoder_stim_dem.cpp +++ /dev/null @@ -1,36 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2024 - 2026 NVIDIA Corporation & Affiliates. * - * All rights reserved. * - * This source code and the accompanying materials are made available under * - * the terms of the Apache License 2.0 which accompanies this distribution. * - ******************************************************************************/ - -#include "cudaq/qec/decoder.h" -#include "cudaq/qec/detector_error_model.h" - -#include -#include - -namespace cudaq::qec { - -dem_default_values dem_defaults_for_missing_keys( - const std::function &contains_user_key, - const detector_error_model &dem) { - dem_default_values out; - if (!contains_user_key("O") && dem.num_observables() > 0) - out.O = &dem.observables_flips_matrix; - if (!contains_user_key("error_rate_vec")) - out.error_rate_vec = &dem.error_rates; - return out; -} - -std::string_view require_dem_text(const decoder_init &init) { - if (const auto *dem_text = std::get_if(&init)) - return std::string_view{*dem_text}; - throw std::runtime_error( - "This decoder requires a Stim detector error model string; a " - "parity-check matrix cannot be used to reconstruct the detector " - "annotations it needs."); -} - -} // namespace cudaq::qec diff --git a/libs/qec/lib/detector_error_model.cpp b/libs/qec/lib/detector_error_model.cpp index c9c9c09d..cb9e1a1d 100644 --- a/libs/qec/lib/detector_error_model.cpp +++ b/libs/qec/lib/detector_error_model.cpp @@ -12,6 +12,12 @@ #include "stim.h" +#include +#include +#include +#include +#include + namespace cudaq::qec { detector_error_model dem_from_stim_text(const std::string &dem_text) { diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index cad8ff99..e728d0bc 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -788,7 +788,7 @@ void bindDecoder(nb::module_ &mod) { if (PyDecoderRegistry::contains(name)) { auto dem = dem_from_stim_text(dem_text); - auto defaults = dem_defaults_for_missing_keys( + auto defaults = details::dem_defaults_for_missing_keys( [&](const std::string &key) { return options.contains(key); }, dem); if (defaults.O) options["O"] = copyToPyArray(*defaults.O); @@ -849,10 +849,11 @@ void bindDecoder(nb::module_ &mod) { raw DEM text via ``decoder_init``; Python-registered decoders receive the DEM-derived PCM plus ``O`` and ``error_rate_vec`` defaults. - For Python-registered decoders (``cudaq.qec.decoder`` decorator), ``H`` - is passed through to ``__init__`` unchanged (NumPy array or sparse dict). - Call ``Decoder.__init__(self, H)`` so nanobind can store the PCM in CSC - form when ``H`` is a dict without building a dense ``rows x cols`` + For Python-registered decoders (``cudaq.qec.decoder`` decorator), + NumPy array and sparse dict inputs are passed through to ``__init__`` + unchanged. DEM string inputs are parsed first as described above. Call + ``Decoder.__init__(self, H)`` so nanobind can store the PCM in CSC form + when ``H`` is a dict without building a dense ``rows x cols`` allocation. )pbdoc"); diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index 400e2d68..6c63ff6e 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -820,7 +820,7 @@ error(0.1) L0 L0 TEST(StimDemGetDecoder, DemWithoutObservablesDoesNotAddODefault) { auto dem = cudaq::qec::dem_from_stim_text("error(0.1) D0\n"); - auto defaults = cudaq::qec::dem_defaults_for_missing_keys( + auto defaults = cudaq::qec::details::dem_defaults_for_missing_keys( [](const std::string &) { return false; }, dem); EXPECT_EQ(defaults.O, nullptr); From 92d3dc977b46035cfd228ebf71e1b6cb524e9951 Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Mon, 8 Jun 2026 14:42:31 -0400 Subject: [PATCH 13/14] update doc string Signed-off-by: vedika-saravanan --- libs/qec/include/cudaq/qec/decoder.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index 661d4c76..1a47b9b0 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -518,6 +518,7 @@ get_decoder(const std::string &name, std::string_view stim_dem_text, } namespace details { +// Declared here because `make_pcm_decoder` is a header-defined template. /// DEM-derived defaults; pointers alias into the source `dem`. struct dem_default_values { const cudaqx::tensor *O = nullptr; @@ -530,8 +531,8 @@ dem_default_values dem_defaults_for_missing_keys( const detector_error_model &dem); } // namespace details -/// Build a PCM-based decoder. If `init` holds DEM text, it is parsed into a PCM -/// and DEM-derived `"O"` / `"error_rate_vec"` defaults are added when absent. +/// If `init` holds DEM text, parse it and inject `"O"` / `"error_rate_vec"` +/// defaults when absent. template std::unique_ptr make_pcm_decoder(const decoder_init &init, From 7b30f089f03a08b27678816de16c7ad35b245e03 Mon Sep 17 00:00:00 2001 From: vedika-saravanan Date: Mon, 8 Jun 2026 17:04:43 -0400 Subject: [PATCH 14/14] address pr comments Signed-off-by: vedika-saravanan --- libs/qec/lib/CMakeLists.txt | 3 +++ libs/qec/unittests/test_decoders.cpp | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/qec/lib/CMakeLists.txt b/libs/qec/lib/CMakeLists.txt index eea07bfc..0c34b801 100644 --- a/libs/qec/lib/CMakeLists.txt +++ b/libs/qec/lib/CMakeLists.txt @@ -74,6 +74,9 @@ if(NOT TARGET libstim) message(FATAL_ERROR "Stim FetchContent did not provide the libstim target.") endif() +set_target_properties(${LIBRARY_NAME} PROPERTIES + VISIBILITY_INLINES_HIDDEN ON +) target_link_libraries(${LIBRARY_NAME} PRIVATE libstim) target_link_options(${LIBRARY_NAME} PRIVATE $<$,$>:-Wl,--exclude-libs,libstim.a> diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index 6c63ff6e..672c4e4b 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -707,7 +707,7 @@ TEST(DecoderResultTest, EqualityOperatorConvergedAndResult) { EXPECT_FALSE(result1 != result2); } -TEST(DecoderTest, GetBlockSizeAndSyndromeSize) { +TEST(DecoderTest, GetWithoutOptionsSetsBlockAndSyndromeSize) { std::size_t block_size = 15; std::size_t syndrome_size = 8;