diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index 1a47b9b0..0a2f4c0c 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -144,6 +144,22 @@ class decoder std::unique_ptr pimpl; public: + /// @brief Indicates whether decode() returns a full error frame (length + /// block_size) or an already-projected observable frame (length + /// num_observables). Decoders that accept an "O" observable matrix in their + /// constructor params should call set_result_type(decode_to_obs); all others + /// default to decode_to_errs. + /// + /// Note: even in decode_to_obs mode, set_O_sparse() must still be called so + /// that enqueue_syndrome() knows num_observables and can size the corrections + /// buffer correctly. + enum decode_result_type { + decode_to_errs, ///< result.size() == block_size; enqueue_syndrome projects + ///< via O_sparse + decode_to_obs, ///< result.size() == num_observables; enqueue_syndrome uses + ///< result directly; set_O_sparse() still required + }; + decoder() = delete; /// @brief Constructor @@ -234,6 +250,12 @@ class decoder // Note: all of the current realtime decoding API is designed to be used with // hard syndromes. + /// @brief Returns the type of result produced by decode(). + /// Defaults to decode_to_errs. Decoders that project to observables + /// internally (i.e., constructed with an "O" param) should call + /// set_result_type(decode_to_obs) in their constructor. + decode_result_type get_result_type() const { return result_type_; } + /// @brief Get the number of measurement syndromes per decode call. This /// depends on D_sparse, so you must have called set_D_sparse() first. uint32_t get_num_msyn_per_decode() const; @@ -315,6 +337,11 @@ class decoder virtual std::string get_version() const; protected: + /// @brief Sets the result type. Call in the constructor when an "O" + /// observable matrix is detected in the decoder params. Must be called + /// before the first enqueue_syndrome(). + void set_result_type(decode_result_type type) { result_type_ = type; } + /// @brief For a classical `[n,k]` code, this is `n`. std::size_t block_size = 0; @@ -329,6 +356,9 @@ class decoder /// @brief The decoder's D matrix in sparse format std::vector> D_sparse; + +private: + decode_result_type result_type_ = decode_result_type::decode_to_errs; }; /// @brief Convert a single soft probability to a hard 0/1 decision. diff --git a/libs/qec/include/cudaq/qec/realtime/decoding_config.h b/libs/qec/include/cudaq/qec/realtime/decoding_config.h index ab50a9c8..d8687f8b 100644 --- a/libs/qec/include/cudaq/qec/realtime/decoding_config.h +++ b/libs/qec/include/cudaq/qec/realtime/decoding_config.h @@ -107,12 +107,22 @@ struct pymatching_config { from_heterogeneous_map(const cudaqx::heterogeneous_map &map); }; +// The realtime trt_decoder config currently models only PyMatching as a global +// decoder. Other global decoder plugins may be constructed through lower-level +// APIs when their parameters are supplied directly, but they are not serialized +// by this config variant yet. +using global_decoder_config = std::variant; + struct trt_decoder_config { std::optional onnx_load_path; std::optional engine_load_path; std::optional engine_save_path; std::optional precision; std::optional memory_workspace; + std::optional batch_size; + std::optional use_cuda_graph; + std::optional global_decoder; + global_decoder_config global_decoder_params; bool operator==(const trt_decoder_config &) const = default; diff --git a/libs/qec/lib/decoder.cpp b/libs/qec/lib/decoder.cpp index 71fa4afb..6e95a776 100644 --- a/libs/qec/lib/decoder.cpp +++ b/libs/qec/lib/decoder.cpp @@ -292,6 +292,7 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, std::vector log_msyn; std::vector log_detectors; std::vector log_errors; + std::vector log_observables; std::vector log_observable_corrections; // The four time points are used to measure the duration of each of 3 steps. std::chrono::time_point log_t0, log_t1, @@ -305,6 +306,7 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, if (should_log) { log_t0 = std::chrono::high_resolution_clock::now(); log_errors.reserve(syndrome_length); + log_observables.reserve(O_sparse.size()); log_observable_corrections.resize(O_sparse.size()); } @@ -362,36 +364,76 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, return false; } } + // Process the results. + // TODO - should this interrogate the decoded_result.converged flag? + const auto result_type = get_result_type(); + const auto num_observables = get_num_observables(); + const char *result_type_str = nullptr; + const char *result_type_name = nullptr; + std::size_t expected_result_size = 0; + switch (result_type) { + case decode_result_type::decode_to_errs: + result_type_str = "errs"; + result_type_name = "decode_to_errs"; + expected_result_size = block_size; + break; + case decode_result_type::decode_to_obs: + result_type_str = "obs"; + result_type_name = "decode_to_obs"; + expected_result_size = num_observables; + break; + } + if (!result_type_name) + throw std::runtime_error( + fmt::format("Unsupported decoder result type ({})", + static_cast(result_type))); if ((!pimpl->is_sliding_window && - decoded_result.result.size() != block_size) || + decoded_result.result.size() != expected_result_size) || (pimpl->is_sliding_window && !decoded_result.result.empty() && - decoded_result.result.size() != block_size)) { - throw std::runtime_error( - fmt::format("Decoder result size ({}) does not match block_size ({})", - decoded_result.result.size(), block_size)); + decoded_result.result.size() != expected_result_size)) { + throw std::runtime_error(fmt::format( + "Decoder result size ({}) does not match expected size ({}) for " + "result type {}", + decoded_result.result.size(), expected_result_size, + result_type_name)); } - if (should_log) { + // Flip an observable correction and mirror it into the per-call log so the + // logged flips stay faithful to the applied corrections. + auto flip_correction = [&](std::size_t i) { + pimpl->corrections[i] ^= 1; + if (should_log) + log_observable_corrections[i] ^= 1; + }; + + if (should_log) log_t2 = std::chrono::high_resolution_clock::now(); - for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++) - if (decoded_result.result[e]) - log_errors.push_back(e); - } - // Process the results. - // TODO - should this interrogate the decoded_result.converged flag? - auto num_observables = O_sparse.size(); - // For each observable - for (std::size_t i = 0; i < num_observables; i++) { - // For each error that flips this observable - for (auto col : O_sparse[i]) { - // If the decoder predicted that this error occurred - if (decoded_result.result[col]) { - // Flip the correction for this observable - pimpl->corrections[i] ^= 1; + + switch (result_type) { + case decode_result_type::decode_to_obs: + // Observable-frame path: decoder already projected to observables via its + // internal "O" matrix; use the result directly. + for (std::size_t i = 0; i < num_observables; i++) + if (decoded_result.result[i]) { if (should_log) - log_observable_corrections[i] ^= 1; + log_observables.push_back(i); + flip_correction(i); } - } + break; + case decode_result_type::decode_to_errs: + // Error-frame path: decoder returns a block-sized error vector; project + // to observables via O_sparse. + if (should_log) + for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++) + if (decoded_result.result[e]) + log_errors.push_back(e); + // For each observable, flip its correction once for each predicted error + // that flips it (net parity over O_sparse[i]). + for (std::size_t i = 0; i < num_observables; i++) + for (auto col : O_sparse[i]) + if (decoded_result.result[col]) + flip_correction(i); + break; } if (should_log) { log_t3 = std::chrono::high_resolution_clock::now(); @@ -401,13 +443,15 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, pimpl->log_counter++; auto s = fmt::format( "[DecoderStats][{}] Counter:{} DecoderId:{} InputMsyn:{} " - "InputDetectors:{} Converged:{} Errors:{} " + "InputDetectors:{} Converged:{} ResultType:{} Errors:{} " + "Observables:{} " "ObservableCorrectionsThisCall:{} ObservableCorrectionsTotal:{} " "Dur1:{:.1f}us Dur2:{:.1f}us Dur3:{:.1f}us", static_cast(this), pimpl->log_counter, pimpl->decoder_id, fmt::join(log_msyn, ","), fmt::join(log_detectors, ","), decoded_result.converged ? 1 : 0, - fmt::join(log_errors, ","), + result_type_str, fmt::join(log_errors, ","), + fmt::join(log_observables, ","), fmt::join(log_observable_corrections, ","), fmt::join(std::vector(pimpl->corrections.begin(), pimpl->corrections.end()), diff --git a/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp b/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp index 614fcd98..a3162993 100644 --- a/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp +++ b/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp @@ -782,6 +782,7 @@ trt_decoder::trt_decoder(const cudaq::qec::sparse_binary_matrix &H, } decode_to_observables_ = true; num_observables_ = O.shape()[0]; + set_result_type(decode_result_type::decode_to_obs); // The TRT model output must encode [pre_L (num_observables_ entries), // residual_dets (rest)]. Validate sizing where we can. diff --git a/libs/qec/lib/realtime/config.cpp b/libs/qec/lib/realtime/config.cpp index eca0269a..c8737d7c 100644 --- a/libs/qec/lib/realtime/config.cpp +++ b/libs/qec/lib/realtime/config.cpp @@ -190,6 +190,31 @@ single_error_lut_config single_error_lut_config::from_heterogeneous_map( return config; } +cudaqx::heterogeneous_map global_decoder_config_to_heterogeneous_map( + const global_decoder_config &global_decoder_params) { + if (std::holds_alternative(global_decoder_params)) { + return cudaqx::heterogeneous_map(); + } + + if (std::holds_alternative(global_decoder_params)) { + return std::get(global_decoder_params) + .to_heterogeneous_map(); + } + + throw std::runtime_error("Unsupported global decoder parameters."); +} + +global_decoder_config global_decoder_config_from_heterogeneous_map( + const cudaqx::heterogeneous_map &map, + const std::optional &global_decoder) { + if (global_decoder.has_value() && global_decoder.value() != "pymatching") { + throw std::runtime_error( + "global_decoder_params currently supports only pymatching."); + } + + return pymatching_config::from_heterogeneous_map(map); +} + // ------ pymatching_config ------ cudaqx::heterogeneous_map pymatching_config::to_heterogeneous_map() const { cudaqx::heterogeneous_map config_map; @@ -217,6 +242,27 @@ cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const { INSERT_ARG(engine_save_path); INSERT_ARG(precision); INSERT_ARG(memory_workspace); + INSERT_ARG(batch_size); + INSERT_ARG(use_cuda_graph); + INSERT_ARG(global_decoder); + if (!std::holds_alternative(global_decoder_params)) { + if (!global_decoder.has_value()) { + throw std::runtime_error( + "global_decoder_params present but global_decoder is not set."); + } + if (global_decoder.value() != "pymatching") { + throw std::runtime_error( + "global_decoder_params currently supports only pymatching."); + } + config_map.insert( + "global_decoder_params", + global_decoder_config_to_heterogeneous_map(global_decoder_params)); + } + // Note: when global_decoder_params is monostate we intentionally emit + // nothing, even if global_decoder is set. Inventing an empty params map here + // would round-trip back as a default pymatching_config, mutating the config. + // Any runtime need for an empty params map is handled in + // prepare_decoder_params (realtime_decoding.cpp), not in serialization. return config_map; } @@ -229,6 +275,32 @@ trt_decoder_config trt_decoder_config::from_heterogeneous_map( GET_ARG(engine_save_path); GET_ARG(precision); GET_ARG(memory_workspace); + GET_ARG(batch_size); + GET_ARG(use_cuda_graph); + GET_ARG(global_decoder); + if (map.contains("global_decoder_params")) { + if (!config.global_decoder.has_value()) + throw std::runtime_error( + "global_decoder_params present but global_decoder is not set."); + if (config.global_decoder.value() != "pymatching") + throw std::runtime_error( + "global_decoder_params currently supports only pymatching."); + try { + config.global_decoder_params = + map.get("global_decoder_params"); + } catch (...) { + try { + config.global_decoder_params = + map.get("global_decoder_params"); + } catch (...) { + auto nested_map = + map.get("global_decoder_params"); + config.global_decoder_params = + global_decoder_config_from_heterogeneous_map(nested_map, + config.global_decoder); + } + } + } return config; } @@ -363,6 +435,30 @@ struct MappingTraits { cudaq::qec::decoding::config::single_error_lut_config &config) {} }; +template <> +struct MappingTraits { + static void + mapping(IO &io, cudaq::qec::decoding::config::global_decoder_config &config) { + using namespace cudaq::qec::decoding::config; + + if (io.outputting()) { + if (std::holds_alternative(config)) { + return; + } + + auto ¶ms = std::get(config); + io.mapOptional("merge_strategy", params.merge_strategy); + io.mapOptional("error_rate_vec", params.error_rate_vec); + return; + } + + pymatching_config params; + io.mapOptional("merge_strategy", params.merge_strategy); + io.mapOptional("error_rate_vec", params.error_rate_vec); + config = std::move(params); + } +}; + template <> struct MappingTraits { static void mapping(IO &io, @@ -381,6 +477,26 @@ struct MappingTraits { io.mapOptional("engine_save_path", config.engine_save_path); io.mapOptional("precision", config.precision); io.mapOptional("memory_workspace", config.memory_workspace); + io.mapOptional("batch_size", config.batch_size); + io.mapOptional("use_cuda_graph", config.use_cuda_graph); + io.mapOptional("global_decoder", config.global_decoder); + // Emit global_decoder_params only when it actually holds params. Mapping it + // unconditionally on output writes an empty `global_decoder_params: {}` for + // the monostate case, which deserializes back into a default + // pymatching_config -- mutating a monostate config across a YAML + // round-trip. On input we always map it: an absent key leaves the variant + // at its monostate default (mapOptional skips the nested mapping), and a + // present key is parsed into pymatching_config. + if (!io.outputting() || + !std::holds_alternative(config.global_decoder_params)) + io.mapOptional("global_decoder_params", config.global_decoder_params); + + if (!std::holds_alternative(config.global_decoder_params) && + config.global_decoder.has_value() && + config.global_decoder.value() != "pymatching") { + throw std::runtime_error( + "global_decoder_params currently supports only pymatching."); + } } }; diff --git a/libs/qec/lib/realtime/realtime_decoding.cpp b/libs/qec/lib/realtime/realtime_decoding.cpp index 0a3edb7a..2e1465dc 100644 --- a/libs/qec/lib/realtime/realtime_decoding.cpp +++ b/libs/qec/lib/realtime/realtime_decoding.cpp @@ -12,6 +12,7 @@ #include "cudaq/qec/pcm_utils.h" #include "cudaq/qec/realtime/decoding_config.h" #include "cudaq/runtime/logger/logger.h" +#include #include #include #include @@ -179,6 +180,45 @@ static std::vector pack_syndrome_bits(const uint8_t *syndromes, namespace cudaq::qec::decoding::host { +cudaqx::heterogeneous_map prepare_decoder_params( + const cudaq::qec::decoding::config::decoder_config &decoder_config) { + auto params = decoder_config.decoder_custom_args_to_heterogeneous_map(); + if (decoder_config.type != "trt_decoder") + return params; + + // The trt_decoder plugin attaches a pymatching global decoder only when both + // "global_decoder" and "global_decoder_params" are present. Serialization no + // longer emits an empty params map for the monostate (no-params) case, so + // synthesize one here -- before the O_sparse early return -- so that a global + // decoder running on residual detectors without an O matrix still attaches. + const bool has_pymatching_global = + params.contains("global_decoder") && + params.get("global_decoder") == "pymatching"; + if (has_pymatching_global && !params.contains("global_decoder_params")) + params.insert("global_decoder_params", cudaqx::heterogeneous_map()); + + if (decoder_config.O_sparse.empty()) + return params; + + const auto num_observables = std::count(decoder_config.O_sparse.begin(), + decoder_config.O_sparse.end(), -1); + if (num_observables == 0) + return params; + + auto O = cudaq::qec::pcm_from_sparse_vec( + decoder_config.O_sparse, num_observables, decoder_config.block_size); + params.insert("O", O); + + if (has_pymatching_global) { + auto global_decoder_params = + params.get("global_decoder_params"); + global_decoder_params.insert("O", O); + params.insert("global_decoder_params", global_decoder_params); + } + + return params; +} + cudaq::qec::realtime::qec_realtime_session *get_realtime_session() { return g_realtime_session.get(); } @@ -248,8 +288,7 @@ int configure_decoders( decoder_config.syndrome_size, decoder_config.block_size); auto new_decoder = cudaq::qec::get_decoder( - decoder_config.type, pcm, - decoder_config.decoder_custom_args_to_heterogeneous_map()); + decoder_config.type, pcm, prepare_decoder_params(decoder_config)); new_decoder->set_decoder_id(decoder_config.id); // Count the number of -1's in the O_sparse vector. That is the number of // rows (observables) in the observable matrix. diff --git a/libs/qec/lib/realtime/realtime_decoding.h b/libs/qec/lib/realtime/realtime_decoding.h index 233574f2..00d86662 100644 --- a/libs/qec/lib/realtime/realtime_decoding.h +++ b/libs/qec/lib/realtime/realtime_decoding.h @@ -30,6 +30,10 @@ __attribute__((visibility("default"))) void enqueue_syndromes(std::size_t decoder_id, uint8_t *syndromes, std::uint64_t syndrome_length, std::uint64_t tag); +__attribute__((visibility("default"))) cudaqx::heterogeneous_map +prepare_decoder_params( + const cudaq::qec::decoding::config::decoder_config &decoder_config); + __attribute__((visibility("default"))) void get_corrections(std::size_t decoder_id, uint8_t *corrections, std::uint64_t correction_length, bool reset); diff --git a/libs/qec/python/bindings/py_decoding_config.cpp b/libs/qec/python/bindings/py_decoding_config.cpp index b23aabd3..3bc3c454 100644 --- a/libs/qec/python/bindings/py_decoding_config.cpp +++ b/libs/qec/python/bindings/py_decoding_config.cpp @@ -29,6 +29,9 @@ void bindDecodingConfig(nb::module_ &mod) { auto mod_cfg = qecmod.def_submodule("config", "Realtime decoding configuration"); + // Allow Python None to clear std::optional fields. + const auto setter_accepts_none = nb::for_setter(nb::arg("value").none()); + // srelay_bp_config nb::class_(mod_cfg, "srelay_bp_config", "Relay-BP decoder configuration.") @@ -120,11 +123,38 @@ void bindDecodingConfig(nb::module_ &mod) { trt_decoder_config::from_heterogeneous_map(map)); }, nb::arg("map")) - .def_rw("onnx_load_path", &trt_decoder_config::onnx_load_path) - .def_rw("engine_load_path", &trt_decoder_config::engine_load_path) - .def_rw("engine_save_path", &trt_decoder_config::engine_save_path) - .def_rw("precision", &trt_decoder_config::precision) - .def_rw("memory_workspace", &trt_decoder_config::memory_workspace) + .def_rw("onnx_load_path", &trt_decoder_config::onnx_load_path, + setter_accepts_none) + .def_rw("engine_load_path", &trt_decoder_config::engine_load_path, + setter_accepts_none) + .def_rw("engine_save_path", &trt_decoder_config::engine_save_path, + setter_accepts_none) + .def_rw("precision", &trt_decoder_config::precision, setter_accepts_none) + .def_rw("memory_workspace", &trt_decoder_config::memory_workspace, + setter_accepts_none) + .def_rw("batch_size", &trt_decoder_config::batch_size, + setter_accepts_none) + .def_rw("use_cuda_graph", &trt_decoder_config::use_cuda_graph, + setter_accepts_none) + .def_rw("global_decoder", &trt_decoder_config::global_decoder, + setter_accepts_none) + .def_prop_rw( + "global_decoder_params", + [](const trt_decoder_config &self) + -> std::optional { + if (std::holds_alternative( + self.global_decoder_params)) { + return std::get(self.global_decoder_params); + } + return std::nullopt; + }, + [](trt_decoder_config &self, std::optional value) { + if (value.has_value()) { + self.global_decoder_params = value.value(); + } else { + self.global_decoder_params = std::monostate(); + } + }) .def("to_heterogeneous_map", &trt_decoder_config::to_heterogeneous_map, nb::rv_policy::move) .def_static("from_heterogeneous_map", diff --git a/libs/qec/python/bindings/type_casters.h b/libs/qec/python/bindings/type_casters.h index ea43df08..7de6d454 100644 --- a/libs/qec/python/bindings/type_casters.h +++ b/libs/qec/python/bindings/type_casters.h @@ -194,6 +194,17 @@ struct type_caster { cudaq::qec::decoding::config::single_error_lut_config>( &val)) { result[key.c_str()] = nb::cast(single_cfg->to_heterogeneous_map()); + } else if (auto *global_cfg = std::any_cast< + cudaq::qec::decoding::config::global_decoder_config>( + &val)) { + if (std::holds_alternative(*global_cfg)) { + result[key.c_str()] = nb::none(); + } else { + result[key.c_str()] = nb::cast( + std::get( + *global_cfg) + .to_heterogeneous_map()); + } } else if (auto *pymatching_cfg = std::any_cast< cudaq::qec::decoding::config::pymatching_config>(&val)) { result[key.c_str()] = diff --git a/libs/qec/unittests/decoders/sample_decoder.cpp b/libs/qec/unittests/decoders/sample_decoder.cpp index da953411..0d357b5d 100644 --- a/libs/qec/unittests/decoders/sample_decoder.cpp +++ b/libs/qec/unittests/decoders/sample_decoder.cpp @@ -16,18 +16,25 @@ namespace cudaq::qec { /// @brief This is a sample (dummy) decoder that demonstrates how to build a /// bare bones custom decoder based on the `cudaq::qec::decoder` interface. class sample_decoder : public decoder { +private: + bool decode_to_obs = false; + public: sample_decoder(const cudaq::qec::sparse_binary_matrix &H, const cudaqx::heterogeneous_map ¶ms) : decoder(H) { // Decoder-specific constructor arguments can be placed in `params`. + decode_to_obs = params.get("decode_to_obs", decode_to_obs); + if (decode_to_obs) + set_result_type(decode_result_type::decode_to_obs); } virtual decoder_result decode(const std::vector &syndrome) { // This is a simple decoder that simply results decoder_result result; result.converged = true; - result.result = std::vector(block_size, 0.0f); + result.result = + decode_to_obs ? syndrome : std::vector(block_size, 0.0f); return result; } diff --git a/libs/qec/unittests/realtime/CMakeLists.txt b/libs/qec/unittests/realtime/CMakeLists.txt index 68972ecc..8230421f 100644 --- a/libs/qec/unittests/realtime/CMakeLists.txt +++ b/libs/qec/unittests/realtime/CMakeLists.txt @@ -119,6 +119,7 @@ if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY AND TENSORRT_ONNX_PARSER_LIBRARY target_link_libraries(test_trt_decoder_composite PRIVATE CUDA::cudart cudaq-qec + cudaq-qec-realtime-decoding cudaq-qec-trt-decoder cudaq::cudaq ) diff --git a/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp b/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp index 79457b6f..5a94035a 100644 --- a/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp +++ b/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp @@ -25,9 +25,16 @@ * --data-dir DIR [--max-samples=N] [--onnx-path=FILE] * [--engine-save-path=FILE] [--batch-size=N] [--warmup=N] * [--no-cuda-graph] [--no-raw-diagnostics] + * + * test_trt_decoder_composite --data-dir DIR --config-yaml FILE + * [--decoder-id=N] [--max-samples=N] [--warmup=N] + * [--no-raw-diagnostics] ******************************************************************************/ #include "predecoder_pipeline_common.h" +#include "../../lib/realtime/realtime_decoding.h" +#include "cudaq/qec/pcm_utils.h" +#include "cudaq/qec/realtime/decoding_config.h" #include #include @@ -36,6 +43,8 @@ #include #include #include +#include +#include #include #include #include @@ -48,8 +57,10 @@ using hrclock = std::chrono::high_resolution_clock; struct DemoConfig { std::string data_dir; + std::string config_yaml_path; std::string onnx_path; std::string engine_save_path; + int64_t decoder_id = -1; int max_samples = 0; int warmup_count = 20; size_t batch_size = 1; @@ -99,6 +110,9 @@ void print_usage(const char *argv0) { "detectors/observables/H/O\n" << " --max-samples=N Limit samples decoded (0 = all)\n" << " --warmup=N Samples excluded from latency stats\n" + << " --config-yaml=FILE Build composite decoder from YAML " + "config\n" + << " --decoder-id=N Decoder ID to select from YAML config\n" << " --onnx-path=FILE Override full ONNX path\n" << " --engine-save-path=FILE Where the built TRT engine is saved\n" << " --batch-size=N TRT dynamic batch profile size (default " @@ -127,6 +141,14 @@ DemoConfig parse_demo_config(int argc, char *argv[]) { cfg.warmup_count = std::stoi(value_after_equals(arg, "--warmup=")); } else if (arg == "--warmup" && i + 1 < argc) { cfg.warmup_count = std::stoi(argv[++i]); + } else if (starts_with(arg, "--config-yaml=")) { + cfg.config_yaml_path = value_after_equals(arg, "--config-yaml="); + } else if (arg == "--config-yaml" && i + 1 < argc) { + cfg.config_yaml_path = argv[++i]; + } else if (starts_with(arg, "--decoder-id=")) { + cfg.decoder_id = std::stoll(value_after_equals(arg, "--decoder-id=")); + } else if (arg == "--decoder-id" && i + 1 < argc) { + cfg.decoder_id = std::stoll(argv[++i]); } else if (starts_with(arg, "--onnx-path=")) { cfg.onnx_path = value_after_equals(arg, "--onnx-path="); } else if (arg == "--onnx-path" && i + 1 < argc) { @@ -152,6 +174,59 @@ DemoConfig parse_demo_config(int argc, char *argv[]) { return cfg; } +std::string read_text_file(const std::string &path) { + std::ifstream in(path); + if (!in.good()) + throw std::runtime_error("Failed to open " + path); + std::ostringstream buffer; + buffer << in.rdbuf(); + return buffer.str(); +} + +size_t sparse_vec_nnz(const std::vector &sparse) { + return static_cast(std::count_if(sparse.begin(), sparse.end(), + [](int64_t v) { return v >= 0; })); +} + +size_t sparse_vec_rows(const std::vector &sparse) { + return static_cast(std::count(sparse.begin(), sparse.end(), -1)); +} + +template +void copy_param_if_present(const cudaqx::heterogeneous_map &src, + cudaqx::heterogeneous_map &dst, + const std::string &key) { + if (src.contains(key)) + dst.insert(key, src.get(key)); +} + +bool build_raw_trt_params(const cudaqx::heterogeneous_map &trt_params, + cudaqx::heterogeneous_map &raw_params) { + bool has_model_source = false; + if (trt_params.contains("engine_load_path")) { + raw_params.insert("engine_load_path", + trt_params.get("engine_load_path")); + has_model_source = true; + } else if (trt_params.contains("engine_save_path") && + file_exists(trt_params.get("engine_save_path"))) { + raw_params.insert("engine_load_path", + trt_params.get("engine_save_path")); + has_model_source = true; + } else if (trt_params.contains("onnx_load_path")) { + raw_params.insert("onnx_load_path", + trt_params.get("onnx_load_path")); + copy_param_if_present(trt_params, raw_params, + "engine_save_path"); + has_model_source = true; + } + + copy_param_if_present(trt_params, raw_params, "batch_size"); + copy_param_if_present(trt_params, raw_params, "use_cuda_graph"); + copy_param_if_present(trt_params, raw_params, "memory_workspace"); + copy_param_if_present(trt_params, raw_params, "precision"); + return has_model_source; +} + std::vector sample_to_syndrome(const TestData &data, int sample_idx) { std::vector syndrome(data.num_detectors); @@ -205,6 +280,23 @@ struct RawDiagnostics { int64_t total_pymatch_frame = 0; }; +struct DecoderSetup { + std::unique_ptr decoder; + cudaqx::tensor H; + cudaqx::heterogeneous_map trt_params; + std::string label; + std::string init_mode; + std::string config_yaml_path; + std::string onnx_path; + std::string engine_save_path; + size_t H_rows = 0; + size_t H_cols = 0; + size_t H_nnz = 0; + size_t O_rows = 0; + size_t O_cols = 0; + size_t O_nnz = 0; +}; + CompositeStats run_composite_decoder(cudaq::qec::decoder &decoder, const TestData &test_data, int n_samples, size_t num_observables) { @@ -261,6 +353,122 @@ CompositeStats run_composite_decoder(cudaq::qec::decoder &decoder, return stats; } +const cudaq::qec::decoding::config::decoder_config &select_yaml_decoder( + const cudaq::qec::decoding::config::multi_decoder_config &config, + int64_t decoder_id) { + if (decoder_id >= 0) { + auto it = std::find_if(config.decoders.begin(), config.decoders.end(), + [&](const auto &decoder_config) { + return decoder_config.id == decoder_id; + }); + if (it == config.decoders.end()) + throw std::runtime_error("Decoder ID " + std::to_string(decoder_id) + + " not found in YAML config."); + return *it; + } + + if (config.decoders.size() != 1) + throw std::runtime_error("YAML config contains " + + std::to_string(config.decoders.size()) + + " decoders; pass --decoder-id to select one."); + return config.decoders.front(); +} + +DecoderSetup create_decoder_from_yaml(const DemoConfig &demo_cfg) { + using cudaq::qec::decoding::config::multi_decoder_config; + + auto config = multi_decoder_config::from_yaml_str( + read_text_file(demo_cfg.config_yaml_path)); + const auto &decoder_config = select_yaml_decoder(config, demo_cfg.decoder_id); + if (decoder_config.type != "trt_decoder") { + throw std::runtime_error("YAML decoder type must be trt_decoder, got '" + + decoder_config.type + "'."); + } + + DecoderSetup setup; + setup.label = "yaml decoder " + std::to_string(decoder_config.id); + setup.init_mode = "YAML config"; + setup.config_yaml_path = demo_cfg.config_yaml_path; + setup.H_rows = static_cast(decoder_config.syndrome_size); + setup.H_cols = static_cast(decoder_config.block_size); + setup.H_nnz = sparse_vec_nnz(decoder_config.H_sparse); + setup.O_rows = sparse_vec_rows(decoder_config.O_sparse); + setup.O_cols = static_cast(decoder_config.block_size); + setup.O_nnz = sparse_vec_nnz(decoder_config.O_sparse); + + setup.H = cudaq::qec::pcm_from_sparse_vec(decoder_config.H_sparse, + decoder_config.syndrome_size, + decoder_config.block_size); + setup.trt_params = + cudaq::qec::decoding::host::prepare_decoder_params(decoder_config); + if (setup.trt_params.contains("onnx_load_path")) + setup.onnx_path = setup.trt_params.get("onnx_load_path"); + if (setup.trt_params.contains("engine_save_path")) + setup.engine_save_path = + setup.trt_params.get("engine_save_path"); + if (setup.trt_params.contains("engine_load_path") && + setup.engine_save_path.empty()) + setup.engine_save_path = + setup.trt_params.get("engine_load_path"); + + setup.decoder = + cudaq::qec::decoder::get(decoder_config.type, setup.H, setup.trt_params); + return setup; +} + +DecoderSetup create_decoder_from_cli(const PipelineConfig &config, + const DemoConfig &demo_cfg, + const StimData &stim) { + std::string onnx_path = + demo_cfg.onnx_path.empty() ? config.onnx_path() : demo_cfg.onnx_path; + std::string engine_save_path = demo_cfg.engine_save_path.empty() + ? replace_extension(onnx_path, ".engine") + : demo_cfg.engine_save_path; + + if (!file_exists(onnx_path)) + throw std::runtime_error("ONNX file not found: " + onnx_path); + + auto H = stim.H.to_dense(); + auto O = stim.O.to_dense(); + + cudaqx::heterogeneous_map pm_params; + pm_params.insert("merge_strategy", std::string("smallest_weight")); + pm_params.insert("O", O); + if (!stim.priors.empty()) { + if (stim.priors.size() != stim.H.ncols) { + throw std::runtime_error( + "priors.bin has " + std::to_string(stim.priors.size()) + + " entries, but H has " + std::to_string(stim.H.ncols) + " columns."); + } + pm_params.insert("error_rate_vec", stim.priors); + } + + DecoderSetup setup; + setup.H = H; + setup.label = config.label; + setup.init_mode = "manual CLI args"; + setup.onnx_path = onnx_path; + setup.engine_save_path = engine_save_path; + setup.H_rows = stim.H.nrows; + setup.H_cols = stim.H.ncols; + setup.H_nnz = stim.H.nnz; + setup.O_rows = stim.O.nrows; + setup.O_cols = stim.O.ncols; + setup.O_nnz = stim.O.nnz; + + setup.trt_params.insert("onnx_load_path", onnx_path); + setup.trt_params.insert("engine_save_path", engine_save_path); + setup.trt_params.insert("batch_size", demo_cfg.batch_size); + setup.trt_params.insert("use_cuda_graph", demo_cfg.use_cuda_graph); + setup.trt_params.insert("global_decoder", std::string("pymatching")); + setup.trt_params.insert("global_decoder_params", pm_params); + setup.trt_params.insert("O", O); + + setup.decoder = + cudaq::qec::decoder::get("trt_decoder", setup.H, setup.trt_params); + return setup; +} + RawDiagnostics run_raw_diagnostics(cudaq::qec::decoder &raw_decoder, const TestData &test_data, const std::vector &final_pred, @@ -337,17 +545,6 @@ int main(int argc, char *argv[]) { return 1; } - std::string onnx_path = - demo_cfg.onnx_path.empty() ? config.onnx_path() : demo_cfg.onnx_path; - std::string engine_save_path = demo_cfg.engine_save_path.empty() - ? replace_extension(onnx_path, ".engine") - : demo_cfg.engine_save_path; - - if (!file_exists(onnx_path)) { - std::cerr << "ERROR: ONNX file not found: " << onnx_path << "\n"; - return 1; - } - TestData test_data = load_test_data(demo_cfg.data_dir); if (!test_data.loaded()) { std::cerr << "ERROR: failed to load detector/observable test data from " @@ -355,74 +552,82 @@ int main(int argc, char *argv[]) { return 1; } - StimData stim = load_stim_data(demo_cfg.data_dir); - if (!stim.H.loaded()) { - std::cerr << "ERROR: H_csr.bin is required in " << demo_cfg.data_dir + const bool use_yaml_config = !demo_cfg.config_yaml_path.empty(); + std::optional stim; + if (!use_yaml_config) { + stim = load_stim_data(demo_cfg.data_dir); + if (!stim->H.loaded()) { + std::cerr << "ERROR: H_csr.bin is required in " << demo_cfg.data_dir + << "\n"; + return 1; + } + if (!stim->O.loaded()) { + std::cerr << "ERROR: O_csr.bin is required in " << demo_cfg.data_dir + << "\n"; + return 1; + } + if (stim->O.nrows == 0) { + std::cerr << "ERROR: O_csr.bin contains zero observables.\n"; + return 1; + } + } + + DecoderSetup setup; + try { + std::cout << "--- Initializing Composite TensorRT Decoder (" + << (use_yaml_config ? demo_cfg.config_yaml_path : config.label) + << ") ---\n"; + setup = use_yaml_config ? create_decoder_from_yaml(demo_cfg) + : create_decoder_from_cli(config, demo_cfg, *stim); + } catch (const std::exception &e) { + std::cerr << "ERROR: failed to create composite trt_decoder: " << e.what() << "\n"; return 1; } - if (!stim.O.loaded()) { - std::cerr << "ERROR: O_csr.bin is required in " << demo_cfg.data_dir - << "\n"; + + if (setup.O_rows == 0) { + std::cerr << "ERROR: observable matrix contains zero observables.\n"; return 1; } - if (stim.O.nrows == 0) { - std::cerr << "ERROR: O_csr.bin contains zero observables.\n"; + if (test_data.num_detectors != setup.H_rows) { + std::cerr << "ERROR: detectors.bin has " << test_data.num_detectors + << " detectors, but decoder H has " << setup.H_rows << " rows.\n"; return 1; } - if (test_data.num_observables < stim.O.nrows) { + if (test_data.num_observables < setup.O_rows) { std::cerr << "ERROR: observables.bin has " << test_data.num_observables - << " observable column(s), but O_csr.bin has " << stim.O.nrows + << " observable column(s), but decoder O has " << setup.O_rows << " row(s).\n"; return 1; } - - auto H = stim.H.to_dense(); - auto O = stim.O.to_dense(); - - cudaqx::heterogeneous_map pm_params; - pm_params.insert("merge_strategy", std::string("smallest_weight")); - pm_params.insert("O", O); - if (!stim.priors.empty()) { - if (stim.priors.size() != stim.H.ncols) { - std::cerr << "ERROR: priors.bin has " << stim.priors.size() - << " entries, but H has " << stim.H.ncols << " columns.\n"; - return 1; - } - pm_params.insert("error_rate_vec", stim.priors); + if (setup.decoder->get_result_type() != + cudaq::qec::decoder::decode_result_type::decode_to_obs) { + std::cerr << "ERROR: composite trt_decoder must report decode_to_obs " + "when constructed with O.\n"; + return 1; } - cudaqx::heterogeneous_map trt_params; - trt_params.insert("onnx_load_path", onnx_path); - trt_params.insert("engine_save_path", engine_save_path); - trt_params.insert("batch_size", demo_cfg.batch_size); - trt_params.insert("use_cuda_graph", demo_cfg.use_cuda_graph); - trt_params.insert("global_decoder", std::string("pymatching")); - trt_params.insert("global_decoder_params", pm_params); - trt_params.insert("O", O); - - std::cout << "--- Initializing Composite TensorRT Decoder (" << config.label - << ") ---\n"; - std::cout << "[Setup] ONNX: " << onnx_path << "\n"; - std::cout << "[Setup] Engine save: " << engine_save_path << "\n"; + std::cout << "[Setup] Init mode: " << setup.init_mode << "\n"; + if (!setup.config_yaml_path.empty()) + std::cout << "[Setup] YAML: " << setup.config_yaml_path << "\n"; + if (!setup.onnx_path.empty()) + std::cout << "[Setup] ONNX: " << setup.onnx_path << "\n"; + if (!setup.engine_save_path.empty()) + std::cout << "[Setup] Engine: " << setup.engine_save_path << "\n"; std::cout << "[Setup] Data dir: " << demo_cfg.data_dir << "\n"; - std::cout << "[Setup] H: " << stim.H.nrows << " x " << stim.H.ncols - << " (" << stim.H.nnz << " nnz)\n"; - std::cout << "[Setup] O: " << stim.O.nrows << " x " << stim.O.ncols - << " (" << stim.O.nnz << " nnz)\n"; + std::cout << "[Setup] H: " << setup.H_rows << " x " << setup.H_cols + << " (" << setup.H_nnz << " nnz)\n"; + std::cout << "[Setup] O: " << setup.O_rows << " x " << setup.O_cols + << " (" << setup.O_nnz << " nnz)\n"; std::cout << "[Setup] Samples: " << test_data.num_samples << ", detectors/sample=" << test_data.num_detectors << ", observables/sample=" << test_data.num_observables << "\n"; - std::cout << "[Setup] PyMatching: merge_strategy=smallest_weight" - << (stim.priors.empty() ? ", no priors\n" : ", priors loaded\n"); - - std::unique_ptr composite_decoder; - try { - composite_decoder = cudaq::qec::decoder::get("trt_decoder", H, trt_params); - } catch (const std::exception &e) { - std::cerr << "ERROR: failed to create composite trt_decoder: " << e.what() - << "\n"; - return 1; + std::cout << "[Setup] PyMatching: "; + if (use_yaml_config) { + std::cout << "from YAML global_decoder_params\n"; + } else { + std::cout << "merge_strategy=smallest_weight" + << (stim->priors.empty() ? ", no priors\n" : ", priors loaded\n"); } const int available_samples = static_cast(test_data.num_samples); @@ -437,30 +642,31 @@ int main(int argc, char *argv[]) { std::cout << "[Run] Decoding " << n_samples << " sample(s) through composite TRT+PyMatching decoder...\n"; - CompositeStats stats = run_composite_decoder(*composite_decoder, test_data, - n_samples, stim.O.nrows); + CompositeStats stats = + run_composite_decoder(*setup.decoder, test_data, n_samples, setup.O_rows); RawDiagnostics raw_stats; if (demo_cfg.raw_diagnostics) { cudaqx::heterogeneous_map raw_params; - if (file_exists(engine_save_path)) { - raw_params.insert("engine_load_path", engine_save_path); + if (!build_raw_trt_params(setup.trt_params, raw_params)) { + std::cerr << "[WARN] Raw TRT diagnostics skipped: no raw TRT model " + "source is available.\n"; } else { - std::cerr << "[WARN] Engine file was not found after composite init; " - "raw diagnostics will rebuild from ONNX.\n"; - raw_params.insert("onnx_load_path", onnx_path); - raw_params.insert("engine_save_path", engine_save_path); - } - raw_params.insert("batch_size", demo_cfg.batch_size); - raw_params.insert("use_cuda_graph", demo_cfg.use_cuda_graph); - - try { - auto raw_decoder = cudaq::qec::decoder::get("trt_decoder", H, raw_params); - raw_stats = - run_raw_diagnostics(*raw_decoder, test_data, stats.first_obs_pred, - n_samples, stim.O.nrows, stim.H.nrows); - } catch (const std::exception &e) { - std::cerr << "[WARN] Raw TRT diagnostics skipped: " << e.what() << "\n"; + if (setup.trt_params.contains("engine_save_path") && + !file_exists(setup.trt_params.get("engine_save_path"))) { + std::cerr << "[WARN] Engine file was not found after composite init; " + "raw diagnostics will rebuild from ONNX.\n"; + } + + try { + auto raw_decoder = + cudaq::qec::decoder::get("trt_decoder", setup.H, raw_params); + raw_stats = + run_raw_diagnostics(*raw_decoder, test_data, stats.first_obs_pred, + n_samples, setup.O_rows, setup.H_rows); + } catch (const std::exception &e) { + std::cerr << "[WARN] Raw TRT diagnostics skipped: " << e.what() << "\n"; + } } } @@ -495,7 +701,7 @@ int main(int argc, char *argv[]) { std::cout << std::fixed; std::cout << "\n================================================================\n"; - std::cout << " Composite TRT Decoder Benchmark: " << config.label << "\n"; + std::cout << " Composite TRT Decoder Benchmark: " << setup.label << "\n"; std::cout << "================================================================\n"; std::cout << " Submitted: " << n_samples << "\n"; @@ -549,7 +755,7 @@ int main(int argc, char *argv[]) { static_cast(raw_stats.total_residual_nonzero) / static_cast(raw_stats.decoded); double input_density = avg_input_nz / test_data.num_detectors; - double residual_density = avg_residual_nz / stim.H.nrows; + double residual_density = avg_residual_nz / setup.H_rows; double reduction = input_density > 0.0 ? (1.0 - residual_density / input_density) : 0.0; @@ -575,7 +781,7 @@ int main(int argc, char *argv[]) { << input_density << ")\n"; std::cout << std::setprecision(1); std::cout << " Residual density: " << avg_residual_nz << " / " - << stim.H.nrows << " (" << std::setprecision(4) + << setup.H_rows << " (" << std::setprecision(4) << residual_density << ")\n"; std::cout << std::setprecision(1); std::cout << " Reduction: " << reduction * 100.0 << "%\n"; diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index 672c4e4b..2ef4cbaa 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -165,7 +165,7 @@ TEST(DecoderPlugins, SingleErrorLutExample_DecodesSingletonColumnSyndromes) { constexpr std::size_t block_size = 3; constexpr std::size_t syndrome_size = 2; // | 1 1 0 | - // | 0 1 1 | — single-bit columns are weight-1 syndrome patterns. + // | 0 1 1 | - single-bit columns are weight-1 syndrome patterns. std::vector H_vec = {1, 1, 0, // row 0 0, 1, 1}; cudaqx::tensor H; @@ -878,3 +878,83 @@ error(0.05) D0 D1 EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text, opts), std::runtime_error); } + +// --------------------------------------------------------------------------- +// Tests for enqueue_syndrome decode_result_type routing +// --------------------------------------------------------------------------- + +// Verify that enqueue_syndrome uses decode() output directly as corrections +// when get_result_type() == decode_to_obs, bypassing the O_sparse projection. +TEST(EnqueueSyndrome, ObsFrameDecoderUsesResultDirectly) { + // H: 2 syndrome measurements, 4 physical errors + cudaqx::tensor H_tensor({2, 4}); + H_tensor.at({0, 0}) = 1; + H_tensor.at({1, 1}) = 1; + cudaqx::heterogeneous_map params; + params.insert("decode_to_obs", true); + auto dec = cudaq::qec::decoder::get("sample_decoder", H_tensor, params); + + // D_sparse maps the two enqueued syndrome bits directly to two detector bits. + dec->set_D_sparse(std::vector>{{0}, {1}}); + // Two observables; cols 0/1 are within block_size=4 for validation only. + dec->set_O_sparse(std::vector>{{0}, {1}}); + + bool did_decode = dec->enqueue_syndrome(std::vector{1, 0}); + EXPECT_TRUE(did_decode); + + const uint8_t *corr = dec->get_obs_corrections(); + EXPECT_EQ(corr[0], 1u); + EXPECT_EQ(corr[1], 0u); +} + +// Verify that corrections XOR-accumulate correctly across multiple shots and +// that clear_corrections() resets them between shots. +TEST(EnqueueSyndrome, ObsFrameMultiShotAccumulation) { + cudaqx::tensor H_tensor({2, 4}); + H_tensor.at({0, 0}) = 1; + H_tensor.at({1, 1}) = 1; + cudaqx::heterogeneous_map params; + params.insert("decode_to_obs", true); + auto dec = cudaq::qec::decoder::get("sample_decoder", H_tensor, params); + + dec->set_D_sparse(std::vector>{{0}, {1}}); + dec->set_O_sparse(std::vector>{{0}, {1}}); + + // Shot 1: obs[0]=1, obs[1]=0 -> corrections become [1, 0] + EXPECT_TRUE(dec->enqueue_syndrome(std::vector{1, 0})); + const uint8_t *corr = dec->get_obs_corrections(); + EXPECT_EQ(corr[0], 1u); + EXPECT_EQ(corr[1], 0u); + + // Shot 2 (no reset): obs[0]=1, obs[1]=1 -> corrections XOR to [0, 1] + EXPECT_TRUE(dec->enqueue_syndrome(std::vector{1, 1})); + corr = dec->get_obs_corrections(); + EXPECT_EQ(corr[0], 0u); + EXPECT_EQ(corr[1], 1u); + + // After clear, corrections reset to [0, 0] + dec->clear_corrections(); + corr = dec->get_obs_corrections(); + EXPECT_EQ(corr[0], 0u); + EXPECT_EQ(corr[1], 0u); +} + +// Verify that a result size mismatch against num_observables throws for +// decode_to_obs decoders. +TEST(EnqueueSyndrome, ObsFrameSizeMismatchThrows) { + cudaqx::tensor H_tensor({3, 4}); + H_tensor.at({0, 0}) = 1; + H_tensor.at({1, 1}) = 1; + H_tensor.at({2, 2}) = 1; + cudaqx::heterogeneous_map params; + params.insert("decode_to_obs", true); + auto dec = cudaq::qec::decoder::get("sample_decoder", H_tensor, params); + + dec->set_D_sparse(std::vector>{{0}, {1}, {2}}); + dec->set_O_sparse( + std::vector>{{0}, {1}}); // 2 observables + + // sample_decoder returns all three detector bits in decode_to_obs mode. + EXPECT_THROW(dec->enqueue_syndrome(std::vector{1, 0, 1}), + std::runtime_error); +} diff --git a/libs/qec/unittests/test_decoders_yaml.cpp b/libs/qec/unittests/test_decoders_yaml.cpp index aa7a37be..ccc2b870 100644 --- a/libs/qec/unittests/test_decoders_yaml.cpp +++ b/libs/qec/unittests/test_decoders_yaml.cpp @@ -6,6 +6,7 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ +#include "../lib/realtime/realtime_decoding.h" #include "cudaq/qec/decoder.h" #include "cudaq/qec/pcm_utils.h" #include "cudaq/qec/realtime/decoding_config.h" @@ -188,6 +189,131 @@ TEST(DecoderYAMLTest, SingleLUTDecoder) { test_decoder_creation(multi_config); } +cudaq::qec::decoding::config::decoder_config +create_test_decoder_config_trt(int id) { + cudaq::qec::decoding::config::decoder_config config = + create_test_empty_decoder_config(id); + config.type = "trt_decoder"; + + cudaqx::tensor O({2, config.block_size}); + O.at({0, 1}) = 1; + O.at({1, 3}) = 1; + config.O_sparse = cudaq::qec::pcm_to_sparse_vec(O); + + config.decoder_custom_args = + cudaq::qec::decoding::config::trt_decoder_config(); + auto &trt_config = std::get( + config.decoder_custom_args); + trt_config.onnx_load_path = "/tmp/predecoder.onnx"; + trt_config.engine_save_path = "/tmp/predecoder.engine"; + trt_config.precision = "best"; + trt_config.memory_workspace = 1ULL << 20; + trt_config.batch_size = 4; + trt_config.use_cuda_graph = false; + trt_config.global_decoder = "pymatching"; + auto pymatching_params = cudaq::qec::decoding::config::pymatching_config(); + pymatching_params.merge_strategy = "smallest_weight"; + pymatching_params.error_rate_vec = + std::vector(config.block_size, 0.1); + trt_config.global_decoder_params = pymatching_params; + + return config; +} + +TEST(DecoderYAMLTest, TrtDecoderConfigRoundTrip) { + cudaq::qec::decoding::config::multi_decoder_config multi_config; + multi_config.decoders.push_back(create_test_decoder_config_trt(0)); + + test_decoder_yaml_roundtrip(multi_config); + const auto &trt_config = + std::get( + multi_config.decoders[0].decoder_custom_args); + EXPECT_TRUE( + std::holds_alternative( + trt_config.global_decoder_params)); +} + +TEST(DecoderYAMLTest, TrtDecoderConfigToHeterogeneousMap) { + auto config = create_test_decoder_config_trt(0); + auto params = config.decoder_custom_args_to_heterogeneous_map(); + + EXPECT_EQ(params.get("onnx_load_path"), "/tmp/predecoder.onnx"); + EXPECT_EQ(params.get("engine_save_path"), + "/tmp/predecoder.engine"); + EXPECT_EQ(params.get("precision"), "best"); + EXPECT_EQ(params.get("memory_workspace"), 1ULL << 20); + EXPECT_EQ(params.get("batch_size"), 4u); + EXPECT_FALSE(params.get("use_cuda_graph")); + EXPECT_EQ(params.get("global_decoder"), "pymatching"); + + auto global_params = + params.get("global_decoder_params"); + EXPECT_EQ(global_params.get("merge_strategy"), + "smallest_weight"); + EXPECT_EQ(global_params.get>("error_rate_vec").size(), + config.block_size); +} + +TEST(DecoderYAMLTest, TrtDecoderRealtimeParamsIncludeObservableMatrix) { + auto config = create_test_decoder_config_trt(0); + auto params = cudaq::qec::decoding::host::prepare_decoder_params(config); + + auto O = params.get>("O"); + EXPECT_EQ(O.shape()[0], 2u); + EXPECT_EQ(O.shape()[1], config.block_size); + EXPECT_EQ(O.at({0, 1}), 1); + EXPECT_EQ(O.at({1, 3}), 1); + + auto global_params = + params.get("global_decoder_params"); + auto global_O = global_params.get>("O"); + EXPECT_EQ(global_O.shape()[0], 2u); + EXPECT_EQ(global_O.shape()[1], config.block_size); +} + +TEST(DecoderYAMLTest, TrtDecoderMonostateGlobalDecoderParams) { + auto config = create_test_decoder_config_trt(0); + auto &trt_config = std::get( + config.decoder_custom_args); + trt_config.global_decoder = "pymatching"; + trt_config.global_decoder_params = std::monostate{}; + + auto params = config.decoder_custom_args_to_heterogeneous_map(); + EXPECT_FALSE(params.contains("global_decoder_params")); + + cudaq::qec::decoding::config::multi_decoder_config multi_config; + multi_config.decoders.push_back(config); + test_decoder_yaml_roundtrip(multi_config); + + params = cudaq::qec::decoding::host::prepare_decoder_params(config); + EXPECT_TRUE(params.contains("global_decoder_params")); + EXPECT_TRUE(params.contains("O")); + + config.O_sparse.clear(); + params = cudaq::qec::decoding::host::prepare_decoder_params(config); + EXPECT_TRUE(params.contains("global_decoder_params")); + EXPECT_FALSE(params.contains("O")); +} + +TEST(DecoderYAMLTest, TrtDecoderParamsWithoutDecoderThrows) { + cudaqx::heterogeneous_map map; + map.insert("onnx_load_path", std::string("/tmp/predecoder.onnx")); + cudaqx::heterogeneous_map gd_params; + gd_params.insert("merge_strategy", std::string("smallest_weight")); + map.insert("global_decoder_params", gd_params); + EXPECT_THROW( + cudaq::qec::decoding::config::trt_decoder_config::from_heterogeneous_map( + map), + std::runtime_error); + + cudaq::qec::decoding::config::trt_decoder_config trt_config; + trt_config.onnx_load_path = "/tmp/predecoder.onnx"; + auto pymatching_params = cudaq::qec::decoding::config::pymatching_config(); + pymatching_params.merge_strategy = "smallest_weight"; + trt_config.global_decoder_params = pymatching_params; + EXPECT_THROW(trt_config.to_heterogeneous_map(), std::runtime_error); +} + TEST(DecoderYAMLTest, SlidingWindowDecoder) { std::size_t n_rounds = 4; std::size_t n_errs_per_round = 30; diff --git a/libs/qec/utils/replay_decoder_logs.py b/libs/qec/utils/replay_decoder_logs.py index 7a7d934d..803a4cf6 100644 --- a/libs/qec/utils/replay_decoder_logs.py +++ b/libs/qec/utils/replay_decoder_logs.py @@ -39,7 +39,8 @@ def sparse_to_dense(sparse_list, num_rows, num_cols, dtype=numpy.uint8): # decoder is created, a dummy decode call is made to "warm up" the decoder, so # you may see more decode calls than shots. def parse_decoder_log(decoder_log_file, log_detectors_sparse, log_errors_sparse, - log_observables_dense, decoder_id_list): + log_observables_sparse, log_observables_dense, + log_result_types, decoder_id_list): # running id of the last decoder seen (needed since the decoder id is not # included in the 1 very verbose decode log message). last_decoder_id = -1 @@ -55,32 +56,48 @@ def parse_decoder_log(decoder_log_file, log_detectors_sparse, log_errors_sparse, line = line.split("[DecoderStats]")[1] # print(line) if "InputDetectors:" in line: # this is a decode call - line = line.split(" ") - for elem in line: + fields = {} + for elem in line.split(" "): if ":" in elem: - key, value = elem.split(":") - # print(key, value) - if key == "InputDetectors": - if value == "": - log_detectors_sparse.append([]) - else: - log_detectors_sparse.append( - [int(x) for x in value.split(",")]) - if last_decoder_id == -1: - print( - f"Error: last_decoder_id is -1. This is a fatal error processing the log file." - ) - exit(1) - decoder_id_list.append(last_decoder_id) - elif key == "Errors": - if value == "": - log_errors_sparse.append([]) - else: - log_errors_sparse.append( - [int(x) for x in value.split(",")]) - elif key == "ObservableCorrectionsThisCall": - log_observables_dense.append( - [int(x) for x in value.split(",")]) + key, value = elem.split(":", 1) + fields[key] = value + + value = fields["InputDetectors"] + if value == "": + log_detectors_sparse.append([]) + else: + log_detectors_sparse.append( + [int(x) for x in value.split(",")]) + if last_decoder_id == -1: + print( + f"Error: last_decoder_id is -1. This is a fatal error processing the log file." + ) + exit(1) + decoder_id_list.append(last_decoder_id) + + value = fields.get("Errors", "") + if value == "": + log_errors_sparse.append([]) + else: + log_errors_sparse.append( + [int(x) for x in value.split(",")]) + + value = fields.get("Observables", "") + if value == "": + log_observables_sparse.append([]) + else: + log_observables_sparse.append( + [int(x) for x in value.split(",")]) + + value = fields["ObservableCorrectionsThisCall"] + log_observables_dense.append( + [int(x) for x in value.split(",")]) + + value = fields.get("ResultType", "errs") + result_types = {x for x in value.split(",") if x} + if not result_types: + result_types = {"errs"} + log_result_types.append(result_types) # ---------------------------------------------------------------------------- # @@ -115,6 +132,9 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): decoder_custom_args[key] = float(value) elif type(value) == bool: decoder_custom_args[key] = bool(value) + # Replaying an obs-frame decoder reconstructs the full composite + # decoder here, so trt_decoder replay needs TensorRT, a GPU, and + # the referenced ONNX/engine artifacts. LUT replay is lighter. decoders.append( qec.get_decoder(decoder['type'], H, **decoder_custom_args)) print(f"Decoder {decoder_id} created.") @@ -149,14 +169,17 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): log_detectors_sparse = [] # Detection events seen in the log file. log_errors_sparse = [] # Errors seen in the log file. replay_errors_sparse = [] # Errors seen in the replay. +log_observables_sparse = [] # Observable results seen in the log file. log_observables_dense = [] # Observable flips seen in the log file. replay_observables_dense = [] # Observable flips calculated in the replay. +log_result_types = [] # ResultType tokens seen in each log decode call. decoders = [] O_per_decoder = [] parse_decoder_log(args.decoder_log, log_detectors_sparse, log_errors_sparse, - log_observables_dense, decoder_id_list) + log_observables_sparse, log_observables_dense, + log_result_types, decoder_id_list) parse_decoder_config(args.config, decoders, O_per_decoder) # Basic error checking @@ -171,6 +194,7 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): # Now loop through the syndromes and compare the results. decode_call_idx = 0 replay_error_mismatch = 0 +replay_observable_result_mismatch = 0 replay_observable_mismatch = 0 print(f'Processing {len(log_detectors_sparse)} decode calls.') for s, o in zip(log_detectors_sparse, log_observables_dense): @@ -181,28 +205,60 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): for idx in s: syndrome[idx] = 1 result = decoders[decoder_id_list[decode_call_idx]].decode(syndrome) - dec_err_sparse = [ + result_types = log_result_types[decode_call_idx] + + mismatch_flag = False + decoded_sparse = [ i for i in range(len(result.result)) if result.result[i] > 0.5 ] - replay_errors_sparse.append(dec_err_sparse) - mismatch_flag = False - if dec_err_sparse != log_errors_sparse[decode_call_idx]: - replay_error_mismatch += 1 - mismatch_flag = True - if args.verbose_on_mismatch: - print( - f"Replay mismatch in error in decode_call_idx {decode_call_idx}" - ) - print(f"Decoded errors : {dec_err_sparse}") - print(f"Expected errors: {log_errors_sparse[decode_call_idx]}") - dec_err_dense = numpy.array(result.result, dtype=numpy.uint8) - O_replay = ( - O_per_decoder[decoder_id_list[decode_call_idx]] @ dec_err_dense % - 2).astype(numpy.uint8) + if "errs" in result_types: + dec_err_sparse = decoded_sparse + replay_errors_sparse.append(dec_err_sparse) + if dec_err_sparse != log_errors_sparse[decode_call_idx]: + replay_error_mismatch += 1 + mismatch_flag = True + if args.verbose_on_mismatch: + print( + f"Replay mismatch in error in decode_call_idx {decode_call_idx}" + ) + print(f"Decoded errors : {dec_err_sparse}") + print(f"Expected errors: {log_errors_sparse[decode_call_idx]}") + dec_err_dense = numpy.array(result.result, dtype=numpy.uint8) + O_replay = ( + O_per_decoder[decoder_id_list[decode_call_idx]] @ dec_err_dense % + 2).astype(numpy.uint8) + decoded_observables_sparse = [ + i for i in range(len(O_replay)) if O_replay[i] + ] + elif "obs" in result_types: + replay_errors_sparse.append([]) + O_replay = numpy.array([1 if x > 0.5 else 0 for x in result.result], + dtype=numpy.uint8) + decoded_observables_sparse = decoded_sparse + else: + print(f"Error: unsupported ResultType set {sorted(result_types)} " + f"in decode_call_idx {decode_call_idx}.") + exit(1) + + if "obs" in result_types: + expected_observables_sparse = log_observables_sparse[decode_call_idx] + if decoded_observables_sparse != expected_observables_sparse: + replay_observable_result_mismatch += 1 + mismatch_flag = True + if args.verbose_on_mismatch: + print( + f"Replay mismatch in observable result in decode_call_idx {decode_call_idx}" + ) + print( + f"Decoded observable result : {decoded_observables_sparse}") + print( + f"Expected observable result: {expected_observables_sparse}" + ) + replay_observables_dense.append(O_replay) O_log = numpy.array(log_observables_dense[decode_call_idx], dtype=numpy.uint8) - if (O_replay != O_log).any(): + if O_replay.shape != O_log.shape or (O_replay != O_log).any(): replay_observable_mismatch += 1 mismatch_flag = True if args.verbose_on_mismatch: @@ -220,6 +276,9 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): print() print(f"Number of error mismatches during replay: {replay_error_mismatch}") +print( + f"Number of observable result mismatches during replay: {replay_observable_result_mismatch}" +) print( f"Number of observable mismatches during replay: {replay_observable_mismatch}" ) diff --git a/test_surface_code_trt.py b/test_surface_code_trt.py new file mode 100644 index 00000000..83145636 --- /dev/null +++ b/test_surface_code_trt.py @@ -0,0 +1,91 @@ +# ============================================================================ # +# Copyright (c) 2026 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +# This is a draft test script that should be improved and run by the CI (where possible). + +import stim +import cudaq_qec as qec +from beliefmatching.belief_matching import detector_error_model_to_check_matrices +import numpy as np +import time + +d = 13 +e = 0.003 +# Generate a Stim circuit for the surface code +# circuit = stim.Circuit.generated( +# "surface_code:rotated_memory_z", +# distance=d, +# rounds=1*d, +# after_clifford_depolarization=e, +# before_round_data_depolarization=e, +# before_measure_flip_probability=e, +# after_reset_flip_probability=e, +# ) + +circuit = stim.Circuit.from_file("/workspaces/pre-decoder/circuit_Z.stim") + +# Get the Detector Error Model (DEM) from the Stim circuit +dem = circuit.detector_error_model(decompose_errors=True, + approximate_disjoint_errors=True) +# print(dem) +# exit() + +matrices = detector_error_model_to_check_matrices(dem) +# H = matrices.check_matrix +# O = matrices.observables_matrix +H = matrices.edge_check_matrix +O = matrices.edge_observables_matrix +# priors = matrices.priors +edge_probs = matrices.hyperedge_to_edge_matrix @ matrices.priors +eps = 1e-14 +edge_probs[edge_probs > 1 - eps] = 1 - eps +edge_probs[edge_probs < eps] = eps +priors = edge_probs +print(f"Shape of priors: {priors.shape}") +print(f"Shape of H: {H.shape}") +print(f"Shape of O: {O.shape}") +print( + f"Shape of matrices.hyperedge_to_edge_matrix: {matrices.hyperedge_to_edge_matrix.shape}" +) +H_dense = H.todense(order="C").astype(np.uint8) +O_dense = O.todense(order="C").astype(np.uint8) + +# If there is a global decoder, the H_dense will be passed to the global +# decoder. Additionally, when there is a global decoder, it is assumed that the +# last portion of the syndrome corresponds to the boundary detectors. +dec = qec.get_decoder( + "trt_decoder", + H_dense, + O=O_dense, + #engine_load_path="/workspaces/pre-decoder/predecoder_memory_d13_T13_Z.engine", + onnx_load_path="/workspaces/pre-decoder/predecoder_memory_d13_T13_Z.onnx", + use_cuda_graph=False, + batch_size=7, + global_decoder="pymatching", + global_decoder_params={ + "merge_strategy": "independent", + "O": O_dense, + }) + +sampler = circuit.compile_detector_sampler(seed=42) +dets, obs = sampler.sample(2048, separate_observables=True) +# Print the shape of dets and obs +print(f"Shape of dets: {dets.shape}") +print(f"Shape of obs: {obs.shape}") + +results = dec.decode_batch(dets) +# for i in range(min(20, len(results))): +# print(f"Result {i}: {results[i].result[0]}, len: {len(results[i].result)}, obs: {obs[i]}") + +num_mismatches = 0 +for i in range(min(len(results), len(obs))): + if results[i].result[0] != obs[i]: + num_mismatches += 1 +print( + f"Number of mismatches: {num_mismatches} out of {min(len(results), len(obs))}" +)