diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index 3da1eda5..1a47b9b0 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -12,8 +12,16 @@ #include "cuda-qx/core/heterogeneous_map.h" #include "cuda-qx/core/tensor.h" #include "sparse_binary_matrix.h" +#include "cudaq/qec/detector_error_model.h" +#include +#include #include +#include #include +#include +#include +#include +#include #include namespace cudaq::qec { @@ -24,6 +32,10 @@ using float_t = CUDAQX_QEC_FLOAT_TYPE; using float_t = double; #endif +/// Decoder construction input: either a parity-check matrix or raw Stim DEM +/// text. +using decoder_init = std::variant; + /// @brief Validates that all keys in a heterogeneous map are found in a list of /// acceptable types /// @param config The heterogeneous map to validate @@ -122,8 +134,7 @@ class async_decoder_result { /// arbitrary constructor parameters that can be unique to each specific /// decoder. class decoder - : public cudaqx::extension_point { private: struct rt_impl; @@ -143,16 +154,16 @@ class decoder /// @brief Decode a single syndrome /// @param syndrome A vector of syndrome measurements where the floating point /// value is the probability that the syndrome measurement is a |1>. The - /// length of the syndrome vector should be equal to \p syndrome_size. - /// @returns Vector of length \p block_size with soft probabilities of errors + /// length of the syndrome vector should be equal to `syndrome_size`. + /// @returns Vector of length `block_size` with soft probabilities of errors /// in each index. virtual decoder_result decode(const std::vector &syndrome) = 0; /// @brief Decode a single syndrome /// @param syndrome An order-1 tensor of syndrome measurements where a 1 bit /// represents that the syndrome measurement is a |1>. The - /// length of the syndrome vector should be equal to \p syndrome_size. - /// @returns Vector of length \p block_size of errors in each index. + /// length of the syndrome vector should be equal to `syndrome_size`. + /// @returns Vector of length `block_size` of errors in each index. virtual decoder_result decode(const cudaqx::tensor &syndrome); /// @brief Decode a single syndrome @@ -172,11 +183,49 @@ class decoder virtual std::vector decode_batch(const std::vector> &syndrome); - /// @brief This `get` overload supports default values. + /// @brief Construct a registered decoder by name. + /// @param name The registered decoder name. + /// @param init A parity-check matrix or raw Stim DEM string. + /// @param param_map Optional decoder-specific parameters. static std::unique_ptr - get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, + get(const std::string &name, const decoder_init &init, const cudaqx::heterogeneous_map ¶m_map = cudaqx::heterogeneous_map()); + static std::unique_ptr + get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, + const cudaqx::heterogeneous_map ¶m_map = + cudaqx::heterogeneous_map()) { + return get(name, decoder_init{H}, param_map); + } + + static std::unique_ptr + get(const std::string &name, const cudaqx::tensor &H, + const cudaqx::heterogeneous_map ¶m_map = + cudaqx::heterogeneous_map()) { + return get(name, cudaq::qec::sparse_binary_matrix(H), param_map); + } + + static std::unique_ptr + get(const std::string &name, const std::string &stim_dem_text, + const cudaqx::heterogeneous_map ¶m_map = + cudaqx::heterogeneous_map()) { + return get(name, decoder_init{stim_dem_text}, param_map); + } + + static std::unique_ptr + get(const std::string &name, const char *stim_dem_text, + const cudaqx::heterogeneous_map ¶m_map = + cudaqx::heterogeneous_map()) { + return get(name, decoder_init{std::string{stim_dem_text}}, param_map); + } + + static std::unique_ptr + get(const std::string &name, std::string_view stim_dem_text, + const cudaqx::heterogeneous_map ¶m_map = + cudaqx::heterogeneous_map()) { + return get(name, decoder_init{std::string{stim_dem_text}}, param_map); + } + std::size_t get_block_size() { return block_size; } std::size_t get_syndrome_size() { return syndrome_size; } @@ -435,6 +484,72 @@ inline void convert_vec_hard_to_soft(const std::vector> &in, } std::unique_ptr -get_decoder(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, +get_decoder(const std::string &name, const decoder_init &init, const cudaqx::heterogeneous_map options = {}); + +inline std::unique_ptr +get_decoder(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, + const cudaqx::heterogeneous_map options = {}) { + return get_decoder(name, decoder_init{H}, options); +} + +inline std::unique_ptr +get_decoder(const std::string &name, const cudaqx::tensor &H, + const cudaqx::heterogeneous_map options = {}) { + return get_decoder(name, cudaq::qec::sparse_binary_matrix(H), options); +} + +inline std::unique_ptr +get_decoder(const std::string &name, const std::string &stim_dem_text, + const cudaqx::heterogeneous_map options = {}) { + return get_decoder(name, decoder_init{stim_dem_text}, options); +} + +inline std::unique_ptr +get_decoder(const std::string &name, const char *stim_dem_text, + const cudaqx::heterogeneous_map options = {}) { + return get_decoder(name, decoder_init{std::string{stim_dem_text}}, options); +} + +inline std::unique_ptr +get_decoder(const std::string &name, std::string_view stim_dem_text, + const cudaqx::heterogeneous_map options = {}) { + return get_decoder(name, decoder_init{std::string{stim_dem_text}}, options); +} + +namespace details { +// Declared here because `make_pcm_decoder` is a header-defined template. +/// DEM-derived defaults; pointers alias into the source `dem`. +struct dem_default_values { + const cudaqx::tensor *O = nullptr; + const std::vector *error_rate_vec = nullptr; +}; + +/// Return DEM defaults for keys not already supplied by the user. +dem_default_values dem_defaults_for_missing_keys( + const std::function &contains_user_key, + const detector_error_model &dem); +} // namespace details + +/// If `init` holds DEM text, parse it and inject `"O"` / `"error_rate_vec"` +/// defaults when absent. +template +std::unique_ptr +make_pcm_decoder(const decoder_init &init, + const cudaqx::heterogeneous_map ¶ms) { + if (const auto *H = std::get_if(&init)) + return std::make_unique(*H, params); + + const auto dem = dem_from_stim_text(std::get(init)); + cudaqx::heterogeneous_map merged = params; + const auto defaults = details::dem_defaults_for_missing_keys( + [&](const std::string &key) { return merged.contains(key); }, dem); + if (defaults.O) + merged.insert("O", *defaults.O); + if (defaults.error_rate_vec) + merged.insert("error_rate_vec", *defaults.error_rate_vec); + return std::make_unique( + cudaq::qec::sparse_binary_matrix(dem.detector_error_matrix), merged); +} + } // namespace cudaq::qec diff --git a/libs/qec/include/cudaq/qec/detector_error_model.h b/libs/qec/include/cudaq/qec/detector_error_model.h index 747f01bb..2fb8d6c6 100644 --- a/libs/qec/include/cudaq/qec/detector_error_model.h +++ b/libs/qec/include/cudaq/qec/detector_error_model.h @@ -8,7 +8,11 @@ #pragma once #include "cuda-qx/core/tensor.h" +#include +#include #include +#include +#include namespace cudaq::qec { @@ -18,9 +22,9 @@ namespace cudaq::qec { /// decoder to help make predictions about observables flips. /// /// Shared size parameters among the matrix types. -/// - \p detector_error_matrix: num_detectors x num_error_mechanisms [d, e] -/// - \p error_rates: num_error_mechanisms -/// - \p observables_flips_matrix: num_observables x num_error_mechanisms [k, e] +/// - `detector_error_matrix`: num_detectors x num_error_mechanisms [d, e] +/// - `error_rates`: num_error_mechanisms +/// - `observables_flips_matrix`: num_observables x num_error_mechanisms [k, e] /// /// @note The C++ API for this class may change in the future. The Python API is /// more likely to be backwards compatible. @@ -32,7 +36,7 @@ struct detector_error_model { cudaqx::tensor detector_error_matrix; /// The list of weights has length equal to the number of columns of - /// \p detector_error_matrix, which assigns a likelihood to each error + /// `detector_error_matrix`, which assigns a likelihood to each error /// mechanism. std::vector error_rates; @@ -65,4 +69,8 @@ struct detector_error_model { void canonicalize_for_rounds(uint32_t num_syndromes_per_round); }; +/// Parse Stim DEM text into detector/observable flip matrices and error rates. +/// This is lossy; DEM-native decoders should consume raw DEM text instead. +detector_error_model dem_from_stim_text(const std::string &dem_text); + } // namespace cudaq::qec diff --git a/libs/qec/lib/CMakeLists.txt b/libs/qec/lib/CMakeLists.txt index ef4acf18..0c34b801 100644 --- a/libs/qec/lib/CMakeLists.txt +++ b/libs/qec/lib/CMakeLists.txt @@ -8,6 +8,8 @@ set(LIBRARY_NAME cudaq-qec) +include(FetchContent) + add_compile_options(-Wno-attributes) find_package(CUDAToolkit REQUIRED) @@ -44,7 +46,7 @@ set(QEC_SOURCES pcm_utils.cpp plugin_loader.cpp sparse_binary_matrix.cpp - stabilizer_utils.cpp + stabilizer_utils.cpp decoders/lut.cpp decoders/sliding_window.cpp version.cpp @@ -58,6 +60,28 @@ list(APPEND QEC_SOURCES # FIXME?: This must be a shared library. Trying to build a static one will fail. add_library(${LIBRARY_NAME} SHARED ${QEC_SOURCES}) +if(NOT TARGET libstim) + FetchContent_Declare( + stim + GIT_REPOSITORY https://github.com/quantumlib/Stim.git + GIT_TAG v1.15.0 + EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(stim) +endif() + +if(NOT TARGET libstim) + message(FATAL_ERROR + "Stim FetchContent did not provide the libstim target.") +endif() +set_target_properties(${LIBRARY_NAME} PROPERTIES + VISIBILITY_INLINES_HIDDEN ON +) +target_link_libraries(${LIBRARY_NAME} PRIVATE libstim) +target_link_options(${LIBRARY_NAME} PRIVATE + $<$,$>:-Wl,--exclude-libs,libstim.a> +) + add_subdirectory(decoders/plugins/example) add_subdirectory(decoders/plugins/pymatching) diff --git a/libs/qec/lib/decoder.cpp b/libs/qec/lib/decoder.cpp index b631522b..71fa4afb 100644 --- a/libs/qec/lib/decoder.cpp +++ b/libs/qec/lib/decoder.cpp @@ -17,10 +17,7 @@ #include #include -INSTANTIATE_REGISTRY(cudaq::qec::decoder, - const cudaq::qec::sparse_binary_matrix &) -INSTANTIATE_REGISTRY(cudaq::qec::decoder, - const cudaq::qec::sparse_binary_matrix &, +INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaq::qec::decoder_init &, const cudaqx::heterogeneous_map &) // Include decoder implementations AFTER registry instantiation @@ -131,7 +128,7 @@ decoder::decode_async(const std::vector &syndrome) { } std::unique_ptr -decoder::get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, +decoder::get(const std::string &name, const decoder_init &init, const cudaqx::heterogeneous_map ¶m_map) { auto [mutex, registry] = get_registry(); std::lock_guard lock(mutex); @@ -141,9 +138,24 @@ decoder::get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H, "invalid decoder requested: " + name + ". Run with CUDAQ_LOG_LEVEL=info (environment variable) to see " "additional plugin diagnostics at startup."); - return iter->second(H, param_map); + return iter->second(init, param_map); } +namespace details { + +dem_default_values dem_defaults_for_missing_keys( + const std::function &contains_user_key, + const detector_error_model &dem) { + dem_default_values out; + if (!contains_user_key("O") && dem.num_observables() > 0) + out.O = &dem.observables_flips_matrix; + if (!contains_user_key("error_rate_vec")) + out.error_rate_vec = &dem.error_rates; + return out; +} + +} // namespace details + static uint32_t calculate_num_msyn_per_decode( const std::vector> &D_sparse) { uint32_t max_col = 0; @@ -480,9 +492,9 @@ void decoder::reset_decoder() { } std::unique_ptr get_decoder(const std::string &name, - const cudaq::qec::sparse_binary_matrix &H, + const decoder_init &init, const cudaqx::heterogeneous_map options) { - return decoder::get(name, H, options); + return decoder::get(name, init, options); } // Constructor function for auto-loading plugins diff --git a/libs/qec/lib/decoders/lut.cpp b/libs/qec/lib/decoders/lut.cpp index ab5e8a89..0bacfd91 100644 --- a/libs/qec/lib/decoders/lut.cpp +++ b/libs/qec/lib/decoders/lut.cpp @@ -228,9 +228,9 @@ class multi_error_lut : public decoder { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( multi_error_lut, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) }; @@ -246,9 +246,9 @@ class single_error_lut : public multi_error_lut { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( single_error_lut, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) }; diff --git a/libs/qec/lib/decoders/plugins/example/single_error_lut_example.cpp b/libs/qec/lib/decoders/plugins/example/single_error_lut_example.cpp index 26154f61..3b4dbdc1 100644 --- a/libs/qec/lib/decoders/plugins/example/single_error_lut_example.cpp +++ b/libs/qec/lib/decoders/plugins/example/single_error_lut_example.cpp @@ -77,9 +77,10 @@ class single_error_lut_example : public decoder { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( single_error_lut_example, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, + params); }) }; diff --git a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp index 03d9627f..7d0421d7 100644 --- a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp +++ b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp @@ -247,9 +247,9 @@ class pymatching : public decoder { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( pymatching, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) }; diff --git a/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp b/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp index 438ebd00..614fcd98 100644 --- a/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp +++ b/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp @@ -432,9 +432,9 @@ class trt_decoder : public decoder { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( trt_decoder, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) private: diff --git a/libs/qec/lib/decoders/sliding_window.h b/libs/qec/lib/decoders/sliding_window.h index e240da9f..826a991d 100644 --- a/libs/qec/lib/decoders/sliding_window.h +++ b/libs/qec/lib/decoders/sliding_window.h @@ -142,9 +142,9 @@ class sliding_window : public decoder { // Plugin registration macros CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( sliding_window, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) }; diff --git a/libs/qec/lib/detector_error_model.cpp b/libs/qec/lib/detector_error_model.cpp index 14612ea4..cb9e1a1d 100644 --- a/libs/qec/lib/detector_error_model.cpp +++ b/libs/qec/lib/detector_error_model.cpp @@ -10,8 +10,98 @@ #include "cudaq/qec/pcm_utils.h" #include "cudaq/runtime/logger/logger.h" +#include "stim.h" + +#include +#include +#include +#include +#include + namespace cudaq::qec { +detector_error_model dem_from_stim_text(const std::string &dem_text) { + auto dem = [&dem_text]() { + try { + return stim::DetectorErrorModel(dem_text); + } catch (const std::exception &e) { + throw std::runtime_error(std::string("Stim DEM parse failed: ") + + e.what()); + } + }(); + const std::size_t num_detectors = + static_cast(dem.count_detectors()); + const std::size_t num_observables = + static_cast(dem.count_observables()); + + std::vector> detector_hits; + std::vector> observable_hits; + std::vector rates; + std::size_t instruction_index = 0; + + dem.iter_flatten_error_instructions([&](const stim::DemInstruction &inst) { + if (inst.arg_data.empty()) + throw std::runtime_error( + "Stim DEM error instruction missing probability argument (index " + + std::to_string(instruction_index) + ")"); + const double prob = inst.arg_data[0]; + if (!(prob >= 0.0 && prob <= 1.0)) + throw std::runtime_error("Stim DEM error probability " + + std::to_string(prob) + + " out of range [0, 1] at instruction index " + + std::to_string(instruction_index)); + std::vector dets; + std::vector obs; + for (const auto &target : inst.target_data) { + if (target.is_separator()) + continue; + if (target.is_relative_detector_id()) { + dets.push_back(static_cast(target.val())); + } else if (target.is_observable_id()) { + obs.push_back(static_cast(target.val())); + } else { + throw std::runtime_error( + "Stim DEM error instruction (index " + + std::to_string(instruction_index) + + ") contains an unsupported target kind; only D* (detector) and " + "L* (observable) targets are supported by the fallback parser"); + } + } + detector_hits.push_back(std::move(dets)); + observable_hits.push_back(std::move(obs)); + rates.push_back(prob); + ++instruction_index; + }); + + const std::size_t num_errors = rates.size(); + if (num_errors == 0) + throw std::runtime_error( + "Stim DEM contains no error mechanisms after flattening"); + detector_error_model result; + result.detector_error_matrix = + cudaqx::tensor({num_detectors, num_errors}); + result.observables_flips_matrix = + cudaqx::tensor({num_observables, num_errors}); + result.error_rates = std::move(rates); + + for (std::size_t err = 0; err < num_errors; ++err) { + for (auto det : detector_hits[err]) { + if (det >= num_detectors) + throw std::runtime_error( + "Stim DEM detector id out of range while extracting H"); + result.detector_error_matrix.at({det, err}) ^= 1; + } + for (auto ob : observable_hits[err]) { + if (ob >= num_observables) + throw std::runtime_error( + "Stim DEM observable id out of range while extracting O"); + result.observables_flips_matrix.at({ob, err}) ^= 1; + } + } + + return result; +} + std::size_t detector_error_model::num_detectors() const { auto shape = detector_error_matrix.shape(); if (shape.size() == 2) diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 31c94a0e..e728d0bc 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -387,6 +387,20 @@ makeBatchDecoderResult(const std::vector &results) { }; } +nb::object copyToPyArray(const cudaqx::tensor &t) { + size_t shape[2] = {t.shape()[0], t.shape()[1]}; + auto arr = nb::ndarray(const_cast(t.data()), 2, + shape, nb::none()); + return nb::cast(arr).attr("copy")(); +} + +nb::object copyToPyArray(const std::vector &v) { + size_t shape[1] = {v.size()}; + auto arr = nb::ndarray(const_cast(v.data()), 1, + shape, nb::none()); + return nb::cast(arr).attr("copy")(); +} + } // namespace void bindDecoder(nb::module_ &mod) { @@ -731,6 +745,11 @@ void bindDecoder(nb::module_ &mod) { )pbdoc", nb::arg("num_syndromes_per_round")); + qecmod.def( + "dem_from_stim_text", &dem_from_stim_text, + "Parse a Stim detector error model string into a DetectorErrorModel.", + nb::arg("dem_text")); + // Expose decorator function that handles inheritance qecmod.def("decoder", [&](const std::string &name) { return nb::cpp_function([name](nb::object decoder_class) -> nb::object { @@ -762,10 +781,37 @@ void bindDecoder(nb::module_ &mod) { }); }); + auto get_decoder_from_dem_text = [](const std::string &name, + const std::string &dem_text, + nb::kwargs options) + -> std::variant> { + if (PyDecoderRegistry::contains(name)) { + auto dem = dem_from_stim_text(dem_text); + + auto defaults = details::dem_defaults_for_missing_keys( + [&](const std::string &key) { return options.contains(key); }, dem); + if (defaults.O) + options["O"] = copyToPyArray(*defaults.O); + if (defaults.error_rate_vec) + options["error_rate_vec"] = copyToPyArray(*defaults.error_rate_vec); + + nb::object H_obj = copyToPyArray(dem.detector_error_matrix); + return PyDecoderRegistry::get_decoder(name, H_obj, options); + } + + return get_decoder(name, decoder_init{dem_text}, hetMapFromKwargs(options)); + }; + qecmod.def( "get_decoder", - [](const std::string &name, nb::object H, nb::kwargs options) + [get_decoder_from_dem_text](const std::string &name, nb::object H, + nb::kwargs options) -> std::variant> { + if (nb::isinstance(H)) { + return get_decoder_from_dem_text(name, nb::cast(H), + options); + } + if (PyDecoderRegistry::contains(name)) { return PyDecoderRegistry::get_decoder(name, H, options); } @@ -799,11 +845,15 @@ void bindDecoder(nb::module_ &mod) { builds ``sparse_binary_matrix`` directly (no dense tensor for ``H``), allowing native C++ decoders like ``pymatching`` to be constructed from very large parity-check matrices. - - For Python-registered decoders (``cudaq.qec.decoder`` decorator), ``H`` - is passed through to ``__init__`` unchanged (NumPy array or sparse dict). - Call ``Decoder.__init__(self, H)`` so nanobind can store the PCM in CSC - form when ``H`` is a dict without building a dense ``rows × cols`` + - A Stim detector error model string: native C++ decoders receive the + raw DEM text via ``decoder_init``; Python-registered decoders receive + the DEM-derived PCM plus ``O`` and ``error_rate_vec`` defaults. + + For Python-registered decoders (``cudaq.qec.decoder`` decorator), + NumPy array and sparse dict inputs are passed through to ``__init__`` + unchanged. DEM string inputs are parsed first as described above. Call + ``Decoder.__init__(self, H)`` so nanobind can store the PCM in CSC form + when ``H`` is a dict without building a dense ``rows x cols`` allocation. )pbdoc"); diff --git a/libs/qec/python/cudaq_qec/__init__.py b/libs/qec/python/cudaq_qec/__init__.py index d2ae1add..692a3af8 100644 --- a/libs/qec/python/cudaq_qec/__init__.py +++ b/libs/qec/python/cudaq_qec/__init__.py @@ -84,6 +84,7 @@ def checked_decode_batch(self, *args, **kwargs): get_code = qecrt.get_code get_available_codes = qecrt.get_available_codes get_decoder = qecrt.get_decoder +dem_from_stim_text = qecrt.dem_from_stim_text DecoderResult = qecrt.DecoderResult BatchDecoderResult = qecrt.BatchDecoderResult DetectorErrorModel = qecrt.DetectorErrorModel diff --git a/libs/qec/python/tests/test_decoder.py b/libs/qec/python/tests/test_decoder.py index c95dd615..0b9e65f0 100644 --- a/libs/qec/python/tests/test_decoder.py +++ b/libs/qec/python/tests/test_decoder.py @@ -769,5 +769,71 @@ def test_generate_random_pcm_signed_weight_rejects_negative(): seed=1) +def test_get_decoder_accepts_stim_dem_string(): + dem_text = ("error(0.1) D0 L0\n" + "error(0.1) D1 L0\n" + "error(0.05) D0 D1\n") + + decoder = qec.get_decoder("single_error_lut", dem_text) + assert decoder is not None + assert decoder.get_syndrome_size() == 2 + assert decoder.get_block_size() == 3 + + cases = [ + ([0.0, 0.0], [0.0, 0.0, 0.0]), + ([1.0, 0.0], [1.0, 0.0, 0.0]), + ([0.0, 1.0], [0.0, 1.0, 0.0]), + ([1.0, 1.0], [0.0, 0.0, 1.0]), + ] + for syndrome, expected in cases: + result = decoder.decode(syndrome) + assert result.converged is True, f"syndrome {syndrome}" + assert list(result.result) == expected, f"syndrome {syndrome}" + + +def test_dem_from_stim_text_explicit_parse_then_get_decoder(): + dem_text = ("error(0.1) D0 L0\n" + "error(0.1) D1 L0\n" + "error(0.05) D0 D1\n") + + dem = qec.dem_from_stim_text(dem_text) + assert isinstance(dem, qec.DetectorErrorModel) + assert dem.num_detectors() == 2 + assert dem.num_error_mechanisms() == 3 + assert dem.num_observables() == 1 + assert dem.detector_error_matrix.shape == (2, 3) + + decoder = qec.get_decoder("single_error_lut", dem.detector_error_matrix) + assert decoder.get_syndrome_size() == 2 + assert decoder.get_block_size() == 3 + + +def test_get_decoder_rejects_malformed_stim_dem_text(): + with pytest.raises(RuntimeError): + qec.get_decoder("single_error_lut", "not a valid DEM") + + +def test_get_decoder_rejects_unknown_decoder_for_stim_dem_text(): + with pytest.raises(RuntimeError, match="__no_such_decoder__"): + qec.get_decoder("__no_such_decoder__", "error(0.1) D0 L0\n") + + +def test_get_decoder_user_O_wins_over_dem_derived(): + dem_text = ("error(0.1) D0 L0\n" + "error(0.1) D1 L0\n" + "error(0.05) D0 D1\n") + bad_O = np.zeros((1, 4), dtype=np.uint8) + with pytest.raises(RuntimeError): + qec.get_decoder("pymatching", dem_text, O=bad_O) + + +def test_get_decoder_stim_dem_without_observables_returns_errors(): + decoder = qec.get_decoder("pymatching", "error(0.1) D0\n") + + result = decoder.decode([1.0]) + assert result.converged is True + assert list(result.result) == [1.0] + + if __name__ == "__main__": pytest.main() diff --git a/libs/qec/unittests/CMakeLists.txt b/libs/qec/unittests/CMakeLists.txt index 0e878dac..923ae9fa 100644 --- a/libs/qec/unittests/CMakeLists.txt +++ b/libs/qec/unittests/CMakeLists.txt @@ -35,7 +35,7 @@ find_package(CUDAToolkit REQUIRED) add_compile_options(-Wno-attributes) add_executable(test_decoders test_decoders.cpp decoders/sample_decoder.cpp) -target_link_libraries(test_decoders PRIVATE GTest::gtest_main cudaq-qec cudaq-qec-realtime-decoding cudaq::cudaq) +target_link_libraries(test_decoders PRIVATE GTest::gtest_main cudaq-qec cudaq-qec-realtime-decoding cudaq::cudaq libstim) add_dependencies(CUDAQXQECUnitTests test_decoders) gtest_discover_tests(test_decoders) @@ -509,4 +509,3 @@ add_subdirectory(decoders/pymatching) if(CUDAQX_QEC_ENABLE_HOLOLINK_TOOLS) add_subdirectory(utils) endif() - diff --git a/libs/qec/unittests/decoders/sample_decoder.cpp b/libs/qec/unittests/decoders/sample_decoder.cpp index 8b3b509a..da953411 100644 --- a/libs/qec/unittests/decoders/sample_decoder.cpp +++ b/libs/qec/unittests/decoders/sample_decoder.cpp @@ -14,7 +14,7 @@ using namespace cudaqx; namespace cudaq::qec { /// @brief This is a sample (dummy) decoder that demonstrates how to build a -/// bare bones custom decoder based on the `cudaqx::qec::decoder` interface. +/// bare bones custom decoder based on the `cudaq::qec::decoder` interface. class sample_decoder : public decoder { public: sample_decoder(const cudaq::qec::sparse_binary_matrix &H, @@ -35,9 +35,9 @@ class sample_decoder : public decoder { CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( sample_decoder, static std::unique_ptr create( - const cudaq::qec::sparse_binary_matrix &H, + const cudaq::qec::decoder_init &init, const cudaqx::heterogeneous_map ¶ms) { - return std::make_unique(H, params); + return cudaq::qec::make_pcm_decoder(init, params); }) }; diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index 85611a5b..672c4e4b 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -6,7 +6,9 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ +#include "stim.h" #include "cudaq/qec/decoder.h" +#include "cudaq/qec/detector_error_model.h" #include "cudaq/qec/pcm_utils.h" #include #include @@ -705,7 +707,7 @@ TEST(DecoderResultTest, EqualityOperatorConvergedAndResult) { EXPECT_FALSE(result1 != result2); } -TEST(DecoderTest, GetBlockSizeAndSyndromeSize) { +TEST(DecoderTest, GetWithoutOptionsSetsBlockAndSyndromeSize) { std::size_t block_size = 15; std::size_t syndrome_size = 8; @@ -743,63 +745,136 @@ TEST(DecoderTest, GetBlockSizeAndSyndromeSize) { EXPECT_EQ(decoder2->get_syndrome_size(), new_syndrome_size); } -TEST(DecoderRegistryTest, SingleParameterRegistryDirect) { - // Test the single-parameter registry instantiation (line 18 in decoder.cpp) - // This directly tests the registry for decoder constructors that only take - // tensor by accessing the single-parameter extension_point registry - // directly +TEST(StimDemGetDecoder, ConstructsLutDecoderFromStimDemText) { + const std::string dem_text = R"(error(0.1) D0 L0 +error(0.1) D1 L0 +error(0.05) D0 D1 +)"; - std::size_t block_size = 8; - std::size_t syndrome_size = 4; - cudaqx::tensor H({syndrome_size, block_size}); + auto d = cudaq::qec::get_decoder("single_error_lut", dem_text); + ASSERT_NE(d, nullptr); + EXPECT_EQ(d->get_syndrome_size(), 2u); + EXPECT_EQ(d->get_block_size(), 3u); - // Initialize with some test data to ensure it's a valid matrix - for (std::size_t i = 0; i < syndrome_size; ++i) { - for (std::size_t j = 0; j < block_size; ++j) { - H.at({i, j}) = (i + j) % 2; - } + struct Case { + std::vector syndrome; + std::vector expected; + }; + const std::vector cases = { + {{0.0, 0.0}, {0.0, 0.0, 0.0}}, + {{1.0, 0.0}, {1.0, 0.0, 0.0}}, + {{0.0, 1.0}, {0.0, 1.0, 0.0}}, + {{1.0, 1.0}, {0.0, 0.0, 1.0}}, + }; + for (const auto &c : cases) { + auto result = d->decode(c.syndrome); + EXPECT_TRUE(result.converged) + << "syndrome {" << c.syndrome[0] << ", " << c.syndrome[1] << "}"; + ASSERT_EQ(result.result.size(), 3u); + for (std::size_t i = 0; i < 3u; ++i) + EXPECT_FLOAT_EQ(result.result[i], c.expected[i]) + << "error " << i << " for syndrome {" << c.syndrome[0] << ", " + << c.syndrome[1] << "}"; } +} - auto H_sparse = cudaq::qec::sparse_binary_matrix(H); +TEST(StimDemGetDecoder, StaticDecoderGetAcceptsStimDemString) { + const std::string dem_text = R"(error(0.1) D0 L0 +error(0.1) D1 L0 +error(0.05) D0 D1 +)"; + + auto d = cudaq::qec::decoder::get("single_error_lut", dem_text); + ASSERT_NE(d, nullptr); + EXPECT_EQ(d->get_syndrome_size(), 2u); + EXPECT_EQ(d->get_block_size(), 3u); + auto result = d->decode(std::vector{1.0, 1.0}); + EXPECT_TRUE(result.converged); + ASSERT_EQ(result.result.size(), 3u); + EXPECT_FLOAT_EQ(result.result[2], 1.0); +} - // Test that the single-parameter registry exists and can be accessed - // This directly tests line 18: INSTANTIATE_REGISTRY(cudaq::qec::decoder, - // const cudaqx::tensor &) - try { - // Create a decoder using the single-parameter extension_point directly - // This bypasses decoder::get and directly uses the single-parameter - // registry - auto single_param_decoder = cudaqx::extension_point< - cudaq::qec::decoder, - const cudaq::qec::sparse_binary_matrix &>::get("sample_decoder", - H_sparse); +TEST(StimDemGetDecoder, StillAcceptsParityCheckMatrix) { + cudaqx::tensor H({2, 3}); + H.copy(std::vector{1, 0, 1, 0, 1, 1}.data(), {2, 3}); + auto d = cudaq::qec::get_decoder("single_error_lut", H); + ASSERT_NE(d, nullptr); + EXPECT_EQ(d->get_syndrome_size(), 2u); + EXPECT_EQ(d->get_block_size(), 3u); +} - ASSERT_NE(single_param_decoder, nullptr); +TEST(StimDemGetDecoder, RepeatedDetectorOrObservableTargetsXorFold) { + const std::string dem_text = R"(error(0.1) D0 D0 +error(0.1) L0 L0 +)"; + + auto dem = cudaq::qec::dem_from_stim_text(dem_text); + ASSERT_EQ(dem.num_detectors(), 1u); + ASSERT_EQ(dem.num_observables(), 1u); + ASSERT_EQ(dem.num_error_mechanisms(), 2u); + EXPECT_EQ(dem.detector_error_matrix.at({0u, 0u}), 0u) + << "duplicate D0 in error 0 should XOR-cancel to 0"; + EXPECT_EQ(dem.observables_flips_matrix.at({0u, 1u}), 0u) + << "duplicate L0 in error 1 should XOR-cancel to 0"; +} - // Verify the decoder works correctly - EXPECT_EQ(single_param_decoder->get_block_size(), block_size); - EXPECT_EQ(single_param_decoder->get_syndrome_size(), syndrome_size); +TEST(StimDemGetDecoder, DemWithoutObservablesDoesNotAddODefault) { + auto dem = cudaq::qec::dem_from_stim_text("error(0.1) D0\n"); + auto defaults = cudaq::qec::details::dem_defaults_for_missing_keys( + [](const std::string &) { return false; }, dem); - // Test with a syndrome decode to ensure functionality - std::vector syndrome(syndrome_size, 0.0f); - auto result = single_param_decoder->decode(syndrome); - EXPECT_EQ(result.result.size(), block_size); + EXPECT_EQ(defaults.O, nullptr); + ASSERT_NE(defaults.error_rate_vec, nullptr); + EXPECT_EQ(defaults.error_rate_vec->size(), 1u); +} - } catch (const std::runtime_error &e) { - // This is expected if "sample_decoder" is not registered in the - // single-parameter registry The test still passes because it verifies that - // line 18 creates a functional registry - EXPECT_TRUE(std::string(e.what()).find("Cannot find extension with name") != - std::string::npos); - } +TEST(StimDemGetDecoder, ThrowsOnProbabilityOutOfRange) { + const std::string dem_text = "error(1.5) D0\n"; + EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text), + std::runtime_error); +} - // Test that we can check if extensions are registered in the single-parameter - // registry - auto registered_single = cudaqx::extension_point< - cudaq::qec::decoder, - const cudaq::qec::sparse_binary_matrix &>::get_registered(); +TEST(StimDemGetDecoder, ThrowsOnMalformedStimDem) { + EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", "not a valid DEM"), + std::runtime_error); +} + +TEST(StimDemGetDecoder, ThrowsOnUnknownDecoderName) { + const std::string dem_text = "error(0.1) D0 L0\n"; + EXPECT_THROW(cudaq::qec::get_decoder("__no_such_decoder__", dem_text), + std::runtime_error); +} + +TEST(StimDemGetDecoder, ThrowsOnEmptyErrorMechanisms) { + const std::string dem_text = "detector(0, 0, 0)\n"; + EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text), + std::runtime_error); +} + +TEST(StimDemGetDecoder, StimDemTargetCategoriesAreExhaustive) { + const std::vector samples = { + stim::DemTarget::separator(), + stim::DemTarget::relative_detector_id(0), + stim::DemTarget::relative_detector_id(42), + stim::DemTarget::observable_id(0), + stim::DemTarget::observable_id(7), + }; + for (const auto &t : samples) { + const int kinds = static_cast(t.is_separator()) + + static_cast(t.is_relative_detector_id()) + + static_cast(t.is_observable_id()); + EXPECT_EQ(kinds, 1) << "DemTarget " << t.str() << " matched " << kinds + << " predicates; expected exactly 1"; + } +} - // The registry should exist (even if empty), proving line 18 instantiation - // works This test passes if no exceptions are thrown, proving the - // single-parameter registry is instantiated +TEST(StimDemGetDecoder, UserOptionsAreNotOverwritten) { + const std::string dem_text = R"(error(0.1) D0 L0 +error(0.1) D1 L0 +error(0.05) D0 D1 +)"; + cudaqx::heterogeneous_map opts; + opts.insert("error_rate_vec", std::vector{0.5}); + EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text, opts), + std::runtime_error); }