From 1cd5d1004ec79d62dd106fb598dd4d2e7b15f2b6 Mon Sep 17 00:00:00 2001 From: Ben Howe Date: Wed, 29 Apr 2026 23:57:20 +0000 Subject: [PATCH 01/19] (qec): Extend trt_decoder with global decoder chaining + observables output Add a "predecoder" execution mode to the TensorRT decoder so it can be chained with a second decoder (e.g. PyMatching) and return logical-frame observables directly. The TRT model is assumed to emit a single output that concatenates [pre_L (num_observables entries), residual_dets (rest)]. New constructor parameters: - "batch_size": required when the ONNX model has a dynamic batch dim. Used to size the optimization profile and pre-allocate I/O buffers. - "global_decoder" + "global_decoder_params": optional decoder name and params for a follow-up decoder run on the residual_dets portion of the TRT output. Created with the same H passed to trt_decoder. - "O": observables matrix (num_observables x block_size). Enables decode()/decode_batch() to return the predicted logical frame. Number of observables is inferred from O.shape()[0]. Decode behavior matrix: - no global_decoder, no O -> raw TRT output (unchanged). - no global_decoder, O -> return the pre_L prefix only. - global_decoder, no O -> entire output -> global_decoder.result. - global_decoder, O -> residual -> global_decoder; return pre_L XOR global_decoder.logical_frame. Constructor validation when O is set: - output_size_per_sample >= num_observables, and - when global_decoder_ is set, output_size_per_sample == num_observables + global_decoder.syndrome_size. Other changes: - Dynamic batch support: setInputShape per call when the model's batch dim is -1; ONNX builder now installs a min/opt/max optimization profile when "batch_size" is provided. - Split decode_batch into a typed decode_batch_impl for cleaner dtype dispatch (engine I/O dtypes float32 / uint8 unchanged). - Better INFO logging: total non-zero input vs residual detector counts per batch to help diagnose predecoder behavior. Signed-off-by: Ben Howe --- .../plugins/trt_decoder/trt_decoder.cpp | 418 +++++++++++++++--- test_surface_code_trt.py | 91 ++++ 2 files changed, 441 insertions(+), 68 deletions(-) create mode 100644 test_surface_code_trt.py diff --git a/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp b/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp index b31883ea..4d971171 100644 --- a/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp +++ b/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp @@ -14,6 +14,8 @@ #include #include #include +#include +#include #include #include @@ -117,6 +119,22 @@ static Logger gLogger; /// model in the desired precision instead. /// - "memory_workspace": Memory workspace size in bytes (optional, default: /// 1GB) +/// - "batch_size": Required when the ONNX model has a dynamic batch dim +/// (-1). Used to size the optimization profile and I/O buffers. +/// - "global_decoder": Optional name of a decoder to run after TRT +/// (e.g. DEM decoder). The TRT model is assumed to have detectors as +/// inputs and either (a) residual detectors as the only output, or +/// (b) when "O" is also provided, the concatenation [pre_L, +/// residual_dets] as the only output. +/// - "global_decoder_params": Optional parameters for the global decoder. The +/// decoder is created with the same H passed to the trt_decoder constructor. +/// - "O": Observables matrix (num_observables x block_size). Calls to +/// decode() and decode_batch() will return the logical frame of the +/// observables. Requires that the TRT model emits the concatenation +/// [pre_L (num_observables entries), residual_dets (rest)] as a single +/// output. When a global_decoder is also set, the final result is +/// pre_L XOR global_decoder(residual_dets); otherwise only the pre_L +/// prefix is returned. /// /// Note: Only one of onnx_load_path or engine_load_path should be specified, /// not both. @@ -127,6 +145,18 @@ namespace cudaq::qec { // ============================================================================ namespace { + +// Helpers for templated I/O: binarize TRT output (float or uint8) to 0/1 +// for counting and for the decoder API (float_t). +inline bool trt_io_nonzero(float val) { return val >= 0.5f; } +inline bool trt_io_nonzero(uint8_t val) { return val != 0; } +inline float_t trt_io_to_binary(float val) { + return (val >= 0.5f) ? static_cast(1.0) : static_cast(0.0); +} +inline float_t trt_io_to_binary(uint8_t val) { + return (val != 0) ? static_cast(1.0) : static_cast(0.0); +} + // Traditional TensorRT execution without CUDA graphs struct TraditionalExecutor { void execute(nvinfer1::IExecutionContext *context, cudaStream_t stream, @@ -378,6 +408,17 @@ class trt_decoder : public decoder { size_t syndrome_size_per_sample_ = 0; size_t output_size_per_sample_ = 0; + // Optional global decoder (e.g. DEM decoder) applied after TRT + postprocess + std::unique_ptr global_decoder_; + cudaqx::heterogeneous_map global_decoder_params_; + + // When true, decode()/decode_batch() return the predicted logical-frame + // observables. The TRT model must emit the concatenation + // [pre_L (num_observables_ entries), residual_dets (rest)] as its single + // output. Enabled by passing the "O" (observables) parameter. + bool decode_to_observables_ = false; + size_t num_observables_ = 0; + public: trt_decoder(const cudaqx::tensor &H, const cudaqx::heterogeneous_map ¶ms); @@ -398,6 +439,12 @@ class trt_decoder : public decoder { private: void check_cuda(); + + /// Typed decode_batch: IoType matches the engine's I/O dtype + /// (currently float or uint8_t). + template + std::vector + decode_batch_impl(const std::vector> &syndromes) const; }; // ============================================================================ @@ -419,11 +466,27 @@ struct trt_decoder::Impl { void *buffers[2] = {nullptr, nullptr}; cudaStream_t stream; + // Dynamic batch: set input shape before each inference (only when dynamic) + bool has_dynamic_batch_ = false; + std::string input_name_; + nvinfer1::Dims input_dims_{0, {}}; + // Executor (chosen once at construction, never changes) std::variant executor; - // Execute inference (variant dispatch) - void execute_inference() { + /// actual_batch is used only when has_dynamic_batch_. + void execute_inference(size_t actual_batch = 0) { + if (has_dynamic_batch_) { + nvinfer1::Dims dims; + dims.nbDims = input_dims_.nbDims; + dims.d[0] = static_cast(actual_batch); + for (int i = 1; i < input_dims_.nbDims; ++i) + dims.d[i] = input_dims_.d[i]; + if (!context->setInputShape(input_name_.c_str(), dims)) { + throw std::runtime_error("setInputShape failed for batch size " + + std::to_string(actual_batch)); + } + } std::visit( [&](auto &exec) { exec.execute(context.get(), stream, buffers[input_index], @@ -543,6 +606,12 @@ trt_decoder::trt_decoder(const cudaqx::tensor &H, if (impl_->input_index == -1 || impl_->output_index == -1) { throw std::runtime_error("Failed to identify input/output tensors"); } + if (n_bindings != 2) { + throw std::runtime_error( + "TensorRT decoder expects exactly 2 I/O tensors (1 input + 1 " + "output), got " + + std::to_string(n_bindings)); + } const char *inputTensorName = impl_->engine->getIOTensorName(impl_->input_index); @@ -573,27 +642,55 @@ trt_decoder::trt_decoder(const cudaqx::tensor &H, auto inputDims = impl_->engine->getTensorShape(inputTensorName); - // Extract batch size from first dimension - // If first dimension is -1, it's dynamic (not supported for batching) - if (inputDims.nbDims > 0 && inputDims.d[0] > 0) { - model_batch_size_ = static_cast(inputDims.d[0]); - } else { - model_batch_size_ = 1; - } + impl_->has_dynamic_batch_ = (inputDims.nbDims > 0 && inputDims.d[0] == -1); - // Calculate total input size and per-sample size - impl_->input_size = 1; - for (int j = 0; j < inputDims.nbDims; ++j) - impl_->input_size *= inputDims.d[j]; + if (impl_->has_dynamic_batch_) { + if (!params.contains("batch_size")) { + // FIXME - should we just default to 1 or throw an error? + throw std::runtime_error( + "TensorRT decoder: model has dynamic batch dimension but " + "'batch_size' was not set in params (required for allocation)"); + } + model_batch_size_ = params.get("batch_size"); + if (model_batch_size_ < 1) { + throw std::runtime_error( + "TensorRT decoder: batch_size must be >= 1, got " + + std::to_string(model_batch_size_)); + } + impl_->input_name_ = impl_->engine->getIOTensorName(impl_->input_index); + impl_->input_dims_ = inputDims; + syndrome_size_per_sample_ = 1; + for (int j = 1; j < inputDims.nbDims; ++j) + syndrome_size_per_sample_ *= inputDims.d[j]; + impl_->input_size = + static_cast(model_batch_size_ * syndrome_size_per_sample_); + } else { + if (inputDims.nbDims > 0 && inputDims.d[0] > 0) { + model_batch_size_ = static_cast(inputDims.d[0]); + } else { + model_batch_size_ = 1; + } - syndrome_size_per_sample_ = impl_->input_size / model_batch_size_; + // Calculate total input size and per-sample size + impl_->input_size = 1; + for (int j = 0; j < inputDims.nbDims; ++j) + impl_->input_size *= inputDims.d[j]; + syndrome_size_per_sample_ = impl_->input_size / model_batch_size_; + } auto outputDims = impl_->engine->getTensorShape(outputTensorName); - impl_->output_size = 1; - for (int j = 0; j < outputDims.nbDims; ++j) - impl_->output_size *= outputDims.d[j]; - - output_size_per_sample_ = impl_->output_size / model_batch_size_; + output_size_per_sample_ = 1; + for (int j = 1; j < outputDims.nbDims; ++j) + output_size_per_sample_ *= (outputDims.d[j] > 0 ? outputDims.d[j] : 1); + if (outputDims.nbDims > 0 && outputDims.d[0] > 0) { + impl_->output_size = 1; + for (int j = 0; j < outputDims.nbDims; ++j) + impl_->output_size *= outputDims.d[j]; + output_size_per_sample_ = impl_->output_size / model_batch_size_; + } else { + impl_->output_size = + static_cast(model_batch_size_ * output_size_per_sample_); + } CUDAQ_INFO("TensorRT model configuration: batch_size={}, " "syndrome_size_per_sample={}, output_size_per_sample={}", @@ -653,6 +750,67 @@ trt_decoder::trt_decoder(const cudaqx::tensor &H, CUDAQ_INFO("TensorRT decoder initialized with traditional execution"); } + // Optional global decoder (e.g. DEM decoder), similar to sliding_window's + // inner_decoder. When set, decode_batch will run: syndrome->trainX->TRT + // ->postprocess->global_decoder->results. + if (params.contains("global_decoder") && + params.contains("global_decoder_params")) { + std::string global_decoder_name = + params.get("global_decoder"); + global_decoder_params_ = + params.get("global_decoder_params"); + if (!global_decoder_name.empty()) { + global_decoder_ = + decoder::get(global_decoder_name, H, global_decoder_params_); + CUDAQ_INFO("TensorRT decoder: global_decoder '{}' attached", + global_decoder_name); + } + } + + if (params.contains("O")) { + auto O = params.get>("O"); + if (O.rank() != 2) { + throw std::runtime_error( + "trt_decoder: O must be a 2-dimensional tensor (num_observables x " + "block_size)"); + } + if (O.shape()[1] != block_size) { + throw std::runtime_error( + "trt_decoder: O second dimension must equal H block_size (got " + + std::to_string(O.shape()[1]) + ", block_size " + + std::to_string(block_size) + ")"); + } + decode_to_observables_ = true; + num_observables_ = O.shape()[0]; + + // The TRT model output must encode [pre_L (num_observables_ entries), + // residual_dets (rest)]. Validate sizing where we can. + if (output_size_per_sample_ < num_observables_) { + throw std::runtime_error( + "trt_decoder: TRT output_size_per_sample (" + + std::to_string(output_size_per_sample_) + + ") is smaller than num_observables (" + + std::to_string(num_observables_) + + "); model output must be [pre_L, residual_dets]."); + } + if (global_decoder_) { + const size_t expected = + num_observables_ + global_decoder_->get_syndrome_size(); + if (output_size_per_sample_ != expected) { + throw std::runtime_error( + "trt_decoder: TRT output_size_per_sample (" + + std::to_string(output_size_per_sample_) + + ") must equal num_observables + global_decoder.syndrome_size " + "(" + + std::to_string(expected) + + ") for the [pre_L, residual_dets] split."); + } + } + CUDAQ_INFO("TensorRT decoder: decode_to_observables enabled " + "(num_observables={})", + num_observables_); + } + // Decoder is now fully configured and ready for inference decoder_ready_ = true; @@ -719,14 +877,6 @@ trt_decoder::decode_batch(const std::vector> &syndromes) { } } - // Check if number of syndromes is an integral multiple of batch size - if (syndromes.size() % model_batch_size_ != 0) { - throw std::runtime_error( - "Number of syndromes (" + std::to_string(syndromes.size()) + - ") must be an integral multiple of the model batch size (" + - std::to_string(model_batch_size_) + ")"); - } - if (!decoder_ready_) { // Return unconverged results if decoder is not ready CUDAQ_WARN( @@ -735,66 +885,158 @@ trt_decoder::decode_batch(const std::vector> &syndromes) { syndromes.size()); std::vector results(syndromes.size()); + const size_t result_size = + decode_to_observables_ ? num_observables_ : output_size_per_sample_; for (auto &result : results) { result.converged = false; - result.result.resize(output_size_per_sample_, 0.0); + result.result.resize(result_size, 0.0); } return results; } + // Dispatch on the actual engine input dtype (uint8 or float). + if (impl_->input_dtype == nvinfer1::DataType::kUINT8) + return decode_batch_impl(syndromes); + return decode_batch_impl(syndromes); +} + +template +std::vector trt_decoder::decode_batch_impl( + const std::vector> &syndromes) const { std::vector results; results.reserve(syndromes.size()); + // Output split for the predecoder pattern: when decode_to_observables_ is + // on the TRT output is [pre_L (num_observables_), residual_dets (rest)]. + const size_t pre_L_size = decode_to_observables_ ? num_observables_ : 0; + const size_t residual_size = output_size_per_sample_ - pre_L_size; + try { - // Process syndromes in batches of model_batch_size_ + size_t total_input_nonzero = 0; + size_t total_residual_nonzero = 0; + const bool log_residual_counts = + cudaq::details::should_log(cudaq::details::LogLevel::info); + for (size_t batch_start = 0; batch_start < syndromes.size(); batch_start += model_batch_size_) { - // Prepare input and copy to GPU (type dispatched from engine metadata) - auto copy_input = [&](auto type_tag) { - using T = decltype(type_tag); - std::vector input_host(impl_->input_size); - for (size_t batch_idx = 0; batch_idx < model_batch_size_; ++batch_idx) { - const auto &syndrome = syndromes[batch_start + batch_idx]; - for (size_t i = 0; i < syndrome_size_per_sample_; ++i) + const size_t actual_batch = + std::min(model_batch_size_, syndromes.size() - batch_start); + + // Prepare input batch. For float input we preserve soft (raw) values; + // for uint8 we binarize to 0/1. + std::vector input_host(impl_->input_size); + for (size_t batch_idx = 0; batch_idx < actual_batch; ++batch_idx) { + const auto &syndrome = syndromes[batch_start + batch_idx]; + for (size_t i = 0; i < syndrome_size_per_sample_; ++i) { + if constexpr (std::is_same_v) { input_host[batch_idx * syndrome_size_per_sample_ + i] = - static_cast(syndrome[i]); + static_cast(syndrome[i]); + } else { + input_host[batch_idx * syndrome_size_per_sample_ + i] = + static_cast(syndrome[i] >= 0.5f ? 1 : 0); + } + } + } + + HANDLE_CUDA_ERROR(cudaMemcpy( + impl_->buffers[impl_->input_index], input_host.data(), + impl_->input_size * sizeof(IoType), cudaMemcpyHostToDevice)); + + impl_->execute_inference(actual_batch); + + std::vector output_host(impl_->output_size); + HANDLE_CUDA_ERROR(cudaMemcpy( + output_host.data(), impl_->buffers[impl_->output_index], + impl_->output_size * sizeof(IoType), cudaMemcpyDeviceToHost)); + + if (log_residual_counts) { + const size_t input_elems = actual_batch * syndrome_size_per_sample_; + for (size_t i = 0; i < input_elems; ++i) + if (trt_io_nonzero(input_host[i])) + total_input_nonzero++; + // Count non-zero entries in just the residual portion of the output. + for (size_t batch_idx = 0; batch_idx < actual_batch; ++batch_idx) { + const IoType *row = output_host.data() + + batch_idx * output_size_per_sample_ + pre_L_size; + for (size_t i = 0; i < residual_size; ++i) + if (trt_io_nonzero(row[i])) + total_residual_nonzero++; } - HANDLE_CUDA_ERROR( - cudaMemcpy(impl_->buffers[impl_->input_index], input_host.data(), - impl_->input_size * sizeof(T), cudaMemcpyHostToDevice)); - }; - if (impl_->input_dtype == nvinfer1::DataType::kUINT8) - copy_input(uint8_t{}); - else - copy_input(float{}); - - // Execute inference - impl_->execute_inference(); - - // Copy output from GPU and extract results (type dispatched) - auto extract_output = [&](auto type_tag) { - using T = decltype(type_tag); - std::vector output_host(impl_->output_size); - HANDLE_CUDA_ERROR( - cudaMemcpy(output_host.data(), impl_->buffers[impl_->output_index], - impl_->output_size * sizeof(T), cudaMemcpyDeviceToHost)); - for (size_t batch_idx = 0; batch_idx < model_batch_size_; ++batch_idx) { + } + + if (global_decoder_) { + // Build the global-decoder input from the residual portion of the + // TRT output. + const size_t global_syndrome_size = + global_decoder_->get_syndrome_size(); + if (residual_size != global_syndrome_size) { + throw std::runtime_error( + "trt_decoder: residual portion of TRT output (" + + std::to_string(residual_size) + + ") != global_decoder.syndrome_size (" + + std::to_string(global_syndrome_size) + ")"); + } + std::vector> residual_soft( + actual_batch, std::vector(global_syndrome_size, 0.0f)); + for (size_t batch_idx = 0; batch_idx < actual_batch; ++batch_idx) { + const IoType *res = output_host.data() + + batch_idx * output_size_per_sample_ + pre_L_size; + float_t *out = residual_soft[batch_idx].data(); + for (size_t i = 0; i < global_syndrome_size; ++i) + out[i] = trt_io_to_binary(res[i]); + } + std::vector global_results = + global_decoder_->decode_batch(residual_soft); + + if (decode_to_observables_) { + // Combine pre_L (the prefix of the TRT output) with the global + // decoder's logical-frame prediction via XOR. + for (size_t batch_idx = 0; batch_idx < actual_batch; ++batch_idx) { + decoder_result combined; + combined.converged = global_results[batch_idx].converged; + combined.result.resize(num_observables_, 0.0f); + const IoType *pre_L_row = + output_host.data() + batch_idx * output_size_per_sample_; + const std::vector &g = global_results[batch_idx].result; + for (size_t k = 0; k < num_observables_; ++k) { + const uint8_t a = trt_io_nonzero(pre_L_row[k]) ? 1u : 0u; + const uint8_t b = (k < g.size() && g[k] >= 0.5f) ? 1u : 0u; + combined.result[k] = static_cast(a ^ b); + } + results.push_back(std::move(combined)); + } + } else { + for (decoder_result &r : global_results) + results.push_back(std::move(r)); + } + } else { + // No global decoder. If decode_to_observables_ is set, return only + // the pre_L prefix; otherwise return the full TRT output. + const size_t out_per_sample = + decode_to_observables_ ? num_observables_ : output_size_per_sample_; + for (size_t batch_idx = 0; batch_idx < actual_batch; ++batch_idx) { decoder_result result; result.converged = true; - result.result.resize(output_size_per_sample_); - std::transform( - output_host.begin() + batch_idx * output_size_per_sample_, - output_host.begin() + (batch_idx + 1) * output_size_per_sample_, - result.result.begin(), - [](T val) { return static_cast(val); }); + result.result.resize(out_per_sample); + const IoType *row = + output_host.data() + batch_idx * output_size_per_sample_; + for (size_t i = 0; i < out_per_sample; ++i) { + if constexpr (std::is_same_v) { + result.result[i] = static_cast(row[i]); + } else { + result.result[i] = trt_io_to_binary(row[i]); + } + } results.push_back(std::move(result)); } - }; - if (impl_->output_dtype == nvinfer1::DataType::kUINT8) - extract_output(uint8_t{}); - else - extract_output(float{}); + } + } + + if (log_residual_counts) { + CUDAQ_INFO("TRT decoder: total non-zero input detectors = {}, total " + "non-zero residual detectors = {}", + total_input_nonzero, total_residual_nonzero); } } catch (const std::exception &e) { @@ -808,6 +1050,12 @@ trt_decoder::decode_batch(const std::vector> &syndromes) { return results; } +// Explicit instantiations for the supported single-output engine I/O dtypes. +template std::vector trt_decoder::decode_batch_impl( + const std::vector> &) const; +template std::vector trt_decoder::decode_batch_impl( + const std::vector> &) const; + trt_decoder::~trt_decoder() = default; void trt_decoder::check_cuda() { @@ -928,6 +1176,40 @@ build_engine_from_onnx(const std::string &onnx_model_path, } parse_precision(precision, config.get()); + // The following is required for cases when using .onnx file with dynamic + // batch dimension. + if (params.contains("batch_size") && network->getNbInputs() > 0) { + nvinfer1::ITensor *input = network->getInput(0); + nvinfer1::Dims dims = input->getDimensions(); + if (dims.nbDims > 0 && dims.d[0] == -1) { + const size_t batch_size = params.get("batch_size"); + if (batch_size < 1) { + throw std::runtime_error("batch_size must be >= 1, got " + + std::to_string(batch_size)); + } + nvinfer1::IOptimizationProfile *profile = + builder->createOptimizationProfile(); + nvinfer1::Dims minDims = dims; + nvinfer1::Dims optDims = dims; + nvinfer1::Dims maxDims = dims; + minDims.d[0] = 1; + optDims.d[0] = static_cast(batch_size); + maxDims.d[0] = static_cast(batch_size); + if (!profile->setDimensions( + input->getName(), nvinfer1::OptProfileSelector::kMIN, minDims) || + !profile->setDimensions( + input->getName(), nvinfer1::OptProfileSelector::kOPT, optDims) || + !profile->setDimensions( + input->getName(), nvinfer1::OptProfileSelector::kMAX, maxDims)) { + throw std::runtime_error( + "Failed to set optimization profile dimensions for batch"); + } + config->addOptimizationProfile(profile); + CUDAQ_INFO("TensorRT optimization profile: batch min=1, opt=max={}", + batch_size); + } + } + // Build engine auto engine = std::unique_ptr( builder->buildEngineWithConfig(*network, *config)); diff --git a/test_surface_code_trt.py b/test_surface_code_trt.py new file mode 100644 index 00000000..83145636 --- /dev/null +++ b/test_surface_code_trt.py @@ -0,0 +1,91 @@ +# ============================================================================ # +# Copyright (c) 2026 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +# This is a draft test script that should be improved and run by the CI (where possible). + +import stim +import cudaq_qec as qec +from beliefmatching.belief_matching import detector_error_model_to_check_matrices +import numpy as np +import time + +d = 13 +e = 0.003 +# Generate a Stim circuit for the surface code +# circuit = stim.Circuit.generated( +# "surface_code:rotated_memory_z", +# distance=d, +# rounds=1*d, +# after_clifford_depolarization=e, +# before_round_data_depolarization=e, +# before_measure_flip_probability=e, +# after_reset_flip_probability=e, +# ) + +circuit = stim.Circuit.from_file("/workspaces/pre-decoder/circuit_Z.stim") + +# Get the Detector Error Model (DEM) from the Stim circuit +dem = circuit.detector_error_model(decompose_errors=True, + approximate_disjoint_errors=True) +# print(dem) +# exit() + +matrices = detector_error_model_to_check_matrices(dem) +# H = matrices.check_matrix +# O = matrices.observables_matrix +H = matrices.edge_check_matrix +O = matrices.edge_observables_matrix +# priors = matrices.priors +edge_probs = matrices.hyperedge_to_edge_matrix @ matrices.priors +eps = 1e-14 +edge_probs[edge_probs > 1 - eps] = 1 - eps +edge_probs[edge_probs < eps] = eps +priors = edge_probs +print(f"Shape of priors: {priors.shape}") +print(f"Shape of H: {H.shape}") +print(f"Shape of O: {O.shape}") +print( + f"Shape of matrices.hyperedge_to_edge_matrix: {matrices.hyperedge_to_edge_matrix.shape}" +) +H_dense = H.todense(order="C").astype(np.uint8) +O_dense = O.todense(order="C").astype(np.uint8) + +# If there is a global decoder, the H_dense will be passed to the global +# decoder. Additionally, when there is a global decoder, it is assumed that the +# last portion of the syndrome corresponds to the boundary detectors. +dec = qec.get_decoder( + "trt_decoder", + H_dense, + O=O_dense, + #engine_load_path="/workspaces/pre-decoder/predecoder_memory_d13_T13_Z.engine", + onnx_load_path="/workspaces/pre-decoder/predecoder_memory_d13_T13_Z.onnx", + use_cuda_graph=False, + batch_size=7, + global_decoder="pymatching", + global_decoder_params={ + "merge_strategy": "independent", + "O": O_dense, + }) + +sampler = circuit.compile_detector_sampler(seed=42) +dets, obs = sampler.sample(2048, separate_observables=True) +# Print the shape of dets and obs +print(f"Shape of dets: {dets.shape}") +print(f"Shape of obs: {obs.shape}") + +results = dec.decode_batch(dets) +# for i in range(min(20, len(results))): +# print(f"Result {i}: {results[i].result[0]}, len: {len(results[i].result)}, obs: {obs[i]}") + +num_mismatches = 0 +for i in range(min(len(results), len(obs))): + if results[i].result[0] != obs[i]: + num_mismatches += 1 +print( + f"Number of mismatches: {num_mismatches} out of {min(len(results), len(obs))}" +) From 84b56f58a489b0ef31fa92321fad9c928e7f0a7b Mon Sep 17 00:00:00 2001 From: Scott Thornton Date: Fri, 1 May 2026 20:16:56 +0000 Subject: [PATCH 02/19] Add composite TRT decoder realtime demo Add a realtime test/demo that initializes the TensorRT decoder from an ONNX predecoder model with PyMatching configured as the global decoder. The driver loads detector, observable, parity-check, observable, and prior data from the Stim export bundle, decodes samples through the composite TRT+PyMatching path, and reports latency, throughput, correctness, and residual-syndrome diagnostics. Register the new test_trt_decoder_composite target when TensorRT, realtime, and the TRT decoder plugin are available. Signed-off-by: Scott Thornton --- libs/qec/unittests/realtime/CMakeLists.txt | 66 ++ .../realtime/test_trt_decoder_composite.cpp | 588 ++++++++++++++++++ 2 files changed, 654 insertions(+) create mode 100644 libs/qec/unittests/realtime/test_trt_decoder_composite.cpp diff --git a/libs/qec/unittests/realtime/CMakeLists.txt b/libs/qec/unittests/realtime/CMakeLists.txt index 53496180..68972ecc 100644 --- a/libs/qec/unittests/realtime/CMakeLists.txt +++ b/libs/qec/unittests/realtime/CMakeLists.txt @@ -76,6 +76,72 @@ else() "Skipping test_realtime_predecoder_w_pymatching.") endif() +# =========================================================================== # +# test_trt_decoder_composite +# =========================================================================== # +# Software-only demo: composite TensorRT decoder plugin with PyMatching as the +# global decoder. Uses the same Stim-exported detector, observable, H, O, and +# priors files as test_realtime_predecoder_w_pymatching, but exercises the +# trt_decoder plugin's built-in [pre_L, residual] -> global_decoder composition. + +if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY AND TENSORRT_ONNX_PARSER_LIBRARY + AND CUDAQ_REALTIME_INCLUDE_DIR + AND TARGET cudaq-qec-trt-decoder) + + get_filename_component(_cuda_bin_trt_comp "${CMAKE_CUDA_COMPILER}" DIRECTORY) + get_filename_component(_cuda_root_trt_comp "${_cuda_bin_trt_comp}" DIRECTORY) + set(_cuda_cccl_include_trt_comp "${_cuda_root_trt_comp}/include/cccl") + + add_executable(test_trt_decoder_composite + test_trt_decoder_composite.cpp + predecoder_pipeline_common.cpp + ) + + set_target_properties(test_trt_decoder_composite PROPERTIES + CXX_STANDARD 20 + CXX_STANDARD_REQUIRED ON + ) + + target_compile_definitions(test_trt_decoder_composite PRIVATE + ONNX_MODEL_DIR="${CMAKE_CURRENT_SOURCE_DIR}/../../lib/realtime" + ) + + target_include_directories(test_trt_decoder_composite PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} + ${_cuda_cccl_include_trt_comp} + ${CUDAToolkit_INCLUDE_DIRS} + ${TENSORRT_INCLUDE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_SOURCE_DIR}/libs/core/include + ${CUDAQ_REALTIME_INCLUDE_DIR} + ) + + target_link_libraries(test_trt_decoder_composite PRIVATE + CUDA::cudart + cudaq-qec + cudaq-qec-trt-decoder + cudaq::cudaq + ) + + if(TARGET cudaq-qec-pymatching) + target_link_libraries(test_trt_decoder_composite PRIVATE cudaq-qec-pymatching) + endif() + + get_filename_component(_cudaq_rt_lib_dir_trt_comp "${CUDAQ_REALTIME_LIBRARY}" DIRECTORY) + set_target_properties(test_trt_decoder_composite PROPERTIES + BUILD_RPATH "${CMAKE_BINARY_DIR}/lib;${CMAKE_BINARY_DIR}/lib/decoder-plugins;${_cudaq_rt_lib_dir_trt_comp}" + INSTALL_RPATH "${CMAKE_BINARY_DIR}/lib;${CMAKE_BINARY_DIR}/lib/decoder-plugins;${_cudaq_rt_lib_dir_trt_comp}" + ) + + if(ENABLE_NVTX) + target_compile_definitions(test_trt_decoder_composite PRIVATE ENABLE_NVTX) + endif() + +else() + message(WARNING "TensorRT or cudaq-qec-trt-decoder not found. " + "Skipping test_trt_decoder_composite.") +endif() + # =========================================================================== # # hololink_predecoder_bridge # =========================================================================== # diff --git a/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp b/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp new file mode 100644 index 00000000..79457b6f --- /dev/null +++ b/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp @@ -0,0 +1,588 @@ +/****************************************************************-*- C++ -*-**** + * Copyright (c) 2026 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +/******************************************************************************* + * Composite TensorRT Decoder Demo + * + * This is a software-only trt_decoder demo of test_predecoder_w_pymatching + * It consumes the same Stim-exported files: + * detectors.bin, observables.bin, H_csr.bin, O_csr.bin, priors.bin + * + * Instead of manually wiring ai_predecoder_service -> PyMatching, it creates + *the trt_decoder plugin from an ONNX model and asks it to run PyMatching as the + * global decoder: + * + * input detectors -> TRT predecoder -> [pre_L, residual syndromes] + * -> PyMatching(H, O, priors) -> final logical frame + * + * Usage: + * test_trt_decoder_composite [d7|d13|d13_r104|d21|d21_r42|d31] + * --data-dir DIR [--max-samples=N] [--onnx-path=FILE] + * [--engine-save-path=FILE] [--batch-size=N] [--warmup=N] + * [--no-cuda-graph] [--no-raw-diagnostics] + ******************************************************************************/ + +#include "predecoder_pipeline_common.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +using hrclock = std::chrono::high_resolution_clock; + +struct DemoConfig { + std::string data_dir; + std::string onnx_path; + std::string engine_save_path; + int max_samples = 0; + int warmup_count = 20; + size_t batch_size = 1; + bool use_cuda_graph = true; + bool raw_diagnostics = true; +}; + +bool starts_with(const std::string &s, const std::string &prefix) { + return s.rfind(prefix, 0) == 0; +} + +std::string value_after_equals(const std::string &arg, + const std::string &prefix) { + return arg.substr(prefix.size()); +} + +bool parse_bool(const std::string &v) { + if (v == "1" || v == "true" || v == "TRUE" || v == "on" || v == "ON" || + v == "yes" || v == "YES") + return true; + if (v == "0" || v == "false" || v == "FALSE" || v == "off" || v == "OFF" || + v == "no" || v == "NO") + return false; + throw std::runtime_error("Expected boolean value, got '" + v + "'"); +} + +bool file_exists(const std::string &path) { + std::ifstream f(path, std::ios::binary); + return f.good(); +} + +std::string replace_extension(const std::string &path, + const std::string &new_ext) { + auto slash = path.find_last_of('/'); + auto dot = path.find_last_of('.'); + if (dot == std::string::npos || (slash != std::string::npos && dot < slash)) + return path + new_ext; + return path.substr(0, dot) + new_ext; +} + +void print_usage(const char *argv0) { + std::cerr + << "Usage: " << argv0 + << " [d7|d13|d13_r104|d21|d21_r42|d31] --data-dir DIR [options]\n\n" + << "Options:\n" + << " --data-dir DIR Directory with " + "detectors/observables/H/O\n" + << " --max-samples=N Limit samples decoded (0 = all)\n" + << " --warmup=N Samples excluded from latency stats\n" + << " --onnx-path=FILE Override full ONNX path\n" + << " --engine-save-path=FILE Where the built TRT engine is saved\n" + << " --batch-size=N TRT dynamic batch profile size (default " + "1)\n" + << " --use-cuda-graph=0|1 Enable CUDA graph executor (default 1)\n" + << " --no-cuda-graph Shorthand for --use-cuda-graph=0\n" + << " --no-raw-diagnostics Skip extra TRT-only pass for predecoder " + "stats\n" + << "\nPipelineConfig overrides are also accepted:\n" + << " --distance=N --num-rounds=N --onnx-filename=FILE --label=NAME\n"; +} + +DemoConfig parse_demo_config(int argc, char *argv[]) { + DemoConfig cfg; + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg == "--data-dir" && i + 1 < argc) { + cfg.data_dir = argv[++i]; + } else if (starts_with(arg, "--data-dir=")) { + cfg.data_dir = value_after_equals(arg, "--data-dir="); + } else if (starts_with(arg, "--max-samples=")) { + cfg.max_samples = std::stoi(value_after_equals(arg, "--max-samples=")); + } else if (arg == "--max-samples" && i + 1 < argc) { + cfg.max_samples = std::stoi(argv[++i]); + } else if (starts_with(arg, "--warmup=")) { + cfg.warmup_count = std::stoi(value_after_equals(arg, "--warmup=")); + } else if (arg == "--warmup" && i + 1 < argc) { + cfg.warmup_count = std::stoi(argv[++i]); + } else if (starts_with(arg, "--onnx-path=")) { + cfg.onnx_path = value_after_equals(arg, "--onnx-path="); + } else if (arg == "--onnx-path" && i + 1 < argc) { + cfg.onnx_path = argv[++i]; + } else if (starts_with(arg, "--engine-save-path=")) { + cfg.engine_save_path = value_after_equals(arg, "--engine-save-path="); + } else if (arg == "--engine-save-path" && i + 1 < argc) { + cfg.engine_save_path = argv[++i]; + } else if (starts_with(arg, "--batch-size=")) { + cfg.batch_size = static_cast( + std::stoul(value_after_equals(arg, "--batch-size="))); + } else if (arg == "--batch-size" && i + 1 < argc) { + cfg.batch_size = static_cast(std::stoul(argv[++i])); + } else if (starts_with(arg, "--use-cuda-graph=")) { + cfg.use_cuda_graph = + parse_bool(value_after_equals(arg, "--use-cuda-graph=")); + } else if (arg == "--no-cuda-graph") { + cfg.use_cuda_graph = false; + } else if (arg == "--no-raw-diagnostics") { + cfg.raw_diagnostics = false; + } + } + return cfg; +} + +std::vector sample_to_syndrome(const TestData &data, + int sample_idx) { + std::vector syndrome(data.num_detectors); + const int32_t *sample = data.sample(sample_idx); + for (uint32_t i = 0; i < data.num_detectors; ++i) + syndrome[i] = static_cast(sample[i] != 0 ? 1.0 : 0.0); + return syndrome; +} + +int count_input_nonzero(const TestData &data, int sample_idx) { + const int32_t *sample = data.sample(sample_idx); + int count = 0; + for (uint32_t i = 0; i < data.num_detectors; ++i) + count += (sample[i] != 0); + return count; +} + +int bit_from_float(cudaq::qec::float_t v) { return v >= 0.5 ? 1 : 0; } + +double percentile(const std::vector &sorted, double p) { + if (sorted.empty()) + return 0.0; + double idx = (p / 100.0) * (sorted.size() - 1); + size_t lo = static_cast(idx); + size_t hi = std::min(lo + 1, sorted.size() - 1); + double frac = idx - static_cast(lo); + return sorted[lo] * (1.0 - frac) + sorted[hi] * frac; +} + +struct CompositeStats { + int decoded = 0; + int converged = 0; + int missing = 0; + int first_obs_mismatches = 0; + int any_obs_mismatches = 0; + int ground_truth_ones = 0; + int result_ones = 0; + std::vector first_obs_pred; + std::vector latencies_us; + double wall_us = 0.0; +}; + +struct RawDiagnostics { + bool ran = false; + int decoded = 0; + int malformed = 0; + int predecoder_only_mismatches = 0; + int64_t total_input_nonzero = 0; + int64_t total_residual_nonzero = 0; + int64_t total_pre_l = 0; + int64_t total_pymatch_frame = 0; +}; + +CompositeStats run_composite_decoder(cudaq::qec::decoder &decoder, + const TestData &test_data, int n_samples, + size_t num_observables) { + CompositeStats stats; + stats.first_obs_pred.assign(n_samples, -1); + stats.latencies_us.reserve(n_samples); + + const size_t check_observables = + std::min(num_observables, static_cast(test_data.num_observables)); + + auto wall_start = std::chrono::steady_clock::now(); + for (int i = 0; i < n_samples; ++i) { + auto syndrome = sample_to_syndrome(test_data, i); + + auto t0 = hrclock::now(); + auto result = decoder.decode(syndrome); + auto t1 = hrclock::now(); + + stats.latencies_us.push_back( + std::chrono::duration_cast>( + t1 - t0) + .count()); + + if (result.result.size() < check_observables || check_observables == 0) { + stats.missing++; + continue; + } + + stats.decoded++; + if (result.converged) + stats.converged++; + + bool any_mismatch = false; + for (size_t obs = 0; obs < check_observables; ++obs) { + int pred = bit_from_float(result.result[obs]); + int truth = test_data.observable(i, static_cast(obs)); + if (obs == 0) { + stats.first_obs_pred[i] = pred; + stats.ground_truth_ones += truth != 0; + stats.result_ones += pred != 0; + if (pred != truth) + stats.first_obs_mismatches++; + } + any_mismatch |= (pred != truth); + } + if (any_mismatch) + stats.any_obs_mismatches++; + } + auto wall_end = std::chrono::steady_clock::now(); + stats.wall_us = + std::chrono::duration_cast>( + wall_end - wall_start) + .count(); + return stats; +} + +RawDiagnostics run_raw_diagnostics(cudaq::qec::decoder &raw_decoder, + const TestData &test_data, + const std::vector &final_pred, + int n_samples, size_t num_observables, + size_t residual_detectors) { + RawDiagnostics stats; + stats.ran = true; + + const size_t expected_output = num_observables + residual_detectors; + for (int i = 0; i < n_samples; ++i) { + auto syndrome = sample_to_syndrome(test_data, i); + auto raw = raw_decoder.decode(syndrome); + if (raw.result.size() < expected_output || num_observables == 0) { + stats.malformed++; + continue; + } + + stats.decoded++; + stats.total_input_nonzero += count_input_nonzero(test_data, i); + + int pre_l = bit_from_float(raw.result[0]); + int truth = test_data.observable(i, 0); + if (pre_l != truth) + stats.predecoder_only_mismatches++; + stats.total_pre_l += pre_l; + + if (i < static_cast(final_pred.size()) && final_pred[i] >= 0) + stats.total_pymatch_frame += (final_pred[i] ^ pre_l); + + for (size_t k = 0; k < residual_detectors; ++k) + stats.total_residual_nonzero += + bit_from_float(raw.result[num_observables + k]); + } + return stats; +} + +} // namespace + +int main(int argc, char *argv[]) { + std::string config_name = "d7"; + if (argc > 1 && std::string(argv[1]).substr(0, 2) != "--") + config_name = argv[1]; + + if (argc > 1 && + (std::string(argv[1]) == "--help" || std::string(argv[1]) == "-h")) { + print_usage(argv[0]); + return 0; + } + + auto config_opt = PipelineConfig::from_name(config_name); + if (!config_opt) { + print_usage(argv[0]); + return 1; + } + + PipelineConfig config = *config_opt; + config.apply_cli_overrides(argc, argv); + DemoConfig demo_cfg = parse_demo_config(argc, argv); + + if (demo_cfg.data_dir.empty()) { + std::cerr + << "ERROR: --data-dir is required for composite TRT decoder demo.\n"; + print_usage(argv[0]); + return 1; + } + if (demo_cfg.batch_size < 1) { + std::cerr << "ERROR: --batch-size must be >= 1.\n"; + return 1; + } + + int device_count = 0; + if (cudaGetDeviceCount(&device_count) != cudaSuccess || device_count == 0) { + std::cerr << "ERROR: no CUDA device available.\n"; + return 1; + } + + std::string onnx_path = + demo_cfg.onnx_path.empty() ? config.onnx_path() : demo_cfg.onnx_path; + std::string engine_save_path = demo_cfg.engine_save_path.empty() + ? replace_extension(onnx_path, ".engine") + : demo_cfg.engine_save_path; + + if (!file_exists(onnx_path)) { + std::cerr << "ERROR: ONNX file not found: " << onnx_path << "\n"; + return 1; + } + + TestData test_data = load_test_data(demo_cfg.data_dir); + if (!test_data.loaded()) { + std::cerr << "ERROR: failed to load detector/observable test data from " + << demo_cfg.data_dir << "\n"; + return 1; + } + + StimData stim = load_stim_data(demo_cfg.data_dir); + if (!stim.H.loaded()) { + std::cerr << "ERROR: H_csr.bin is required in " << demo_cfg.data_dir + << "\n"; + return 1; + } + if (!stim.O.loaded()) { + std::cerr << "ERROR: O_csr.bin is required in " << demo_cfg.data_dir + << "\n"; + return 1; + } + if (stim.O.nrows == 0) { + std::cerr << "ERROR: O_csr.bin contains zero observables.\n"; + return 1; + } + if (test_data.num_observables < stim.O.nrows) { + std::cerr << "ERROR: observables.bin has " << test_data.num_observables + << " observable column(s), but O_csr.bin has " << stim.O.nrows + << " row(s).\n"; + return 1; + } + + auto H = stim.H.to_dense(); + auto O = stim.O.to_dense(); + + cudaqx::heterogeneous_map pm_params; + pm_params.insert("merge_strategy", std::string("smallest_weight")); + pm_params.insert("O", O); + if (!stim.priors.empty()) { + if (stim.priors.size() != stim.H.ncols) { + std::cerr << "ERROR: priors.bin has " << stim.priors.size() + << " entries, but H has " << stim.H.ncols << " columns.\n"; + return 1; + } + pm_params.insert("error_rate_vec", stim.priors); + } + + cudaqx::heterogeneous_map trt_params; + trt_params.insert("onnx_load_path", onnx_path); + trt_params.insert("engine_save_path", engine_save_path); + trt_params.insert("batch_size", demo_cfg.batch_size); + trt_params.insert("use_cuda_graph", demo_cfg.use_cuda_graph); + trt_params.insert("global_decoder", std::string("pymatching")); + trt_params.insert("global_decoder_params", pm_params); + trt_params.insert("O", O); + + std::cout << "--- Initializing Composite TensorRT Decoder (" << config.label + << ") ---\n"; + std::cout << "[Setup] ONNX: " << onnx_path << "\n"; + std::cout << "[Setup] Engine save: " << engine_save_path << "\n"; + std::cout << "[Setup] Data dir: " << demo_cfg.data_dir << "\n"; + std::cout << "[Setup] H: " << stim.H.nrows << " x " << stim.H.ncols + << " (" << stim.H.nnz << " nnz)\n"; + std::cout << "[Setup] O: " << stim.O.nrows << " x " << stim.O.ncols + << " (" << stim.O.nnz << " nnz)\n"; + std::cout << "[Setup] Samples: " << test_data.num_samples + << ", detectors/sample=" << test_data.num_detectors + << ", observables/sample=" << test_data.num_observables << "\n"; + std::cout << "[Setup] PyMatching: merge_strategy=smallest_weight" + << (stim.priors.empty() ? ", no priors\n" : ", priors loaded\n"); + + std::unique_ptr composite_decoder; + try { + composite_decoder = cudaq::qec::decoder::get("trt_decoder", H, trt_params); + } catch (const std::exception &e) { + std::cerr << "ERROR: failed to create composite trt_decoder: " << e.what() + << "\n"; + return 1; + } + + const int available_samples = static_cast(test_data.num_samples); + const int n_samples = (demo_cfg.max_samples > 0) + ? std::min(demo_cfg.max_samples, available_samples) + : available_samples; + if (n_samples <= 0) { + std::cerr << "ERROR: no samples selected.\n"; + return 1; + } + + std::cout << "[Run] Decoding " << n_samples + << " sample(s) through composite TRT+PyMatching decoder...\n"; + + CompositeStats stats = run_composite_decoder(*composite_decoder, test_data, + n_samples, stim.O.nrows); + + RawDiagnostics raw_stats; + if (demo_cfg.raw_diagnostics) { + cudaqx::heterogeneous_map raw_params; + if (file_exists(engine_save_path)) { + raw_params.insert("engine_load_path", engine_save_path); + } else { + std::cerr << "[WARN] Engine file was not found after composite init; " + "raw diagnostics will rebuild from ONNX.\n"; + raw_params.insert("onnx_load_path", onnx_path); + raw_params.insert("engine_save_path", engine_save_path); + } + raw_params.insert("batch_size", demo_cfg.batch_size); + raw_params.insert("use_cuda_graph", demo_cfg.use_cuda_graph); + + try { + auto raw_decoder = cudaq::qec::decoder::get("trt_decoder", H, raw_params); + raw_stats = + run_raw_diagnostics(*raw_decoder, test_data, stats.first_obs_pred, + n_samples, stim.O.nrows, stim.H.nrows); + } catch (const std::exception &e) { + std::cerr << "[WARN] Raw TRT diagnostics skipped: " << e.what() << "\n"; + } + } + + int warmup = std::min(demo_cfg.warmup_count, + static_cast(stats.latencies_us.size())); + std::vector steady_latencies(stats.latencies_us.begin() + warmup, + stats.latencies_us.end()); + std::sort(steady_latencies.begin(), steady_latencies.end()); + double mean = 0.0; + for (double v : steady_latencies) + mean += v; + mean = steady_latencies.empty() ? 0.0 : mean / steady_latencies.size(); + double stddev = 0.0; + for (double v : steady_latencies) + stddev += (v - mean) * (v - mean); + stddev = steady_latencies.empty() + ? 0.0 + : std::sqrt(stddev / steady_latencies.size()); + double throughput = + stats.wall_us > 0.0 + ? (static_cast(stats.decoded) * 1e6 / stats.wall_us) + : 0.0; + double ler = stats.decoded > 0 + ? static_cast(stats.first_obs_mismatches) / + static_cast(stats.decoded) + : 0.0; + double any_obs_ler = stats.decoded > 0 + ? static_cast(stats.any_obs_mismatches) / + static_cast(stats.decoded) + : 0.0; + + std::cout << std::fixed; + std::cout + << "\n================================================================\n"; + std::cout << " Composite TRT Decoder Benchmark: " << config.label << "\n"; + std::cout + << "================================================================\n"; + std::cout << " Submitted: " << n_samples << "\n"; + std::cout << " Decoded: " << stats.decoded << "\n"; + std::cout << " Missing/malformed: " << stats.missing << "\n"; + std::cout << " Converged: " << stats.converged << "\n"; + std::cout << std::setprecision(1); + std::cout << " Wall time: " << stats.wall_us / 1000.0 << " ms\n"; + std::cout << " Throughput: " << throughput << " samples/s\n"; + std::cout + << " ---------------------------------------------------------------\n"; + std::cout << " Latency (us) [steady-state, " << steady_latencies.size() + << " samples after " << warmup << " warmup]\n"; + if (!steady_latencies.empty()) { + std::cout << " min = " << std::setw(10) << steady_latencies.front() + << "\n"; + std::cout << " p50 = " << std::setw(10) + << percentile(steady_latencies, 50) << "\n"; + std::cout << " mean = " << std::setw(10) << mean << "\n"; + std::cout << " p90 = " << std::setw(10) + << percentile(steady_latencies, 90) << "\n"; + std::cout << " p95 = " << std::setw(10) + << percentile(steady_latencies, 95) << "\n"; + std::cout << " p99 = " << std::setw(10) + << percentile(steady_latencies, 99) << "\n"; + std::cout << " max = " << std::setw(10) << steady_latencies.back() + << "\n"; + std::cout << " stddev = " << std::setw(10) << stddev << "\n"; + } + + std::cout + << " ---------------------------------------------------------------\n"; + std::cout << std::setprecision(4); + std::cout << " Correctness [observable 0]:\n"; + std::cout << " Composite mismatches: " << stats.first_obs_mismatches + << " LER: " << ler << "\n"; + std::cout << " Any-observable mismatches: " << stats.any_obs_mismatches + << " rate: " << any_obs_ler << "\n"; + std::cout << " Composite ones: " << stats.result_ones << "/" + << stats.decoded << "\n"; + std::cout << " Ground truth ones: " << stats.ground_truth_ones << "/" + << stats.decoded << "\n"; + + if (raw_stats.ran && raw_stats.decoded > 0) { + double pred_ler = + static_cast(raw_stats.predecoder_only_mismatches) / + static_cast(raw_stats.decoded); + double avg_input_nz = static_cast(raw_stats.total_input_nonzero) / + static_cast(raw_stats.decoded); + double avg_residual_nz = + static_cast(raw_stats.total_residual_nonzero) / + static_cast(raw_stats.decoded); + double input_density = avg_input_nz / test_data.num_detectors; + double residual_density = avg_residual_nz / stim.H.nrows; + double reduction = + input_density > 0.0 ? (1.0 - residual_density / input_density) : 0.0; + + std::cout + << " " + "---------------------------------------------------------------\n"; + std::cout << " Raw TRT diagnostics (" << raw_stats.decoded << " samples, " + << raw_stats.malformed << " malformed):\n"; + std::cout << " Predecoder-only mismatches: " + << raw_stats.predecoder_only_mismatches << " LER: " << pred_ler + << "\n"; + std::cout << std::setprecision(3); + std::cout << " Avg logical_pred: " + << static_cast(raw_stats.total_pre_l) / raw_stats.decoded + << "\n"; + std::cout << " Avg PyMatching frame flip: " + << static_cast(raw_stats.total_pymatch_frame) / + raw_stats.decoded + << "\n"; + std::cout << std::setprecision(1); + std::cout << " Input density: " << avg_input_nz << " / " + << test_data.num_detectors << " (" << std::setprecision(4) + << input_density << ")\n"; + std::cout << std::setprecision(1); + std::cout << " Residual density: " << avg_residual_nz << " / " + << stim.H.nrows << " (" << std::setprecision(4) + << residual_density << ")\n"; + std::cout << std::setprecision(1); + std::cout << " Reduction: " << reduction * 100.0 << "%\n"; + } + + std::cout + << "================================================================\n"; + std::cout << "Done.\n"; + return 0; +} From 0418e4545125545668a6afb93552cd7ea97bb4bb Mon Sep 17 00:00:00 2001 From: Scott Thornton Date: Fri, 8 May 2026 15:22:29 +0000 Subject: [PATCH 03/19] Expand TRT decoder YAML config for composite decoding Add YAML/config support for TRT decoder runtime options including batch size, CUDA graph execution, global decoder selection, and PyMatching-specific global decoder parameters. Wire realtime decoder construction so TRT configs receive the top-level observable matrix from O_sparse, and pass the same O matrix into PyMatching global decoder params for composite observable decoding. Expose the new config fields through Python bindings and heterogeneous_map round-tripping. Extend YAML tests for TRT config round-trip, runtime parameter conversion, and O_sparse-to-O injection. Update test_trt_decoder_composite to support an optional --config-yaml path, allowing the existing composite demo to construct and run a real TRT+PyMatching decoder directly from YAML while preserving the original manual CLI path. Signed-off-by: Scott Thornton --- .../cudaq/qec/realtime/decoding_config.h | 17 + libs/qec/lib/realtime/config.cpp | 57 +++ libs/qec/lib/realtime/realtime_decoding.cpp | 33 +- libs/qec/lib/realtime/realtime_decoding.h | 4 + .../python/bindings/py_decoding_config.cpp | 26 ++ libs/qec/python/bindings/type_casters.h | 4 + libs/qec/unittests/realtime/CMakeLists.txt | 1 + .../realtime/test_trt_decoder_composite.cpp | 370 ++++++++++++++---- libs/qec/unittests/test_decoders_yaml.cpp | 77 ++++ 9 files changed, 502 insertions(+), 87 deletions(-) diff --git a/libs/qec/include/cudaq/qec/realtime/decoding_config.h b/libs/qec/include/cudaq/qec/realtime/decoding_config.h index 3dcc1196..a08a598a 100644 --- a/libs/qec/include/cudaq/qec/realtime/decoding_config.h +++ b/libs/qec/include/cudaq/qec/realtime/decoding_config.h @@ -89,12 +89,29 @@ struct single_error_lut_config { from_heterogeneous_map(const cudaqx::heterogeneous_map &map); }; +struct pymatching_decoder_config { + std::optional merge_strategy; + std::optional> error_rate_vec; + + bool operator==(const pymatching_decoder_config &) const = default; + + __attribute__((visibility("default"))) cudaqx::heterogeneous_map + to_heterogeneous_map() const; + + __attribute__((visibility("default"))) static pymatching_decoder_config + from_heterogeneous_map(const cudaqx::heterogeneous_map &map); +}; + struct trt_decoder_config { std::optional onnx_load_path; std::optional engine_load_path; std::optional engine_save_path; std::optional precision; std::optional memory_workspace; + std::optional batch_size; + std::optional use_cuda_graph; + std::optional global_decoder; + std::optional global_decoder_params; bool operator==(const trt_decoder_config &) const = default; diff --git a/libs/qec/lib/realtime/config.cpp b/libs/qec/lib/realtime/config.cpp index 0dc43cb4..5b08d253 100644 --- a/libs/qec/lib/realtime/config.cpp +++ b/libs/qec/lib/realtime/config.cpp @@ -188,6 +188,23 @@ single_error_lut_config single_error_lut_config::from_heterogeneous_map( return config; } +// ------ pymatching_decoder_config ------ +cudaqx::heterogeneous_map +pymatching_decoder_config::to_heterogeneous_map() const { + cudaqx::heterogeneous_map config_map; + INSERT_ARG(merge_strategy); + INSERT_ARG(error_rate_vec); + return config_map; +} + +pymatching_decoder_config pymatching_decoder_config::from_heterogeneous_map( + const cudaqx::heterogeneous_map &map) { + pymatching_decoder_config config; + GET_ARG(merge_strategy); + GET_ARG(error_rate_vec); + return config; +} + // ------ trt_decoder_config ------ cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const { cudaqx::heterogeneous_map config_map; @@ -197,6 +214,18 @@ cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const { INSERT_ARG(engine_save_path); INSERT_ARG(precision); INSERT_ARG(memory_workspace); + INSERT_ARG(batch_size); + INSERT_ARG(use_cuda_graph); + INSERT_ARG(global_decoder); + if (global_decoder_params.has_value()) { + config_map.insert("global_decoder_params", + global_decoder_params->to_heterogeneous_map()); + } else if (global_decoder.has_value()) { + // trt_decoder attaches a global decoder only when both the decoder name and + // parameter map are present. An empty map is valid for decoders that do not + // need extra parameters. + config_map.insert("global_decoder_params", cudaqx::heterogeneous_map()); + } return config_map; } @@ -209,6 +238,20 @@ trt_decoder_config trt_decoder_config::from_heterogeneous_map( GET_ARG(engine_save_path); GET_ARG(precision); GET_ARG(memory_workspace); + GET_ARG(batch_size); + GET_ARG(use_cuda_graph); + GET_ARG(global_decoder); + if (map.contains("global_decoder_params")) { + try { + config.global_decoder_params = + map.get("global_decoder_params"); + } catch (...) { + auto nested_map = + map.get("global_decoder_params"); + config.global_decoder_params = + pymatching_decoder_config::from_heterogeneous_map(nested_map); + } + } return config; } @@ -342,6 +385,16 @@ struct MappingTraits { cudaq::qec::decoding::config::single_error_lut_config &config) {} }; +template <> +struct MappingTraits { + static void + mapping(IO &io, + cudaq::qec::decoding::config::pymatching_decoder_config &config) { + io.mapOptional("merge_strategy", config.merge_strategy); + io.mapOptional("error_rate_vec", config.error_rate_vec); + } +}; + template <> struct MappingTraits { static void @@ -351,6 +404,10 @@ struct MappingTraits { io.mapOptional("engine_save_path", config.engine_save_path); io.mapOptional("precision", config.precision); io.mapOptional("memory_workspace", config.memory_workspace); + io.mapOptional("batch_size", config.batch_size); + io.mapOptional("use_cuda_graph", config.use_cuda_graph); + io.mapOptional("global_decoder", config.global_decoder); + io.mapOptional("global_decoder_params", config.global_decoder_params); } }; diff --git a/libs/qec/lib/realtime/realtime_decoding.cpp b/libs/qec/lib/realtime/realtime_decoding.cpp index b611af2c..0e627786 100644 --- a/libs/qec/lib/realtime/realtime_decoding.cpp +++ b/libs/qec/lib/realtime/realtime_decoding.cpp @@ -12,6 +12,7 @@ #include "cudaq/qec/pcm_utils.h" #include "cudaq/qec/realtime/decoding_config.h" #include "cudaq/runtime/logger/logger.h" +#include #include // Optional syndrome capture callback for --save_syndrome feature @@ -42,6 +43,35 @@ static std::vector pack_syndrome_bits(const uint8_t *syndromes, namespace cudaq::qec::decoding::host { +cudaqx::heterogeneous_map prepare_decoder_params( + const cudaq::qec::decoding::config::decoder_config &decoder_config) { + auto params = decoder_config.decoder_custom_args_to_heterogeneous_map(); + if (decoder_config.type != "trt_decoder" || decoder_config.O_sparse.empty()) + return params; + + const auto num_observables = std::count(decoder_config.O_sparse.begin(), + decoder_config.O_sparse.end(), -1); + if (num_observables == 0) + return params; + + auto O = cudaq::qec::pcm_from_sparse_vec( + decoder_config.O_sparse, num_observables, decoder_config.block_size); + params.insert("O", O); + + if (params.contains("global_decoder") && + params.get("global_decoder") == "pymatching") { + cudaqx::heterogeneous_map global_decoder_params; + if (params.contains("global_decoder_params")) { + global_decoder_params = + params.get("global_decoder_params"); + } + global_decoder_params.insert("O", O); + params.insert("global_decoder_params", global_decoder_params); + } + + return params; +} + int configure_decoders( cudaq::qec::decoding::config::multi_decoder_config &config) { CUDAQ_INFO("Initializing decoders..."); @@ -87,8 +117,7 @@ int configure_decoders( decoder_config.syndrome_size, decoder_config.block_size); auto new_decoder = cudaq::qec::get_decoder( - decoder_config.type, pcm, - decoder_config.decoder_custom_args_to_heterogeneous_map()); + decoder_config.type, pcm, prepare_decoder_params(decoder_config)); new_decoder->set_decoder_id(decoder_config.id); // Count the number of -1's in the O_sparse vector. That is the number of // rows (observables) in the observable matrix. diff --git a/libs/qec/lib/realtime/realtime_decoding.h b/libs/qec/lib/realtime/realtime_decoding.h index 38eff215..c41669e2 100644 --- a/libs/qec/lib/realtime/realtime_decoding.h +++ b/libs/qec/lib/realtime/realtime_decoding.h @@ -19,6 +19,10 @@ __attribute__((visibility("default"))) void enqueue_syndromes(std::size_t decoder_id, uint8_t *syndromes, std::uint64_t syndrome_length, std::uint64_t tag); +__attribute__((visibility("default"))) cudaqx::heterogeneous_map +prepare_decoder_params( + const cudaq::qec::decoding::config::decoder_config &decoder_config); + __attribute__((visibility("default"))) void get_corrections(std::size_t decoder_id, uint8_t *corrections, std::uint64_t correction_length, bool reset); diff --git a/libs/qec/python/bindings/py_decoding_config.cpp b/libs/qec/python/bindings/py_decoding_config.cpp index 162656e7..0cb23196 100644 --- a/libs/qec/python/bindings/py_decoding_config.cpp +++ b/libs/qec/python/bindings/py_decoding_config.cpp @@ -108,6 +108,27 @@ void bindDecodingConfig(nb::module_ &mod) { &multi_error_lut_config::from_heterogeneous_map, nb::arg("map")); + // pymatching_decoder_config + nb::class_( + mod_cfg, "pymatching_decoder_config", "PyMatching decoder configuration.") + .def(nb::init<>()) + .def( + "__init__", + [](config::pymatching_decoder_config &self, + const cudaqx::heterogeneous_map &map) { + new (&self) pymatching_decoder_config( + pymatching_decoder_config::from_heterogeneous_map(map)); + }, + nb::arg("map")) + .def_rw("merge_strategy", &pymatching_decoder_config::merge_strategy) + .def_rw("error_rate_vec", &pymatching_decoder_config::error_rate_vec) + .def("to_heterogeneous_map", + &pymatching_decoder_config::to_heterogeneous_map, + nb::rv_policy::move) + .def_static("from_heterogeneous_map", + &pymatching_decoder_config::from_heterogeneous_map, + nb::arg("map")); + // trt_decoder_config nb::class_(mod_cfg, "trt_decoder_config", "TensorRT decoder configuration.") @@ -125,6 +146,11 @@ void bindDecodingConfig(nb::module_ &mod) { .def_rw("engine_save_path", &trt_decoder_config::engine_save_path) .def_rw("precision", &trt_decoder_config::precision) .def_rw("memory_workspace", &trt_decoder_config::memory_workspace) + .def_rw("batch_size", &trt_decoder_config::batch_size) + .def_rw("use_cuda_graph", &trt_decoder_config::use_cuda_graph) + .def_rw("global_decoder", &trt_decoder_config::global_decoder) + .def_rw("global_decoder_params", + &trt_decoder_config::global_decoder_params) .def("to_heterogeneous_map", &trt_decoder_config::to_heterogeneous_map, nb::rv_policy::move) .def_static("from_heterogeneous_map", diff --git a/libs/qec/python/bindings/type_casters.h b/libs/qec/python/bindings/type_casters.h index 658bdd7b..851d3d23 100644 --- a/libs/qec/python/bindings/type_casters.h +++ b/libs/qec/python/bindings/type_casters.h @@ -194,6 +194,10 @@ struct type_caster { cudaq::qec::decoding::config::single_error_lut_config>( &val)) { result[key.c_str()] = nb::cast(single_cfg->to_heterogeneous_map()); + } else if (auto *pm_cfg = std::any_cast< + cudaq::qec::decoding::config::pymatching_decoder_config>( + &val)) { + result[key.c_str()] = nb::cast(pm_cfg->to_heterogeneous_map()); } else if (auto *sw_cfg = std::any_cast< cudaq::qec::decoding::config::sliding_window_config>( &val)) { diff --git a/libs/qec/unittests/realtime/CMakeLists.txt b/libs/qec/unittests/realtime/CMakeLists.txt index 68972ecc..8230421f 100644 --- a/libs/qec/unittests/realtime/CMakeLists.txt +++ b/libs/qec/unittests/realtime/CMakeLists.txt @@ -119,6 +119,7 @@ if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY AND TENSORRT_ONNX_PARSER_LIBRARY target_link_libraries(test_trt_decoder_composite PRIVATE CUDA::cudart cudaq-qec + cudaq-qec-realtime-decoding cudaq-qec-trt-decoder cudaq::cudaq ) diff --git a/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp b/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp index 79457b6f..56714f3a 100644 --- a/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp +++ b/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp @@ -25,9 +25,16 @@ * --data-dir DIR [--max-samples=N] [--onnx-path=FILE] * [--engine-save-path=FILE] [--batch-size=N] [--warmup=N] * [--no-cuda-graph] [--no-raw-diagnostics] + * + * test_trt_decoder_composite --data-dir DIR --config-yaml FILE + * [--decoder-id=N] [--max-samples=N] [--warmup=N] + * [--no-raw-diagnostics] ******************************************************************************/ #include "predecoder_pipeline_common.h" +#include "../../lib/realtime/realtime_decoding.h" +#include "cudaq/qec/pcm_utils.h" +#include "cudaq/qec/realtime/decoding_config.h" #include #include @@ -36,6 +43,8 @@ #include #include #include +#include +#include #include #include #include @@ -48,8 +57,10 @@ using hrclock = std::chrono::high_resolution_clock; struct DemoConfig { std::string data_dir; + std::string config_yaml_path; std::string onnx_path; std::string engine_save_path; + int64_t decoder_id = -1; int max_samples = 0; int warmup_count = 20; size_t batch_size = 1; @@ -99,6 +110,9 @@ void print_usage(const char *argv0) { "detectors/observables/H/O\n" << " --max-samples=N Limit samples decoded (0 = all)\n" << " --warmup=N Samples excluded from latency stats\n" + << " --config-yaml=FILE Build composite decoder from YAML " + "config\n" + << " --decoder-id=N Decoder ID to select from YAML config\n" << " --onnx-path=FILE Override full ONNX path\n" << " --engine-save-path=FILE Where the built TRT engine is saved\n" << " --batch-size=N TRT dynamic batch profile size (default " @@ -127,6 +141,14 @@ DemoConfig parse_demo_config(int argc, char *argv[]) { cfg.warmup_count = std::stoi(value_after_equals(arg, "--warmup=")); } else if (arg == "--warmup" && i + 1 < argc) { cfg.warmup_count = std::stoi(argv[++i]); + } else if (starts_with(arg, "--config-yaml=")) { + cfg.config_yaml_path = value_after_equals(arg, "--config-yaml="); + } else if (arg == "--config-yaml" && i + 1 < argc) { + cfg.config_yaml_path = argv[++i]; + } else if (starts_with(arg, "--decoder-id=")) { + cfg.decoder_id = std::stoll(value_after_equals(arg, "--decoder-id=")); + } else if (arg == "--decoder-id" && i + 1 < argc) { + cfg.decoder_id = std::stoll(argv[++i]); } else if (starts_with(arg, "--onnx-path=")) { cfg.onnx_path = value_after_equals(arg, "--onnx-path="); } else if (arg == "--onnx-path" && i + 1 < argc) { @@ -152,6 +174,59 @@ DemoConfig parse_demo_config(int argc, char *argv[]) { return cfg; } +std::string read_text_file(const std::string &path) { + std::ifstream in(path); + if (!in.good()) + throw std::runtime_error("Failed to open " + path); + std::ostringstream buffer; + buffer << in.rdbuf(); + return buffer.str(); +} + +size_t sparse_vec_nnz(const std::vector &sparse) { + return static_cast(std::count_if(sparse.begin(), sparse.end(), + [](int64_t v) { return v >= 0; })); +} + +size_t sparse_vec_rows(const std::vector &sparse) { + return static_cast(std::count(sparse.begin(), sparse.end(), -1)); +} + +template +void copy_param_if_present(const cudaqx::heterogeneous_map &src, + cudaqx::heterogeneous_map &dst, + const std::string &key) { + if (src.contains(key)) + dst.insert(key, src.get(key)); +} + +bool build_raw_trt_params(const cudaqx::heterogeneous_map &trt_params, + cudaqx::heterogeneous_map &raw_params) { + bool has_model_source = false; + if (trt_params.contains("engine_load_path")) { + raw_params.insert("engine_load_path", + trt_params.get("engine_load_path")); + has_model_source = true; + } else if (trt_params.contains("engine_save_path") && + file_exists(trt_params.get("engine_save_path"))) { + raw_params.insert("engine_load_path", + trt_params.get("engine_save_path")); + has_model_source = true; + } else if (trt_params.contains("onnx_load_path")) { + raw_params.insert("onnx_load_path", + trt_params.get("onnx_load_path")); + copy_param_if_present(trt_params, raw_params, + "engine_save_path"); + has_model_source = true; + } + + copy_param_if_present(trt_params, raw_params, "batch_size"); + copy_param_if_present(trt_params, raw_params, "use_cuda_graph"); + copy_param_if_present(trt_params, raw_params, "memory_workspace"); + copy_param_if_present(trt_params, raw_params, "precision"); + return has_model_source; +} + std::vector sample_to_syndrome(const TestData &data, int sample_idx) { std::vector syndrome(data.num_detectors); @@ -205,6 +280,23 @@ struct RawDiagnostics { int64_t total_pymatch_frame = 0; }; +struct DecoderSetup { + std::unique_ptr decoder; + cudaqx::tensor H; + cudaqx::heterogeneous_map trt_params; + std::string label; + std::string init_mode; + std::string config_yaml_path; + std::string onnx_path; + std::string engine_save_path; + size_t H_rows = 0; + size_t H_cols = 0; + size_t H_nnz = 0; + size_t O_rows = 0; + size_t O_cols = 0; + size_t O_nnz = 0; +}; + CompositeStats run_composite_decoder(cudaq::qec::decoder &decoder, const TestData &test_data, int n_samples, size_t num_observables) { @@ -261,6 +353,122 @@ CompositeStats run_composite_decoder(cudaq::qec::decoder &decoder, return stats; } +const cudaq::qec::decoding::config::decoder_config &select_yaml_decoder( + const cudaq::qec::decoding::config::multi_decoder_config &config, + int64_t decoder_id) { + if (decoder_id >= 0) { + auto it = std::find_if(config.decoders.begin(), config.decoders.end(), + [&](const auto &decoder_config) { + return decoder_config.id == decoder_id; + }); + if (it == config.decoders.end()) + throw std::runtime_error("Decoder ID " + std::to_string(decoder_id) + + " not found in YAML config."); + return *it; + } + + if (config.decoders.size() != 1) + throw std::runtime_error("YAML config contains " + + std::to_string(config.decoders.size()) + + " decoders; pass --decoder-id to select one."); + return config.decoders.front(); +} + +DecoderSetup create_decoder_from_yaml(const DemoConfig &demo_cfg) { + using cudaq::qec::decoding::config::multi_decoder_config; + + auto config = multi_decoder_config::from_yaml_str( + read_text_file(demo_cfg.config_yaml_path)); + const auto &decoder_config = select_yaml_decoder(config, demo_cfg.decoder_id); + if (decoder_config.type != "trt_decoder") { + throw std::runtime_error("YAML decoder type must be trt_decoder, got '" + + decoder_config.type + "'."); + } + + DecoderSetup setup; + setup.label = "yaml decoder " + std::to_string(decoder_config.id); + setup.init_mode = "YAML config"; + setup.config_yaml_path = demo_cfg.config_yaml_path; + setup.H_rows = static_cast(decoder_config.syndrome_size); + setup.H_cols = static_cast(decoder_config.block_size); + setup.H_nnz = sparse_vec_nnz(decoder_config.H_sparse); + setup.O_rows = sparse_vec_rows(decoder_config.O_sparse); + setup.O_cols = static_cast(decoder_config.block_size); + setup.O_nnz = sparse_vec_nnz(decoder_config.O_sparse); + + setup.H = cudaq::qec::pcm_from_sparse_vec(decoder_config.H_sparse, + decoder_config.syndrome_size, + decoder_config.block_size); + setup.trt_params = + cudaq::qec::decoding::host::prepare_decoder_params(decoder_config); + if (setup.trt_params.contains("onnx_load_path")) + setup.onnx_path = setup.trt_params.get("onnx_load_path"); + if (setup.trt_params.contains("engine_save_path")) + setup.engine_save_path = + setup.trt_params.get("engine_save_path"); + if (setup.trt_params.contains("engine_load_path") && + setup.engine_save_path.empty()) + setup.engine_save_path = + setup.trt_params.get("engine_load_path"); + + setup.decoder = + cudaq::qec::decoder::get(decoder_config.type, setup.H, setup.trt_params); + return setup; +} + +DecoderSetup create_decoder_from_cli(const PipelineConfig &config, + const DemoConfig &demo_cfg, + const StimData &stim) { + std::string onnx_path = + demo_cfg.onnx_path.empty() ? config.onnx_path() : demo_cfg.onnx_path; + std::string engine_save_path = demo_cfg.engine_save_path.empty() + ? replace_extension(onnx_path, ".engine") + : demo_cfg.engine_save_path; + + if (!file_exists(onnx_path)) + throw std::runtime_error("ONNX file not found: " + onnx_path); + + auto H = stim.H.to_dense(); + auto O = stim.O.to_dense(); + + cudaqx::heterogeneous_map pm_params; + pm_params.insert("merge_strategy", std::string("smallest_weight")); + pm_params.insert("O", O); + if (!stim.priors.empty()) { + if (stim.priors.size() != stim.H.ncols) { + throw std::runtime_error( + "priors.bin has " + std::to_string(stim.priors.size()) + + " entries, but H has " + std::to_string(stim.H.ncols) + " columns."); + } + pm_params.insert("error_rate_vec", stim.priors); + } + + DecoderSetup setup; + setup.H = H; + setup.label = config.label; + setup.init_mode = "manual CLI args"; + setup.onnx_path = onnx_path; + setup.engine_save_path = engine_save_path; + setup.H_rows = stim.H.nrows; + setup.H_cols = stim.H.ncols; + setup.H_nnz = stim.H.nnz; + setup.O_rows = stim.O.nrows; + setup.O_cols = stim.O.ncols; + setup.O_nnz = stim.O.nnz; + + setup.trt_params.insert("onnx_load_path", onnx_path); + setup.trt_params.insert("engine_save_path", engine_save_path); + setup.trt_params.insert("batch_size", demo_cfg.batch_size); + setup.trt_params.insert("use_cuda_graph", demo_cfg.use_cuda_graph); + setup.trt_params.insert("global_decoder", std::string("pymatching")); + setup.trt_params.insert("global_decoder_params", pm_params); + setup.trt_params.insert("O", O); + + setup.decoder = + cudaq::qec::decoder::get("trt_decoder", setup.H, setup.trt_params); + return setup; +} + RawDiagnostics run_raw_diagnostics(cudaq::qec::decoder &raw_decoder, const TestData &test_data, const std::vector &final_pred, @@ -337,17 +545,6 @@ int main(int argc, char *argv[]) { return 1; } - std::string onnx_path = - demo_cfg.onnx_path.empty() ? config.onnx_path() : demo_cfg.onnx_path; - std::string engine_save_path = demo_cfg.engine_save_path.empty() - ? replace_extension(onnx_path, ".engine") - : demo_cfg.engine_save_path; - - if (!file_exists(onnx_path)) { - std::cerr << "ERROR: ONNX file not found: " << onnx_path << "\n"; - return 1; - } - TestData test_data = load_test_data(demo_cfg.data_dir); if (!test_data.loaded()) { std::cerr << "ERROR: failed to load detector/observable test data from " @@ -355,74 +552,76 @@ int main(int argc, char *argv[]) { return 1; } - StimData stim = load_stim_data(demo_cfg.data_dir); - if (!stim.H.loaded()) { - std::cerr << "ERROR: H_csr.bin is required in " << demo_cfg.data_dir + const bool use_yaml_config = !demo_cfg.config_yaml_path.empty(); + std::optional stim; + if (!use_yaml_config) { + stim = load_stim_data(demo_cfg.data_dir); + if (!stim->H.loaded()) { + std::cerr << "ERROR: H_csr.bin is required in " << demo_cfg.data_dir + << "\n"; + return 1; + } + if (!stim->O.loaded()) { + std::cerr << "ERROR: O_csr.bin is required in " << demo_cfg.data_dir + << "\n"; + return 1; + } + if (stim->O.nrows == 0) { + std::cerr << "ERROR: O_csr.bin contains zero observables.\n"; + return 1; + } + } + + DecoderSetup setup; + try { + std::cout << "--- Initializing Composite TensorRT Decoder (" + << (use_yaml_config ? demo_cfg.config_yaml_path : config.label) + << ") ---\n"; + setup = use_yaml_config ? create_decoder_from_yaml(demo_cfg) + : create_decoder_from_cli(config, demo_cfg, *stim); + } catch (const std::exception &e) { + std::cerr << "ERROR: failed to create composite trt_decoder: " << e.what() << "\n"; return 1; } - if (!stim.O.loaded()) { - std::cerr << "ERROR: O_csr.bin is required in " << demo_cfg.data_dir - << "\n"; + + if (setup.O_rows == 0) { + std::cerr << "ERROR: observable matrix contains zero observables.\n"; return 1; } - if (stim.O.nrows == 0) { - std::cerr << "ERROR: O_csr.bin contains zero observables.\n"; + if (test_data.num_detectors != setup.H_rows) { + std::cerr << "ERROR: detectors.bin has " << test_data.num_detectors + << " detectors, but decoder H has " << setup.H_rows << " rows.\n"; return 1; } - if (test_data.num_observables < stim.O.nrows) { + if (test_data.num_observables < setup.O_rows) { std::cerr << "ERROR: observables.bin has " << test_data.num_observables - << " observable column(s), but O_csr.bin has " << stim.O.nrows + << " observable column(s), but decoder O has " << setup.O_rows << " row(s).\n"; return 1; } - auto H = stim.H.to_dense(); - auto O = stim.O.to_dense(); - - cudaqx::heterogeneous_map pm_params; - pm_params.insert("merge_strategy", std::string("smallest_weight")); - pm_params.insert("O", O); - if (!stim.priors.empty()) { - if (stim.priors.size() != stim.H.ncols) { - std::cerr << "ERROR: priors.bin has " << stim.priors.size() - << " entries, but H has " << stim.H.ncols << " columns.\n"; - return 1; - } - pm_params.insert("error_rate_vec", stim.priors); - } - - cudaqx::heterogeneous_map trt_params; - trt_params.insert("onnx_load_path", onnx_path); - trt_params.insert("engine_save_path", engine_save_path); - trt_params.insert("batch_size", demo_cfg.batch_size); - trt_params.insert("use_cuda_graph", demo_cfg.use_cuda_graph); - trt_params.insert("global_decoder", std::string("pymatching")); - trt_params.insert("global_decoder_params", pm_params); - trt_params.insert("O", O); - - std::cout << "--- Initializing Composite TensorRT Decoder (" << config.label - << ") ---\n"; - std::cout << "[Setup] ONNX: " << onnx_path << "\n"; - std::cout << "[Setup] Engine save: " << engine_save_path << "\n"; + std::cout << "[Setup] Init mode: " << setup.init_mode << "\n"; + if (!setup.config_yaml_path.empty()) + std::cout << "[Setup] YAML: " << setup.config_yaml_path << "\n"; + if (!setup.onnx_path.empty()) + std::cout << "[Setup] ONNX: " << setup.onnx_path << "\n"; + if (!setup.engine_save_path.empty()) + std::cout << "[Setup] Engine: " << setup.engine_save_path << "\n"; std::cout << "[Setup] Data dir: " << demo_cfg.data_dir << "\n"; - std::cout << "[Setup] H: " << stim.H.nrows << " x " << stim.H.ncols - << " (" << stim.H.nnz << " nnz)\n"; - std::cout << "[Setup] O: " << stim.O.nrows << " x " << stim.O.ncols - << " (" << stim.O.nnz << " nnz)\n"; + std::cout << "[Setup] H: " << setup.H_rows << " x " << setup.H_cols + << " (" << setup.H_nnz << " nnz)\n"; + std::cout << "[Setup] O: " << setup.O_rows << " x " << setup.O_cols + << " (" << setup.O_nnz << " nnz)\n"; std::cout << "[Setup] Samples: " << test_data.num_samples << ", detectors/sample=" << test_data.num_detectors << ", observables/sample=" << test_data.num_observables << "\n"; - std::cout << "[Setup] PyMatching: merge_strategy=smallest_weight" - << (stim.priors.empty() ? ", no priors\n" : ", priors loaded\n"); - - std::unique_ptr composite_decoder; - try { - composite_decoder = cudaq::qec::decoder::get("trt_decoder", H, trt_params); - } catch (const std::exception &e) { - std::cerr << "ERROR: failed to create composite trt_decoder: " << e.what() - << "\n"; - return 1; + std::cout << "[Setup] PyMatching: "; + if (use_yaml_config) { + std::cout << "from YAML global_decoder_params\n"; + } else { + std::cout << "merge_strategy=smallest_weight" + << (stim->priors.empty() ? ", no priors\n" : ", priors loaded\n"); } const int available_samples = static_cast(test_data.num_samples); @@ -437,30 +636,31 @@ int main(int argc, char *argv[]) { std::cout << "[Run] Decoding " << n_samples << " sample(s) through composite TRT+PyMatching decoder...\n"; - CompositeStats stats = run_composite_decoder(*composite_decoder, test_data, - n_samples, stim.O.nrows); + CompositeStats stats = + run_composite_decoder(*setup.decoder, test_data, n_samples, setup.O_rows); RawDiagnostics raw_stats; if (demo_cfg.raw_diagnostics) { cudaqx::heterogeneous_map raw_params; - if (file_exists(engine_save_path)) { - raw_params.insert("engine_load_path", engine_save_path); + if (!build_raw_trt_params(setup.trt_params, raw_params)) { + std::cerr << "[WARN] Raw TRT diagnostics skipped: no raw TRT model " + "source is available.\n"; } else { - std::cerr << "[WARN] Engine file was not found after composite init; " - "raw diagnostics will rebuild from ONNX.\n"; - raw_params.insert("onnx_load_path", onnx_path); - raw_params.insert("engine_save_path", engine_save_path); - } - raw_params.insert("batch_size", demo_cfg.batch_size); - raw_params.insert("use_cuda_graph", demo_cfg.use_cuda_graph); - - try { - auto raw_decoder = cudaq::qec::decoder::get("trt_decoder", H, raw_params); - raw_stats = - run_raw_diagnostics(*raw_decoder, test_data, stats.first_obs_pred, - n_samples, stim.O.nrows, stim.H.nrows); - } catch (const std::exception &e) { - std::cerr << "[WARN] Raw TRT diagnostics skipped: " << e.what() << "\n"; + if (setup.trt_params.contains("engine_save_path") && + !file_exists(setup.trt_params.get("engine_save_path"))) { + std::cerr << "[WARN] Engine file was not found after composite init; " + "raw diagnostics will rebuild from ONNX.\n"; + } + + try { + auto raw_decoder = + cudaq::qec::decoder::get("trt_decoder", setup.H, raw_params); + raw_stats = + run_raw_diagnostics(*raw_decoder, test_data, stats.first_obs_pred, + n_samples, setup.O_rows, setup.H_rows); + } catch (const std::exception &e) { + std::cerr << "[WARN] Raw TRT diagnostics skipped: " << e.what() << "\n"; + } } } @@ -495,7 +695,7 @@ int main(int argc, char *argv[]) { std::cout << std::fixed; std::cout << "\n================================================================\n"; - std::cout << " Composite TRT Decoder Benchmark: " << config.label << "\n"; + std::cout << " Composite TRT Decoder Benchmark: " << setup.label << "\n"; std::cout << "================================================================\n"; std::cout << " Submitted: " << n_samples << "\n"; @@ -549,7 +749,7 @@ int main(int argc, char *argv[]) { static_cast(raw_stats.total_residual_nonzero) / static_cast(raw_stats.decoded); double input_density = avg_input_nz / test_data.num_detectors; - double residual_density = avg_residual_nz / stim.H.nrows; + double residual_density = avg_residual_nz / setup.H_rows; double reduction = input_density > 0.0 ? (1.0 - residual_density / input_density) : 0.0; @@ -575,7 +775,7 @@ int main(int argc, char *argv[]) { << input_density << ")\n"; std::cout << std::setprecision(1); std::cout << " Residual density: " << avg_residual_nz << " / " - << stim.H.nrows << " (" << std::setprecision(4) + << setup.H_rows << " (" << std::setprecision(4) << residual_density << ")\n"; std::cout << std::setprecision(1); std::cout << " Reduction: " << reduction * 100.0 << "%\n"; diff --git a/libs/qec/unittests/test_decoders_yaml.cpp b/libs/qec/unittests/test_decoders_yaml.cpp index aa7a37be..c1a56796 100644 --- a/libs/qec/unittests/test_decoders_yaml.cpp +++ b/libs/qec/unittests/test_decoders_yaml.cpp @@ -6,6 +6,7 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ +#include "../lib/realtime/realtime_decoding.h" #include "cudaq/qec/decoder.h" #include "cudaq/qec/pcm_utils.h" #include "cudaq/qec/realtime/decoding_config.h" @@ -188,6 +189,82 @@ TEST(DecoderYAMLTest, SingleLUTDecoder) { test_decoder_creation(multi_config); } +cudaq::qec::decoding::config::decoder_config +create_test_decoder_config_trt(int id) { + cudaq::qec::decoding::config::decoder_config config = + create_test_empty_decoder_config(id); + config.type = "trt_decoder"; + + cudaqx::tensor O({2, config.block_size}); + O.at({0, 1}) = 1; + O.at({1, 3}) = 1; + config.O_sparse = cudaq::qec::pcm_to_sparse_vec(O); + + config.decoder_custom_args = + cudaq::qec::decoding::config::trt_decoder_config(); + auto &trt_config = std::get( + config.decoder_custom_args); + trt_config.onnx_load_path = "/tmp/predecoder.onnx"; + trt_config.engine_save_path = "/tmp/predecoder.engine"; + trt_config.precision = "best"; + trt_config.memory_workspace = 1ULL << 20; + trt_config.batch_size = 4; + trt_config.use_cuda_graph = false; + trt_config.global_decoder = "pymatching"; + trt_config.global_decoder_params = + cudaq::qec::decoding::config::pymatching_decoder_config(); + trt_config.global_decoder_params->merge_strategy = "smallest_weight"; + trt_config.global_decoder_params->error_rate_vec = + std::vector(config.block_size, 0.1); + + return config; +} + +TEST(DecoderYAMLTest, TrtDecoderConfigRoundTrip) { + cudaq::qec::decoding::config::multi_decoder_config multi_config; + multi_config.decoders.push_back(create_test_decoder_config_trt(0)); + + test_decoder_yaml_roundtrip(multi_config); +} + +TEST(DecoderYAMLTest, TrtDecoderConfigToHeterogeneousMap) { + auto config = create_test_decoder_config_trt(0); + auto params = config.decoder_custom_args_to_heterogeneous_map(); + + EXPECT_EQ(params.get("onnx_load_path"), "/tmp/predecoder.onnx"); + EXPECT_EQ(params.get("engine_save_path"), + "/tmp/predecoder.engine"); + EXPECT_EQ(params.get("precision"), "best"); + EXPECT_EQ(params.get("memory_workspace"), 1ULL << 20); + EXPECT_EQ(params.get("batch_size"), 4u); + EXPECT_FALSE(params.get("use_cuda_graph")); + EXPECT_EQ(params.get("global_decoder"), "pymatching"); + + auto global_params = + params.get("global_decoder_params"); + EXPECT_EQ(global_params.get("merge_strategy"), + "smallest_weight"); + EXPECT_EQ(global_params.get>("error_rate_vec").size(), + config.block_size); +} + +TEST(DecoderYAMLTest, TrtDecoderRealtimeParamsIncludeObservableMatrix) { + auto config = create_test_decoder_config_trt(0); + auto params = cudaq::qec::decoding::host::prepare_decoder_params(config); + + auto O = params.get>("O"); + EXPECT_EQ(O.shape()[0], 2u); + EXPECT_EQ(O.shape()[1], config.block_size); + EXPECT_EQ(O.at({0, 1}), 1); + EXPECT_EQ(O.at({1, 3}), 1); + + auto global_params = + params.get("global_decoder_params"); + auto global_O = global_params.get>("O"); + EXPECT_EQ(global_O.shape()[0], 2u); + EXPECT_EQ(global_O.shape()[1], config.block_size); +} + TEST(DecoderYAMLTest, SlidingWindowDecoder) { std::size_t n_rounds = 4; std::size_t n_errs_per_round = 30; From 7d091ce8810ff323c31af672acb7d4bca1cd664f Mon Sep 17 00:00:00 2001 From: Scott Thornton Date: Tue, 12 May 2026 00:45:16 +0000 Subject: [PATCH 04/19] Refactor TRT global decoder params to variant config Replace the TRT decoder's hardcoded optional PyMatching global decoder params with a tagged global_decoder_config variant. Preserve PyMatching as the current supported concrete config while using std::monostate for the unset case. Update heterogeneous-map conversion, YAML mapping, and Python bindings so the existing PyMatching YAML/Python surface continues to round-trip. Extend the YAML unit test to verify the PyMatching variant arm is selected and still produces the expected runtime parameter map. Signed-off-by: Scott Thornton --- .../cudaq/qec/realtime/decoding_config.h | 5 +- libs/qec/lib/realtime/config.cpp | 84 +++++++++++++++++-- .../python/bindings/py_decoding_config.cpp | 21 ++++- libs/qec/python/bindings/type_casters.h | 12 +++ libs/qec/unittests/test_decoders_yaml.cpp | 13 ++- 5 files changed, 121 insertions(+), 14 deletions(-) diff --git a/libs/qec/include/cudaq/qec/realtime/decoding_config.h b/libs/qec/include/cudaq/qec/realtime/decoding_config.h index a08a598a..1b0df079 100644 --- a/libs/qec/include/cudaq/qec/realtime/decoding_config.h +++ b/libs/qec/include/cudaq/qec/realtime/decoding_config.h @@ -102,6 +102,9 @@ struct pymatching_decoder_config { from_heterogeneous_map(const cudaqx::heterogeneous_map &map); }; +using global_decoder_config = + std::variant; + struct trt_decoder_config { std::optional onnx_load_path; std::optional engine_load_path; @@ -111,7 +114,7 @@ struct trt_decoder_config { std::optional batch_size; std::optional use_cuda_graph; std::optional global_decoder; - std::optional global_decoder_params; + global_decoder_config global_decoder_params; bool operator==(const trt_decoder_config &) const = default; diff --git a/libs/qec/lib/realtime/config.cpp b/libs/qec/lib/realtime/config.cpp index 5b08d253..81f90d1e 100644 --- a/libs/qec/lib/realtime/config.cpp +++ b/libs/qec/lib/realtime/config.cpp @@ -205,6 +205,32 @@ pymatching_decoder_config pymatching_decoder_config::from_heterogeneous_map( return config; } +cudaqx::heterogeneous_map global_decoder_config_to_heterogeneous_map( + const global_decoder_config &global_decoder_params) { + if (std::holds_alternative(global_decoder_params)) { + return cudaqx::heterogeneous_map(); + } + + if (std::holds_alternative( + global_decoder_params)) { + return std::get(global_decoder_params) + .to_heterogeneous_map(); + } + + throw std::runtime_error("Unsupported global decoder parameters."); +} + +global_decoder_config global_decoder_config_from_heterogeneous_map( + const cudaqx::heterogeneous_map &map, + const std::optional &global_decoder) { + if (global_decoder.has_value() && global_decoder.value() != "pymatching") { + throw std::runtime_error( + "global_decoder_params currently supports only pymatching."); + } + + return pymatching_decoder_config::from_heterogeneous_map(map); +} + // ------ trt_decoder_config ------ cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const { cudaqx::heterogeneous_map config_map; @@ -217,9 +243,14 @@ cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const { INSERT_ARG(batch_size); INSERT_ARG(use_cuda_graph); INSERT_ARG(global_decoder); - if (global_decoder_params.has_value()) { - config_map.insert("global_decoder_params", - global_decoder_params->to_heterogeneous_map()); + if (!std::holds_alternative(global_decoder_params)) { + if (global_decoder.has_value() && global_decoder.value() != "pymatching") { + throw std::runtime_error( + "global_decoder_params currently supports only pymatching."); + } + config_map.insert( + "global_decoder_params", + global_decoder_config_to_heterogeneous_map(global_decoder_params)); } else if (global_decoder.has_value()) { // trt_decoder attaches a global decoder only when both the decoder name and // parameter map are present. An empty map is valid for decoders that do not @@ -244,12 +275,18 @@ trt_decoder_config trt_decoder_config::from_heterogeneous_map( if (map.contains("global_decoder_params")) { try { config.global_decoder_params = - map.get("global_decoder_params"); + map.get("global_decoder_params"); } catch (...) { - auto nested_map = - map.get("global_decoder_params"); - config.global_decoder_params = - pymatching_decoder_config::from_heterogeneous_map(nested_map); + try { + config.global_decoder_params = + map.get("global_decoder_params"); + } catch (...) { + auto nested_map = + map.get("global_decoder_params"); + config.global_decoder_params = + global_decoder_config_from_heterogeneous_map(nested_map, + config.global_decoder); + } } } @@ -395,6 +432,30 @@ struct MappingTraits { } }; +template <> +struct MappingTraits { + static void + mapping(IO &io, cudaq::qec::decoding::config::global_decoder_config &config) { + using namespace cudaq::qec::decoding::config; + + if (io.outputting()) { + if (std::holds_alternative(config)) { + return; + } + + auto &pymatching_config = std::get(config); + io.mapOptional("merge_strategy", pymatching_config.merge_strategy); + io.mapOptional("error_rate_vec", pymatching_config.error_rate_vec); + return; + } + + pymatching_decoder_config pymatching_config; + io.mapOptional("merge_strategy", pymatching_config.merge_strategy); + io.mapOptional("error_rate_vec", pymatching_config.error_rate_vec); + config = std::move(pymatching_config); + } +}; + template <> struct MappingTraits { static void @@ -408,6 +469,13 @@ struct MappingTraits { io.mapOptional("use_cuda_graph", config.use_cuda_graph); io.mapOptional("global_decoder", config.global_decoder); io.mapOptional("global_decoder_params", config.global_decoder_params); + + if (!std::holds_alternative(config.global_decoder_params) && + config.global_decoder.has_value() && + config.global_decoder.value() != "pymatching") { + throw std::runtime_error( + "global_decoder_params currently supports only pymatching."); + } } }; diff --git a/libs/qec/python/bindings/py_decoding_config.cpp b/libs/qec/python/bindings/py_decoding_config.cpp index 0cb23196..1baea0d2 100644 --- a/libs/qec/python/bindings/py_decoding_config.cpp +++ b/libs/qec/python/bindings/py_decoding_config.cpp @@ -149,8 +149,25 @@ void bindDecodingConfig(nb::module_ &mod) { .def_rw("batch_size", &trt_decoder_config::batch_size) .def_rw("use_cuda_graph", &trt_decoder_config::use_cuda_graph) .def_rw("global_decoder", &trt_decoder_config::global_decoder) - .def_rw("global_decoder_params", - &trt_decoder_config::global_decoder_params) + .def_prop_rw( + "global_decoder_params", + [](const trt_decoder_config &self) + -> std::optional { + if (std::holds_alternative( + self.global_decoder_params)) { + return std::get( + self.global_decoder_params); + } + return std::nullopt; + }, + [](trt_decoder_config &self, + std::optional value) { + if (value.has_value()) { + self.global_decoder_params = value.value(); + } else { + self.global_decoder_params = std::monostate(); + } + }) .def("to_heterogeneous_map", &trt_decoder_config::to_heterogeneous_map, nb::rv_policy::move) .def_static("from_heterogeneous_map", diff --git a/libs/qec/python/bindings/type_casters.h b/libs/qec/python/bindings/type_casters.h index 851d3d23..e3c5f80f 100644 --- a/libs/qec/python/bindings/type_casters.h +++ b/libs/qec/python/bindings/type_casters.h @@ -198,6 +198,18 @@ struct type_caster { cudaq::qec::decoding::config::pymatching_decoder_config>( &val)) { result[key.c_str()] = nb::cast(pm_cfg->to_heterogeneous_map()); + } else if (auto *global_cfg = std::any_cast< + cudaq::qec::decoding::config::global_decoder_config>( + &val)) { + if (std::holds_alternative(*global_cfg)) { + result[key.c_str()] = nb::none(); + } else { + result[key.c_str()] = nb::cast( + std::get< + cudaq::qec::decoding::config::pymatching_decoder_config>( + *global_cfg) + .to_heterogeneous_map()); + } } else if (auto *sw_cfg = std::any_cast< cudaq::qec::decoding::config::sliding_window_config>( &val)) { diff --git a/libs/qec/unittests/test_decoders_yaml.cpp b/libs/qec/unittests/test_decoders_yaml.cpp index c1a56796..e803d932 100644 --- a/libs/qec/unittests/test_decoders_yaml.cpp +++ b/libs/qec/unittests/test_decoders_yaml.cpp @@ -211,11 +211,12 @@ create_test_decoder_config_trt(int id) { trt_config.batch_size = 4; trt_config.use_cuda_graph = false; trt_config.global_decoder = "pymatching"; - trt_config.global_decoder_params = + auto pymatching_params = cudaq::qec::decoding::config::pymatching_decoder_config(); - trt_config.global_decoder_params->merge_strategy = "smallest_weight"; - trt_config.global_decoder_params->error_rate_vec = + pymatching_params.merge_strategy = "smallest_weight"; + pymatching_params.error_rate_vec = std::vector(config.block_size, 0.1); + trt_config.global_decoder_params = pymatching_params; return config; } @@ -225,6 +226,12 @@ TEST(DecoderYAMLTest, TrtDecoderConfigRoundTrip) { multi_config.decoders.push_back(create_test_decoder_config_trt(0)); test_decoder_yaml_roundtrip(multi_config); + const auto &trt_config = + std::get( + multi_config.decoders[0].decoder_custom_args); + EXPECT_TRUE(std::holds_alternative< + cudaq::qec::decoding::config::pymatching_decoder_config>( + trt_config.global_decoder_params)); } TEST(DecoderYAMLTest, TrtDecoderConfigToHeterogeneousMap) { From 26be6b48c450861c5a5631f7f9ec59a3c56425bc Mon Sep 17 00:00:00 2001 From: Scott Thornton Date: Fri, 29 May 2026 17:14:23 +0000 Subject: [PATCH 05/19] Restore optional None setter helper Signed-off-by: Scott Thornton --- libs/qec/python/bindings/py_decoding_config.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/libs/qec/python/bindings/py_decoding_config.cpp b/libs/qec/python/bindings/py_decoding_config.cpp index 3c75bc3c..db7a2734 100644 --- a/libs/qec/python/bindings/py_decoding_config.cpp +++ b/libs/qec/python/bindings/py_decoding_config.cpp @@ -29,6 +29,9 @@ void bindDecodingConfig(nb::module_ &mod) { auto mod_cfg = qecmod.def_submodule("config", "Realtime decoding configuration"); + // Allow Python None to clear std::optional fields. + const auto setter_accepts_none = nb::for_setter(nb::arg("value").none()); + // srelay_bp_config nb::class_(mod_cfg, "srelay_bp_config", "Relay-BP decoder configuration.") From 1dff1f1df628654d03c9cddf4dc7d93902b101ab Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Tue, 9 Jun 2026 11:46:30 -0700 Subject: [PATCH 06/19] Add method to set decoding result type Signed-off-by: Melody Ren --- libs/qec/include/cudaq/qec/decoder.h | 30 +++++++ libs/qec/lib/decoder.cpp | 61 ++++++++++---- .../qec/unittests/decoders/sample_decoder.cpp | 9 +- libs/qec/unittests/test_decoders.cpp | 82 ++++++++++++++++++- 4 files changed, 162 insertions(+), 20 deletions(-) diff --git a/libs/qec/include/cudaq/qec/decoder.h b/libs/qec/include/cudaq/qec/decoder.h index 1a47b9b0..0a2f4c0c 100644 --- a/libs/qec/include/cudaq/qec/decoder.h +++ b/libs/qec/include/cudaq/qec/decoder.h @@ -144,6 +144,22 @@ class decoder std::unique_ptr pimpl; public: + /// @brief Indicates whether decode() returns a full error frame (length + /// block_size) or an already-projected observable frame (length + /// num_observables). Decoders that accept an "O" observable matrix in their + /// constructor params should call set_result_type(decode_to_obs); all others + /// default to decode_to_errs. + /// + /// Note: even in decode_to_obs mode, set_O_sparse() must still be called so + /// that enqueue_syndrome() knows num_observables and can size the corrections + /// buffer correctly. + enum decode_result_type { + decode_to_errs, ///< result.size() == block_size; enqueue_syndrome projects + ///< via O_sparse + decode_to_obs, ///< result.size() == num_observables; enqueue_syndrome uses + ///< result directly; set_O_sparse() still required + }; + decoder() = delete; /// @brief Constructor @@ -234,6 +250,12 @@ class decoder // Note: all of the current realtime decoding API is designed to be used with // hard syndromes. + /// @brief Returns the type of result produced by decode(). + /// Defaults to decode_to_errs. Decoders that project to observables + /// internally (i.e., constructed with an "O" param) should call + /// set_result_type(decode_to_obs) in their constructor. + decode_result_type get_result_type() const { return result_type_; } + /// @brief Get the number of measurement syndromes per decode call. This /// depends on D_sparse, so you must have called set_D_sparse() first. uint32_t get_num_msyn_per_decode() const; @@ -315,6 +337,11 @@ class decoder virtual std::string get_version() const; protected: + /// @brief Sets the result type. Call in the constructor when an "O" + /// observable matrix is detected in the decoder params. Must be called + /// before the first enqueue_syndrome(). + void set_result_type(decode_result_type type) { result_type_ = type; } + /// @brief For a classical `[n,k]` code, this is `n`. std::size_t block_size = 0; @@ -329,6 +356,9 @@ class decoder /// @brief The decoder's D matrix in sparse format std::vector> D_sparse; + +private: + decode_result_type result_type_ = decode_result_type::decode_to_errs; }; /// @brief Convert a single soft probability to a hard 0/1 decision. diff --git a/libs/qec/lib/decoder.cpp b/libs/qec/lib/decoder.cpp index 71fa4afb..56bac8f3 100644 --- a/libs/qec/lib/decoder.cpp +++ b/libs/qec/lib/decoder.cpp @@ -362,36 +362,61 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, return false; } } + // Process the results. + // TODO - should this interrogate the decoded_result.converged flag? + const auto result_type = get_result_type(); + const auto num_observables = get_num_observables(); + const std::size_t expected_result_size = + result_type == decode_result_type::decode_to_obs ? num_observables + : block_size; if ((!pimpl->is_sliding_window && - decoded_result.result.size() != block_size) || + decoded_result.result.size() != expected_result_size) || (pimpl->is_sliding_window && !decoded_result.result.empty() && - decoded_result.result.size() != block_size)) { - throw std::runtime_error( - fmt::format("Decoder result size ({}) does not match block_size ({})", - decoded_result.result.size(), block_size)); + decoded_result.result.size() != expected_result_size)) { + throw std::runtime_error(fmt::format( + "Decoder result size ({}) does not match expected size ({}) for " + "result type {}", + decoded_result.result.size(), expected_result_size, + result_type == decode_result_type::decode_to_obs ? "decode_to_obs" + : "decode_to_errs")); } if (should_log) { log_t2 = std::chrono::high_resolution_clock::now(); - for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++) - if (decoded_result.result[e]) - log_errors.push_back(e); + // TODO: log_errors is meaningful only for decode_to_errs; revisit on + // the logging pass. + if (result_type != decode_result_type::decode_to_obs) { + for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++) + if (decoded_result.result[e]) + log_errors.push_back(e); + } } - // Process the results. - // TODO - should this interrogate the decoded_result.converged flag? - auto num_observables = O_sparse.size(); - // For each observable - for (std::size_t i = 0; i < num_observables; i++) { - // For each error that flips this observable - for (auto col : O_sparse[i]) { - // If the decoder predicted that this error occurred - if (decoded_result.result[col]) { - // Flip the correction for this observable + if (result_type == decode_result_type::decode_to_obs) { + // Observable-frame path: decoder already projected to observables via its + // internal "O" matrix; use the result directly. + for (std::size_t i = 0; i < num_observables; i++) { + if (decoded_result.result[i]) { pimpl->corrections[i] ^= 1; if (should_log) log_observable_corrections[i] ^= 1; } } + } else { + // Error-frame path: decoder returns block-sized error vector; project to + // observables via O_sparse. + // For each observable + for (std::size_t i = 0; i < num_observables; i++) { + // For each error that flips this observable + for (auto col : O_sparse[i]) { + // If the decoder predicted that this error occurred + if (decoded_result.result[col]) { + // Flip the correction for this observable + pimpl->corrections[i] ^= 1; + if (should_log) + log_observable_corrections[i] ^= 1; + } + } + } } if (should_log) { log_t3 = std::chrono::high_resolution_clock::now(); diff --git a/libs/qec/unittests/decoders/sample_decoder.cpp b/libs/qec/unittests/decoders/sample_decoder.cpp index da953411..0d357b5d 100644 --- a/libs/qec/unittests/decoders/sample_decoder.cpp +++ b/libs/qec/unittests/decoders/sample_decoder.cpp @@ -16,18 +16,25 @@ namespace cudaq::qec { /// @brief This is a sample (dummy) decoder that demonstrates how to build a /// bare bones custom decoder based on the `cudaq::qec::decoder` interface. class sample_decoder : public decoder { +private: + bool decode_to_obs = false; + public: sample_decoder(const cudaq::qec::sparse_binary_matrix &H, const cudaqx::heterogeneous_map ¶ms) : decoder(H) { // Decoder-specific constructor arguments can be placed in `params`. + decode_to_obs = params.get("decode_to_obs", decode_to_obs); + if (decode_to_obs) + set_result_type(decode_result_type::decode_to_obs); } virtual decoder_result decode(const std::vector &syndrome) { // This is a simple decoder that simply results decoder_result result; result.converged = true; - result.result = std::vector(block_size, 0.0f); + result.result = + decode_to_obs ? syndrome : std::vector(block_size, 0.0f); return result; } diff --git a/libs/qec/unittests/test_decoders.cpp b/libs/qec/unittests/test_decoders.cpp index 672c4e4b..2ef4cbaa 100644 --- a/libs/qec/unittests/test_decoders.cpp +++ b/libs/qec/unittests/test_decoders.cpp @@ -165,7 +165,7 @@ TEST(DecoderPlugins, SingleErrorLutExample_DecodesSingletonColumnSyndromes) { constexpr std::size_t block_size = 3; constexpr std::size_t syndrome_size = 2; // | 1 1 0 | - // | 0 1 1 | — single-bit columns are weight-1 syndrome patterns. + // | 0 1 1 | - single-bit columns are weight-1 syndrome patterns. std::vector H_vec = {1, 1, 0, // row 0 0, 1, 1}; cudaqx::tensor H; @@ -878,3 +878,83 @@ error(0.05) D0 D1 EXPECT_THROW(cudaq::qec::get_decoder("single_error_lut", dem_text, opts), std::runtime_error); } + +// --------------------------------------------------------------------------- +// Tests for enqueue_syndrome decode_result_type routing +// --------------------------------------------------------------------------- + +// Verify that enqueue_syndrome uses decode() output directly as corrections +// when get_result_type() == decode_to_obs, bypassing the O_sparse projection. +TEST(EnqueueSyndrome, ObsFrameDecoderUsesResultDirectly) { + // H: 2 syndrome measurements, 4 physical errors + cudaqx::tensor H_tensor({2, 4}); + H_tensor.at({0, 0}) = 1; + H_tensor.at({1, 1}) = 1; + cudaqx::heterogeneous_map params; + params.insert("decode_to_obs", true); + auto dec = cudaq::qec::decoder::get("sample_decoder", H_tensor, params); + + // D_sparse maps the two enqueued syndrome bits directly to two detector bits. + dec->set_D_sparse(std::vector>{{0}, {1}}); + // Two observables; cols 0/1 are within block_size=4 for validation only. + dec->set_O_sparse(std::vector>{{0}, {1}}); + + bool did_decode = dec->enqueue_syndrome(std::vector{1, 0}); + EXPECT_TRUE(did_decode); + + const uint8_t *corr = dec->get_obs_corrections(); + EXPECT_EQ(corr[0], 1u); + EXPECT_EQ(corr[1], 0u); +} + +// Verify that corrections XOR-accumulate correctly across multiple shots and +// that clear_corrections() resets them between shots. +TEST(EnqueueSyndrome, ObsFrameMultiShotAccumulation) { + cudaqx::tensor H_tensor({2, 4}); + H_tensor.at({0, 0}) = 1; + H_tensor.at({1, 1}) = 1; + cudaqx::heterogeneous_map params; + params.insert("decode_to_obs", true); + auto dec = cudaq::qec::decoder::get("sample_decoder", H_tensor, params); + + dec->set_D_sparse(std::vector>{{0}, {1}}); + dec->set_O_sparse(std::vector>{{0}, {1}}); + + // Shot 1: obs[0]=1, obs[1]=0 -> corrections become [1, 0] + EXPECT_TRUE(dec->enqueue_syndrome(std::vector{1, 0})); + const uint8_t *corr = dec->get_obs_corrections(); + EXPECT_EQ(corr[0], 1u); + EXPECT_EQ(corr[1], 0u); + + // Shot 2 (no reset): obs[0]=1, obs[1]=1 -> corrections XOR to [0, 1] + EXPECT_TRUE(dec->enqueue_syndrome(std::vector{1, 1})); + corr = dec->get_obs_corrections(); + EXPECT_EQ(corr[0], 0u); + EXPECT_EQ(corr[1], 1u); + + // After clear, corrections reset to [0, 0] + dec->clear_corrections(); + corr = dec->get_obs_corrections(); + EXPECT_EQ(corr[0], 0u); + EXPECT_EQ(corr[1], 0u); +} + +// Verify that a result size mismatch against num_observables throws for +// decode_to_obs decoders. +TEST(EnqueueSyndrome, ObsFrameSizeMismatchThrows) { + cudaqx::tensor H_tensor({3, 4}); + H_tensor.at({0, 0}) = 1; + H_tensor.at({1, 1}) = 1; + H_tensor.at({2, 2}) = 1; + cudaqx::heterogeneous_map params; + params.insert("decode_to_obs", true); + auto dec = cudaq::qec::decoder::get("sample_decoder", H_tensor, params); + + dec->set_D_sparse(std::vector>{{0}, {1}, {2}}); + dec->set_O_sparse( + std::vector>{{0}, {1}}); // 2 observables + + // sample_decoder returns all three detector bits in decode_to_obs mode. + EXPECT_THROW(dec->enqueue_syndrome(std::vector{1, 0, 1}), + std::runtime_error); +} From c9b2c08ac0d4448d6dc5c3a2667afabf3ce1718a Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 09:31:05 -0700 Subject: [PATCH 07/19] Mark composite trt_decoder as decode_to_obs The trt_decoder constructed with an "O" observable matrix projects to observables internally, so it must report decode_result_type::decode_to_obs to enqueue_syndrome(). Set the result type where decode_to_observables_ is enabled, and assert it in the composite test. Signed-off-by: Melody Ren --- libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp | 1 + libs/qec/unittests/realtime/test_trt_decoder_composite.cpp | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp b/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp index 614fcd98..a3162993 100644 --- a/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp +++ b/libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp @@ -782,6 +782,7 @@ trt_decoder::trt_decoder(const cudaq::qec::sparse_binary_matrix &H, } decode_to_observables_ = true; num_observables_ = O.shape()[0]; + set_result_type(decode_result_type::decode_to_obs); // The TRT model output must encode [pre_L (num_observables_ entries), // residual_dets (rest)]. Validate sizing where we can. diff --git a/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp b/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp index 56714f3a..5a94035a 100644 --- a/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp +++ b/libs/qec/unittests/realtime/test_trt_decoder_composite.cpp @@ -600,6 +600,12 @@ int main(int argc, char *argv[]) { << " row(s).\n"; return 1; } + if (setup.decoder->get_result_type() != + cudaq::qec::decoder::decode_result_type::decode_to_obs) { + std::cerr << "ERROR: composite trt_decoder must report decode_to_obs " + "when constructed with O.\n"; + return 1; + } std::cout << "[Setup] Init mode: " << setup.init_mode << "\n"; if (!setup.config_yaml_path.empty()) From 1ff493238ae10b7d0444fe11d4cb5f20c5185796 Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 12:10:48 -0700 Subject: [PATCH 08/19] enable logging for obs path in enqueue_syndrome Signed-off-by: Melody Ren --- libs/qec/lib/decoder.cpp | 14 ++- libs/qec/utils/replay_decoder_logs.py | 140 ++++++++++++++++++-------- 2 files changed, 106 insertions(+), 48 deletions(-) diff --git a/libs/qec/lib/decoder.cpp b/libs/qec/lib/decoder.cpp index 56bac8f3..ac7d40e1 100644 --- a/libs/qec/lib/decoder.cpp +++ b/libs/qec/lib/decoder.cpp @@ -292,6 +292,7 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, std::vector log_msyn; std::vector log_detectors; std::vector log_errors; + std::vector log_observables; std::vector log_observable_corrections; // The four time points are used to measure the duration of each of 3 steps. std::chrono::time_point log_t0, log_t1, @@ -305,6 +306,7 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, if (should_log) { log_t0 = std::chrono::high_resolution_clock::now(); log_errors.reserve(syndrome_length); + log_observables.reserve(O_sparse.size()); log_observable_corrections.resize(O_sparse.size()); } @@ -365,6 +367,8 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, // Process the results. // TODO - should this interrogate the decoded_result.converged flag? const auto result_type = get_result_type(); + const auto *result_type_str = + result_type == decode_result_type::decode_to_obs ? "obs" : "errs"; const auto num_observables = get_num_observables(); const std::size_t expected_result_size = result_type == decode_result_type::decode_to_obs ? num_observables @@ -383,8 +387,6 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, if (should_log) { log_t2 = std::chrono::high_resolution_clock::now(); - // TODO: log_errors is meaningful only for decode_to_errs; revisit on - // the logging pass. if (result_type != decode_result_type::decode_to_obs) { for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++) if (decoded_result.result[e]) @@ -396,6 +398,8 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, // internal "O" matrix; use the result directly. for (std::size_t i = 0; i < num_observables; i++) { if (decoded_result.result[i]) { + if (should_log) + log_observables.push_back(i); pimpl->corrections[i] ^= 1; if (should_log) log_observable_corrections[i] ^= 1; @@ -426,13 +430,15 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, pimpl->log_counter++; auto s = fmt::format( "[DecoderStats][{}] Counter:{} DecoderId:{} InputMsyn:{} " - "InputDetectors:{} Converged:{} Errors:{} " + "InputDetectors:{} Converged:{} ResultType:{} Errors:{} " + "Observables:{} " "ObservableCorrectionsThisCall:{} ObservableCorrectionsTotal:{} " "Dur1:{:.1f}us Dur2:{:.1f}us Dur3:{:.1f}us", static_cast(this), pimpl->log_counter, pimpl->decoder_id, fmt::join(log_msyn, ","), fmt::join(log_detectors, ","), decoded_result.converged ? 1 : 0, - fmt::join(log_errors, ","), + result_type_str, fmt::join(log_errors, ","), + fmt::join(log_observables, ","), fmt::join(log_observable_corrections, ","), fmt::join(std::vector(pimpl->corrections.begin(), pimpl->corrections.end()), diff --git a/libs/qec/utils/replay_decoder_logs.py b/libs/qec/utils/replay_decoder_logs.py index 7a7d934d..9096c80e 100644 --- a/libs/qec/utils/replay_decoder_logs.py +++ b/libs/qec/utils/replay_decoder_logs.py @@ -39,7 +39,8 @@ def sparse_to_dense(sparse_list, num_rows, num_cols, dtype=numpy.uint8): # decoder is created, a dummy decode call is made to "warm up" the decoder, so # you may see more decode calls than shots. def parse_decoder_log(decoder_log_file, log_detectors_sparse, log_errors_sparse, - log_observables_dense, decoder_id_list): + log_observables_sparse, log_observables_dense, + log_result_types, decoder_id_list): # running id of the last decoder seen (needed since the decoder id is not # included in the 1 very verbose decode log message). last_decoder_id = -1 @@ -55,32 +56,48 @@ def parse_decoder_log(decoder_log_file, log_detectors_sparse, log_errors_sparse, line = line.split("[DecoderStats]")[1] # print(line) if "InputDetectors:" in line: # this is a decode call - line = line.split(" ") - for elem in line: + fields = {} + for elem in line.split(" "): if ":" in elem: - key, value = elem.split(":") - # print(key, value) - if key == "InputDetectors": - if value == "": - log_detectors_sparse.append([]) - else: - log_detectors_sparse.append( - [int(x) for x in value.split(",")]) - if last_decoder_id == -1: - print( - f"Error: last_decoder_id is -1. This is a fatal error processing the log file." - ) - exit(1) - decoder_id_list.append(last_decoder_id) - elif key == "Errors": - if value == "": - log_errors_sparse.append([]) - else: - log_errors_sparse.append( - [int(x) for x in value.split(",")]) - elif key == "ObservableCorrectionsThisCall": - log_observables_dense.append( - [int(x) for x in value.split(",")]) + key, value = elem.split(":", 1) + fields[key] = value + + value = fields["InputDetectors"] + if value == "": + log_detectors_sparse.append([]) + else: + log_detectors_sparse.append( + [int(x) for x in value.split(",")]) + if last_decoder_id == -1: + print( + f"Error: last_decoder_id is -1. This is a fatal error processing the log file." + ) + exit(1) + decoder_id_list.append(last_decoder_id) + + value = fields.get("Errors", "") + if value == "": + log_errors_sparse.append([]) + else: + log_errors_sparse.append( + [int(x) for x in value.split(",")]) + + value = fields.get("Observables", "") + if value == "": + log_observables_sparse.append([]) + else: + log_observables_sparse.append( + [int(x) for x in value.split(",")]) + + value = fields["ObservableCorrectionsThisCall"] + log_observables_dense.append( + [int(x) for x in value.split(",")]) + + value = fields.get("ResultType", "errs") + result_types = {x for x in value.split(",") if x} + if not result_types: + result_types = {"errs"} + log_result_types.append(result_types) # ---------------------------------------------------------------------------- # @@ -115,6 +132,9 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): decoder_custom_args[key] = float(value) elif type(value) == bool: decoder_custom_args[key] = bool(value) + # Replaying an obs-frame decoder reconstructs the full composite + # decoder here, so trt_decoder replay needs TensorRT, a GPU, and + # the referenced ONNX/engine artifacts. LUT replay is lighter. decoders.append( qec.get_decoder(decoder['type'], H, **decoder_custom_args)) print(f"Decoder {decoder_id} created.") @@ -149,14 +169,17 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): log_detectors_sparse = [] # Detection events seen in the log file. log_errors_sparse = [] # Errors seen in the log file. replay_errors_sparse = [] # Errors seen in the replay. +log_observables_sparse = [] # Observable results seen in the log file. log_observables_dense = [] # Observable flips seen in the log file. replay_observables_dense = [] # Observable flips calculated in the replay. +log_result_types = [] # ResultType tokens seen in each log decode call. decoders = [] O_per_decoder = [] parse_decoder_log(args.decoder_log, log_detectors_sparse, log_errors_sparse, - log_observables_dense, decoder_id_list) + log_observables_sparse, log_observables_dense, + log_result_types, decoder_id_list) parse_decoder_config(args.config, decoders, O_per_decoder) # Basic error checking @@ -171,6 +194,7 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): # Now loop through the syndromes and compare the results. decode_call_idx = 0 replay_error_mismatch = 0 +replay_observable_result_mismatch = 0 replay_observable_mismatch = 0 print(f'Processing {len(log_detectors_sparse)} decode calls.') for s, o in zip(log_detectors_sparse, log_observables_dense): @@ -181,28 +205,53 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): for idx in s: syndrome[idx] = 1 result = decoders[decoder_id_list[decode_call_idx]].decode(syndrome) - dec_err_sparse = [ + result_types = log_result_types[decode_call_idx] + + mismatch_flag = False + decoded_sparse = [ i for i in range(len(result.result)) if result.result[i] > 0.5 ] - replay_errors_sparse.append(dec_err_sparse) - mismatch_flag = False - if dec_err_sparse != log_errors_sparse[decode_call_idx]: - replay_error_mismatch += 1 - mismatch_flag = True - if args.verbose_on_mismatch: - print( - f"Replay mismatch in error in decode_call_idx {decode_call_idx}" - ) - print(f"Decoded errors : {dec_err_sparse}") - print(f"Expected errors: {log_errors_sparse[decode_call_idx]}") - dec_err_dense = numpy.array(result.result, dtype=numpy.uint8) - O_replay = ( - O_per_decoder[decoder_id_list[decode_call_idx]] @ dec_err_dense % - 2).astype(numpy.uint8) + if "errs" in result_types: + dec_err_sparse = decoded_sparse + replay_errors_sparse.append(dec_err_sparse) + if dec_err_sparse != log_errors_sparse[decode_call_idx]: + replay_error_mismatch += 1 + mismatch_flag = True + if args.verbose_on_mismatch: + print( + f"Replay mismatch in error in decode_call_idx {decode_call_idx}" + ) + print(f"Decoded errors : {dec_err_sparse}") + print(f"Expected errors: {log_errors_sparse[decode_call_idx]}") + dec_err_dense = numpy.array(result.result, dtype=numpy.uint8) + O_replay = ( + O_per_decoder[decoder_id_list[decode_call_idx]] @ dec_err_dense % + 2).astype(numpy.uint8) + decoded_observables_sparse = [ + i for i in range(len(O_replay)) if O_replay[i] + ] + else: + replay_errors_sparse.append([]) + O_replay = numpy.array([1 if x > 0.5 else 0 for x in result.result], + dtype=numpy.uint8) + decoded_observables_sparse = decoded_sparse + + if "obs" in result_types: + expected_observables_sparse = log_observables_sparse[decode_call_idx] + if decoded_observables_sparse != expected_observables_sparse: + replay_observable_result_mismatch += 1 + mismatch_flag = True + if args.verbose_on_mismatch: + print( + f"Replay mismatch in observable result in decode_call_idx {decode_call_idx}" + ) + print(f"Decoded observable result : {decoded_observables_sparse}") + print(f"Expected observable result: {expected_observables_sparse}") + replay_observables_dense.append(O_replay) O_log = numpy.array(log_observables_dense[decode_call_idx], dtype=numpy.uint8) - if (O_replay != O_log).any(): + if O_replay.shape != O_log.shape or (O_replay != O_log).any(): replay_observable_mismatch += 1 mismatch_flag = True if args.verbose_on_mismatch: @@ -220,6 +269,9 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): print() print(f"Number of error mismatches during replay: {replay_error_mismatch}") +print( + f"Number of observable result mismatches during replay: {replay_observable_result_mismatch}" +) print( f"Number of observable mismatches during replay: {replay_observable_mismatch}" ) From 626e40d7a66cde20de96bfd52d39041225b33535 Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 12:43:15 -0700 Subject: [PATCH 09/19] Fix trt_decoder_config monostate round-trip A monostate global_decoder_params (no global decoder attached) was being mutated into a default pymatching_decoder_config across a serialize -> deserialize cycle, through two independent serialization layers: 1. heterogeneous_map: to_heterogeneous_map() emitted an empty global_decoder_params map whenever global_decoder was set but the params were monostate, which read back as a pymatching config. 2. YAML MappingTraits (the path used by to_yaml_str/from_yaml_str, and thus by save_dem/load_dem): mapOptional emitted an empty 'global_decoder_params: {}' for the monostate case, which read back into a default pymatching config. Both layers now emit nothing for monostate. Any runtime need for an empty params map is handled in prepare_decoder_params (realtime_decoding.cpp), not in serialization. The heterogeneous_map path also rejects a params map that carries global_decoder_params without a global_decoder. Add regression tests: monostate round-trips unchanged through both YAML and heterogeneous_map and emits no params key; params-without-decoder throws. Signed-off-by: Melody Ren --- libs/qec/lib/realtime/config.cpp | 24 +++++++++---- libs/qec/unittests/test_decoders_yaml.cpp | 41 +++++++++++++++++++++++ 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/libs/qec/lib/realtime/config.cpp b/libs/qec/lib/realtime/config.cpp index 424f857c..c542957c 100644 --- a/libs/qec/lib/realtime/config.cpp +++ b/libs/qec/lib/realtime/config.cpp @@ -251,12 +251,12 @@ cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const { config_map.insert( "global_decoder_params", global_decoder_config_to_heterogeneous_map(global_decoder_params)); - } else if (global_decoder.has_value()) { - // trt_decoder attaches a global decoder only when both the decoder name and - // parameter map are present. An empty map is valid for decoders that do not - // need extra parameters. - config_map.insert("global_decoder_params", cudaqx::heterogeneous_map()); } + // Note: when global_decoder_params is monostate we intentionally emit + // nothing, even if global_decoder is set. Inventing an empty params map here + // would round-trip back as a default pymatching_decoder_config, mutating the + // config. Any runtime need for an empty params map is handled in + // prepare_decoder_params (realtime_decoding.cpp), not in serialization. return config_map; } @@ -273,6 +273,9 @@ trt_decoder_config trt_decoder_config::from_heterogeneous_map( GET_ARG(use_cuda_graph); GET_ARG(global_decoder); if (map.contains("global_decoder_params")) { + if (!config.global_decoder.has_value()) + throw std::runtime_error( + "global_decoder_params present but global_decoder is not set."); try { config.global_decoder_params = map.get("global_decoder_params"); @@ -468,7 +471,16 @@ struct MappingTraits { io.mapOptional("batch_size", config.batch_size); io.mapOptional("use_cuda_graph", config.use_cuda_graph); io.mapOptional("global_decoder", config.global_decoder); - io.mapOptional("global_decoder_params", config.global_decoder_params); + // Emit global_decoder_params only when it actually holds params. Mapping it + // unconditionally on output writes an empty `global_decoder_params: {}` for + // the monostate case, which deserializes back into a default + // pymatching_decoder_config -- mutating a monostate config across a YAML + // round-trip. On input we always map it: an absent key leaves the variant + // at its monostate default (mapOptional skips the nested mapping), and a + // present key is parsed into pymatching_decoder_config. + if (!io.outputting() || + !std::holds_alternative(config.global_decoder_params)) + io.mapOptional("global_decoder_params", config.global_decoder_params); if (!std::holds_alternative(config.global_decoder_params) && config.global_decoder.has_value() && diff --git a/libs/qec/unittests/test_decoders_yaml.cpp b/libs/qec/unittests/test_decoders_yaml.cpp index e803d932..7f8d30a4 100644 --- a/libs/qec/unittests/test_decoders_yaml.cpp +++ b/libs/qec/unittests/test_decoders_yaml.cpp @@ -272,6 +272,47 @@ TEST(DecoderYAMLTest, TrtDecoderRealtimeParamsIncludeObservableMatrix) { EXPECT_EQ(global_O.shape()[1], config.block_size); } +// Regression test for the round-trip fix: a trt config whose global_decoder is +// set but whose global_decoder_params is left as the default (monostate) must +// round-trip without inventing a default pymatching_decoder_config. Previously +// to_heterogeneous_map emitted an empty params map that deserialized back into a +// pymatching_decoder_config, silently mutating the config. +TEST(DecoderYAMLTest, TrtDecoderMonostateParamsRoundTrip) { + auto config = create_test_decoder_config_trt(0); + auto &trt_config = std::get( + config.decoder_custom_args); + // Keep the decoder name, but drop the params back to monostate. + trt_config.global_decoder = "pymatching"; + trt_config.global_decoder_params = std::monostate{}; + + // Serialization must NOT emit a global_decoder_params entry for monostate. + auto params = config.decoder_custom_args_to_heterogeneous_map(); + EXPECT_FALSE(params.contains("global_decoder_params")); + + // Full YAML round-trip must preserve monostate (not become pymatching). + cudaq::qec::decoding::config::multi_decoder_config multi_config; + multi_config.decoders.push_back(config); + test_decoder_yaml_roundtrip(multi_config); + const auto &rt = std::get( + multi_config.decoders[0].decoder_custom_args); + EXPECT_TRUE(std::holds_alternative(rt.global_decoder_params)); +} + +// Regression test for the round-trip fix: a params map carrying +// global_decoder_params but no global_decoder is malformed and must be rejected +// rather than silently constructing a default pymatching config. +TEST(DecoderYAMLTest, TrtDecoderParamsWithoutDecoderThrows) { + cudaqx::heterogeneous_map map; + map.insert("onnx_load_path", std::string("/tmp/predecoder.onnx")); + cudaqx::heterogeneous_map gd_params; + gd_params.insert("merge_strategy", std::string("smallest_weight")); + map.insert("global_decoder_params", gd_params); + EXPECT_THROW( + cudaq::qec::decoding::config::trt_decoder_config::from_heterogeneous_map( + map), + std::runtime_error); +} + TEST(DecoderYAMLTest, SlidingWindowDecoder) { std::size_t n_rounds = 4; std::size_t n_errs_per_round = 30; From bf77915b6165b7826d978e59301db374cca3af44 Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 12:54:40 -0700 Subject: [PATCH 10/19] format trt decoder config round-trip test Signed-off-by: Melody Ren --- libs/qec/unittests/test_decoders_yaml.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/qec/unittests/test_decoders_yaml.cpp b/libs/qec/unittests/test_decoders_yaml.cpp index 7f8d30a4..eabcbc0d 100644 --- a/libs/qec/unittests/test_decoders_yaml.cpp +++ b/libs/qec/unittests/test_decoders_yaml.cpp @@ -275,8 +275,8 @@ TEST(DecoderYAMLTest, TrtDecoderRealtimeParamsIncludeObservableMatrix) { // Regression test for the round-trip fix: a trt config whose global_decoder is // set but whose global_decoder_params is left as the default (monostate) must // round-trip without inventing a default pymatching_decoder_config. Previously -// to_heterogeneous_map emitted an empty params map that deserialized back into a -// pymatching_decoder_config, silently mutating the config. +// to_heterogeneous_map emitted an empty params map that deserialized back into +// a pymatching_decoder_config, silently mutating the config. TEST(DecoderYAMLTest, TrtDecoderMonostateParamsRoundTrip) { auto config = create_test_decoder_config_trt(0); auto &trt_config = std::get( From 7d3f09f84ecdcc0c00df88c6ca55d797c57e0644 Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 14:20:39 -0700 Subject: [PATCH 11/19] Restore no-O global decoder and reject params-without-decoder on write Two follow-ups to the monostate round-trip fix: 1. prepare_decoder_params now synthesizes an empty global_decoder_params map for a pymatching global decoder before the O_sparse early return. Since serialization stopped emitting an empty params map for monostate, a global decoder configured to run on residual detectors without an O matrix was no longer attached by the plugin (which requires both global_decoder and global_decoder_params keys). This is a documented, valid configuration, so restore it in runtime prep where it belongs. 2. trt_decoder_config::to_heterogeneous_map now throws when global_decoder_params is set but global_decoder is not, matching the rejection already enforced by from_heterogeneous_map. Add regression tests for both. Signed-off-by: Melody Ren --- libs/qec/lib/realtime/config.cpp | 6 +++- libs/qec/lib/realtime/realtime_decoding.cpp | 26 +++++++++++------ libs/qec/unittests/test_decoders_yaml.cpp | 31 +++++++++++++++++++++ 3 files changed, 54 insertions(+), 9 deletions(-) diff --git a/libs/qec/lib/realtime/config.cpp b/libs/qec/lib/realtime/config.cpp index c542957c..55f42356 100644 --- a/libs/qec/lib/realtime/config.cpp +++ b/libs/qec/lib/realtime/config.cpp @@ -244,7 +244,11 @@ cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const { INSERT_ARG(use_cuda_graph); INSERT_ARG(global_decoder); if (!std::holds_alternative(global_decoder_params)) { - if (global_decoder.has_value() && global_decoder.value() != "pymatching") { + if (!global_decoder.has_value()) { + throw std::runtime_error( + "global_decoder_params present but global_decoder is not set."); + } + if (global_decoder.value() != "pymatching") { throw std::runtime_error( "global_decoder_params currently supports only pymatching."); } diff --git a/libs/qec/lib/realtime/realtime_decoding.cpp b/libs/qec/lib/realtime/realtime_decoding.cpp index 0e627786..2a33adbc 100644 --- a/libs/qec/lib/realtime/realtime_decoding.cpp +++ b/libs/qec/lib/realtime/realtime_decoding.cpp @@ -46,7 +46,21 @@ namespace cudaq::qec::decoding::host { cudaqx::heterogeneous_map prepare_decoder_params( const cudaq::qec::decoding::config::decoder_config &decoder_config) { auto params = decoder_config.decoder_custom_args_to_heterogeneous_map(); - if (decoder_config.type != "trt_decoder" || decoder_config.O_sparse.empty()) + if (decoder_config.type != "trt_decoder") + return params; + + // The trt_decoder plugin attaches a pymatching global decoder only when both + // "global_decoder" and "global_decoder_params" are present. Serialization no + // longer emits an empty params map for the monostate (no-params) case, so + // synthesize one here -- before the O_sparse early return -- so that a global + // decoder running on residual detectors without an O matrix still attaches. + const bool has_pymatching_global = + params.contains("global_decoder") && + params.get("global_decoder") == "pymatching"; + if (has_pymatching_global && !params.contains("global_decoder_params")) + params.insert("global_decoder_params", cudaqx::heterogeneous_map()); + + if (decoder_config.O_sparse.empty()) return params; const auto num_observables = std::count(decoder_config.O_sparse.begin(), @@ -58,13 +72,9 @@ cudaqx::heterogeneous_map prepare_decoder_params( decoder_config.O_sparse, num_observables, decoder_config.block_size); params.insert("O", O); - if (params.contains("global_decoder") && - params.get("global_decoder") == "pymatching") { - cudaqx::heterogeneous_map global_decoder_params; - if (params.contains("global_decoder_params")) { - global_decoder_params = - params.get("global_decoder_params"); - } + if (has_pymatching_global) { + auto global_decoder_params = + params.get("global_decoder_params"); global_decoder_params.insert("O", O); params.insert("global_decoder_params", global_decoder_params); } diff --git a/libs/qec/unittests/test_decoders_yaml.cpp b/libs/qec/unittests/test_decoders_yaml.cpp index eabcbc0d..93c112f1 100644 --- a/libs/qec/unittests/test_decoders_yaml.cpp +++ b/libs/qec/unittests/test_decoders_yaml.cpp @@ -313,6 +313,37 @@ TEST(DecoderYAMLTest, TrtDecoderParamsWithoutDecoderThrows) { std::runtime_error); } +// A pymatching global decoder with no params and no O matrix must still attach: +// prepare_decoder_params synthesizes an empty global_decoder_params map before +// the O early-return so the plugin sees both keys it requires. +TEST(DecoderYAMLTest, TrtDecoderGlobalDecoderWithoutObservables) { + auto config = create_test_decoder_config_trt(0); + auto &trt_config = std::get( + config.decoder_custom_args); + trt_config.global_decoder = "pymatching"; + trt_config.global_decoder_params = std::monostate{}; + config.O_sparse.clear(); // no observables + + auto params = cudaq::qec::decoding::host::prepare_decoder_params(config); + EXPECT_TRUE(params.contains("global_decoder")); + EXPECT_TRUE(params.contains("global_decoder_params")); + EXPECT_FALSE(params.contains("O")); +} + +// Serialization rejects a malformed in-memory config that carries +// global_decoder_params but no global_decoder, symmetric with +// from_heterogeneous_map. +TEST(DecoderYAMLTest, TrtDecoderToMapParamsWithoutDecoderThrows) { + cudaq::qec::decoding::config::trt_decoder_config trt_config; + trt_config.onnx_load_path = "/tmp/predecoder.onnx"; + auto pymatching_params = + cudaq::qec::decoding::config::pymatching_decoder_config(); + pymatching_params.merge_strategy = "smallest_weight"; + trt_config.global_decoder_params = pymatching_params; // non-monostate + // global_decoder intentionally left unset. + EXPECT_THROW(trt_config.to_heterogeneous_map(), std::runtime_error); +} + TEST(DecoderYAMLTest, SlidingWindowDecoder) { std::size_t n_rounds = 4; std::size_t n_errs_per_round = 30; From 1cf7479489573e506b71ccd097892f6e2c914ee7 Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 14:24:37 -0700 Subject: [PATCH 12/19] document realtime global decoder config support Signed-off-by: Melody Ren --- libs/qec/include/cudaq/qec/realtime/decoding_config.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libs/qec/include/cudaq/qec/realtime/decoding_config.h b/libs/qec/include/cudaq/qec/realtime/decoding_config.h index 1b0df079..561981ca 100644 --- a/libs/qec/include/cudaq/qec/realtime/decoding_config.h +++ b/libs/qec/include/cudaq/qec/realtime/decoding_config.h @@ -102,6 +102,10 @@ struct pymatching_decoder_config { from_heterogeneous_map(const cudaqx::heterogeneous_map &map); }; +// The realtime trt_decoder config currently models only PyMatching as a global +// decoder. Other global decoder plugins may be constructed through lower-level +// APIs when their parameters are supplied directly, but they are not serialized +// by this config variant yet. using global_decoder_config = std::variant; From 489175cd722869c9a912028b052a399d3344ce73 Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 14:53:07 -0700 Subject: [PATCH 13/19] handle decoder result types explicitly Signed-off-by: Melody Ren --- libs/qec/lib/decoder.cpp | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/libs/qec/lib/decoder.cpp b/libs/qec/lib/decoder.cpp index ac7d40e1..9c3e23ba 100644 --- a/libs/qec/lib/decoder.cpp +++ b/libs/qec/lib/decoder.cpp @@ -367,12 +367,26 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, // Process the results. // TODO - should this interrogate the decoded_result.converged flag? const auto result_type = get_result_type(); - const auto *result_type_str = - result_type == decode_result_type::decode_to_obs ? "obs" : "errs"; const auto num_observables = get_num_observables(); - const std::size_t expected_result_size = - result_type == decode_result_type::decode_to_obs ? num_observables - : block_size; + const char *result_type_str = nullptr; + const char *result_type_name = nullptr; + std::size_t expected_result_size = 0; + switch (result_type) { + case decode_result_type::decode_to_errs: + result_type_str = "errs"; + result_type_name = "decode_to_errs"; + expected_result_size = block_size; + break; + case decode_result_type::decode_to_obs: + result_type_str = "obs"; + result_type_name = "decode_to_obs"; + expected_result_size = num_observables; + break; + } + if (!result_type_name) + throw std::runtime_error( + fmt::format("Unsupported decoder result type ({})", + static_cast(result_type))); if ((!pimpl->is_sliding_window && decoded_result.result.size() != expected_result_size) || (pimpl->is_sliding_window && !decoded_result.result.empty() && @@ -381,19 +395,23 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, "Decoder result size ({}) does not match expected size ({}) for " "result type {}", decoded_result.result.size(), expected_result_size, - result_type == decode_result_type::decode_to_obs ? "decode_to_obs" - : "decode_to_errs")); + result_type_name)); } if (should_log) { log_t2 = std::chrono::high_resolution_clock::now(); - if (result_type != decode_result_type::decode_to_obs) { + switch (result_type) { + case decode_result_type::decode_to_errs: for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++) if (decoded_result.result[e]) log_errors.push_back(e); + break; + case decode_result_type::decode_to_obs: + break; } } - if (result_type == decode_result_type::decode_to_obs) { + switch (result_type) { + case decode_result_type::decode_to_obs: // Observable-frame path: decoder already projected to observables via its // internal "O" matrix; use the result directly. for (std::size_t i = 0; i < num_observables; i++) { @@ -405,7 +423,8 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, log_observable_corrections[i] ^= 1; } } - } else { + break; + case decode_result_type::decode_to_errs: // Error-frame path: decoder returns block-sized error vector; project to // observables via O_sparse. // For each observable @@ -421,6 +440,7 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, } } } + break; } if (should_log) { log_t3 = std::chrono::high_resolution_clock::now(); From 7c003a8a6bfdcdda47fdc084d26654d00c528d72 Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 14:57:28 -0700 Subject: [PATCH 14/19] consolidate trt decoder yaml tests Signed-off-by: Melody Ren --- libs/qec/unittests/test_decoders_yaml.cpp | 50 +++++------------------ 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/libs/qec/unittests/test_decoders_yaml.cpp b/libs/qec/unittests/test_decoders_yaml.cpp index 93c112f1..b1fd5152 100644 --- a/libs/qec/unittests/test_decoders_yaml.cpp +++ b/libs/qec/unittests/test_decoders_yaml.cpp @@ -272,35 +272,30 @@ TEST(DecoderYAMLTest, TrtDecoderRealtimeParamsIncludeObservableMatrix) { EXPECT_EQ(global_O.shape()[1], config.block_size); } -// Regression test for the round-trip fix: a trt config whose global_decoder is -// set but whose global_decoder_params is left as the default (monostate) must -// round-trip without inventing a default pymatching_decoder_config. Previously -// to_heterogeneous_map emitted an empty params map that deserialized back into -// a pymatching_decoder_config, silently mutating the config. -TEST(DecoderYAMLTest, TrtDecoderMonostateParamsRoundTrip) { +TEST(DecoderYAMLTest, TrtDecoderMonostateGlobalDecoderParams) { auto config = create_test_decoder_config_trt(0); auto &trt_config = std::get( config.decoder_custom_args); - // Keep the decoder name, but drop the params back to monostate. trt_config.global_decoder = "pymatching"; trt_config.global_decoder_params = std::monostate{}; - // Serialization must NOT emit a global_decoder_params entry for monostate. auto params = config.decoder_custom_args_to_heterogeneous_map(); EXPECT_FALSE(params.contains("global_decoder_params")); - // Full YAML round-trip must preserve monostate (not become pymatching). cudaq::qec::decoding::config::multi_decoder_config multi_config; multi_config.decoders.push_back(config); test_decoder_yaml_roundtrip(multi_config); - const auto &rt = std::get( - multi_config.decoders[0].decoder_custom_args); - EXPECT_TRUE(std::holds_alternative(rt.global_decoder_params)); + + params = cudaq::qec::decoding::host::prepare_decoder_params(config); + EXPECT_TRUE(params.contains("global_decoder_params")); + EXPECT_TRUE(params.contains("O")); + + config.O_sparse.clear(); + params = cudaq::qec::decoding::host::prepare_decoder_params(config); + EXPECT_TRUE(params.contains("global_decoder_params")); + EXPECT_FALSE(params.contains("O")); } -// Regression test for the round-trip fix: a params map carrying -// global_decoder_params but no global_decoder is malformed and must be rejected -// rather than silently constructing a default pymatching config. TEST(DecoderYAMLTest, TrtDecoderParamsWithoutDecoderThrows) { cudaqx::heterogeneous_map map; map.insert("onnx_load_path", std::string("/tmp/predecoder.onnx")); @@ -311,36 +306,13 @@ TEST(DecoderYAMLTest, TrtDecoderParamsWithoutDecoderThrows) { cudaq::qec::decoding::config::trt_decoder_config::from_heterogeneous_map( map), std::runtime_error); -} - -// A pymatching global decoder with no params and no O matrix must still attach: -// prepare_decoder_params synthesizes an empty global_decoder_params map before -// the O early-return so the plugin sees both keys it requires. -TEST(DecoderYAMLTest, TrtDecoderGlobalDecoderWithoutObservables) { - auto config = create_test_decoder_config_trt(0); - auto &trt_config = std::get( - config.decoder_custom_args); - trt_config.global_decoder = "pymatching"; - trt_config.global_decoder_params = std::monostate{}; - config.O_sparse.clear(); // no observables - auto params = cudaq::qec::decoding::host::prepare_decoder_params(config); - EXPECT_TRUE(params.contains("global_decoder")); - EXPECT_TRUE(params.contains("global_decoder_params")); - EXPECT_FALSE(params.contains("O")); -} - -// Serialization rejects a malformed in-memory config that carries -// global_decoder_params but no global_decoder, symmetric with -// from_heterogeneous_map. -TEST(DecoderYAMLTest, TrtDecoderToMapParamsWithoutDecoderThrows) { cudaq::qec::decoding::config::trt_decoder_config trt_config; trt_config.onnx_load_path = "/tmp/predecoder.onnx"; auto pymatching_params = cudaq::qec::decoding::config::pymatching_decoder_config(); pymatching_params.merge_strategy = "smallest_weight"; - trt_config.global_decoder_params = pymatching_params; // non-monostate - // global_decoder intentionally left unset. + trt_config.global_decoder_params = pymatching_params; EXPECT_THROW(trt_config.to_heterogeneous_map(), std::runtime_error); } From 62dd9258361e822b3e0dd93b39dcd1afae37f91a Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 15:25:08 -0700 Subject: [PATCH 15/19] clean up enqueue_syndrome result handling bloat Signed-off-by: Melody Ren --- libs/qec/lib/decoder.cpp | 55 ++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 31 deletions(-) diff --git a/libs/qec/lib/decoder.cpp b/libs/qec/lib/decoder.cpp index 9c3e23ba..6e95a776 100644 --- a/libs/qec/lib/decoder.cpp +++ b/libs/qec/lib/decoder.cpp @@ -398,48 +398,41 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome, result_type_name)); } - if (should_log) { + // Flip an observable correction and mirror it into the per-call log so the + // logged flips stay faithful to the applied corrections. + auto flip_correction = [&](std::size_t i) { + pimpl->corrections[i] ^= 1; + if (should_log) + log_observable_corrections[i] ^= 1; + }; + + if (should_log) log_t2 = std::chrono::high_resolution_clock::now(); - switch (result_type) { - case decode_result_type::decode_to_errs: - for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++) - if (decoded_result.result[e]) - log_errors.push_back(e); - break; - case decode_result_type::decode_to_obs: - break; - } - } + switch (result_type) { case decode_result_type::decode_to_obs: // Observable-frame path: decoder already projected to observables via its // internal "O" matrix; use the result directly. - for (std::size_t i = 0; i < num_observables; i++) { + for (std::size_t i = 0; i < num_observables; i++) if (decoded_result.result[i]) { if (should_log) log_observables.push_back(i); - pimpl->corrections[i] ^= 1; - if (should_log) - log_observable_corrections[i] ^= 1; + flip_correction(i); } - } break; case decode_result_type::decode_to_errs: - // Error-frame path: decoder returns block-sized error vector; project to - // observables via O_sparse. - // For each observable - for (std::size_t i = 0; i < num_observables; i++) { - // For each error that flips this observable - for (auto col : O_sparse[i]) { - // If the decoder predicted that this error occurred - if (decoded_result.result[col]) { - // Flip the correction for this observable - pimpl->corrections[i] ^= 1; - if (should_log) - log_observable_corrections[i] ^= 1; - } - } - } + // Error-frame path: decoder returns a block-sized error vector; project + // to observables via O_sparse. + if (should_log) + for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++) + if (decoded_result.result[e]) + log_errors.push_back(e); + // For each observable, flip its correction once for each predicted error + // that flips it (net parity over O_sparse[i]). + for (std::size_t i = 0; i < num_observables; i++) + for (auto col : O_sparse[i]) + if (decoded_result.result[col]) + flip_correction(i); break; } if (should_log) { From f70f2fc6717e9b0da619ef8b6d36f6d033ff7bd1 Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 15:56:30 -0700 Subject: [PATCH 16/19] Fix replay result type handling Signed-off-by: Melody Ren --- libs/qec/utils/replay_decoder_logs.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/libs/qec/utils/replay_decoder_logs.py b/libs/qec/utils/replay_decoder_logs.py index 9096c80e..97acc640 100644 --- a/libs/qec/utils/replay_decoder_logs.py +++ b/libs/qec/utils/replay_decoder_logs.py @@ -230,11 +230,17 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): decoded_observables_sparse = [ i for i in range(len(O_replay)) if O_replay[i] ] - else: + elif "obs" in result_types: replay_errors_sparse.append([]) O_replay = numpy.array([1 if x > 0.5 else 0 for x in result.result], dtype=numpy.uint8) decoded_observables_sparse = decoded_sparse + else: + print( + f"Error: unsupported ResultType set {sorted(result_types)} " + f"in decode_call_idx {decode_call_idx}." + ) + exit(1) if "obs" in result_types: expected_observables_sparse = log_observables_sparse[decode_call_idx] From 4f3e149e58cefb76ad1c31efa733744cdcf696e7 Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 16:09:55 -0700 Subject: [PATCH 17/19] Fix replay script formatting Signed-off-by: Melody Ren --- libs/qec/utils/replay_decoder_logs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/libs/qec/utils/replay_decoder_logs.py b/libs/qec/utils/replay_decoder_logs.py index 97acc640..dee021e0 100644 --- a/libs/qec/utils/replay_decoder_logs.py +++ b/libs/qec/utils/replay_decoder_logs.py @@ -238,8 +238,7 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): else: print( f"Error: unsupported ResultType set {sorted(result_types)} " - f"in decode_call_idx {decode_call_idx}." - ) + f"in decode_call_idx {decode_call_idx}.") exit(1) if "obs" in result_types: From e15766151f08bb59c04a78eaab305883723abeea Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Fri, 12 Jun 2026 16:14:25 -0700 Subject: [PATCH 18/19] Apply yapf to replay script Signed-off-by: Melody Ren --- libs/qec/utils/replay_decoder_logs.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/libs/qec/utils/replay_decoder_logs.py b/libs/qec/utils/replay_decoder_logs.py index dee021e0..803a4cf6 100644 --- a/libs/qec/utils/replay_decoder_logs.py +++ b/libs/qec/utils/replay_decoder_logs.py @@ -236,9 +236,8 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): dtype=numpy.uint8) decoded_observables_sparse = decoded_sparse else: - print( - f"Error: unsupported ResultType set {sorted(result_types)} " - f"in decode_call_idx {decode_call_idx}.") + print(f"Error: unsupported ResultType set {sorted(result_types)} " + f"in decode_call_idx {decode_call_idx}.") exit(1) if "obs" in result_types: @@ -250,8 +249,11 @@ def parse_decoder_config(config_file, decoders, O_per_decoder): print( f"Replay mismatch in observable result in decode_call_idx {decode_call_idx}" ) - print(f"Decoded observable result : {decoded_observables_sparse}") - print(f"Expected observable result: {expected_observables_sparse}") + print( + f"Decoded observable result : {decoded_observables_sparse}") + print( + f"Expected observable result: {expected_observables_sparse}" + ) replay_observables_dense.append(O_replay) O_log = numpy.array(log_observables_dense[decode_call_idx], From 33057f54853982d3a40934de76047dc6900da72f Mon Sep 17 00:00:00 2001 From: Melody Ren Date: Wed, 17 Jun 2026 11:31:57 -0700 Subject: [PATCH 19/19] Unify PyMatching realtime config type Signed-off-by: Melody Ren --- .../cudaq/qec/realtime/decoding_config.h | 16 +---- libs/qec/lib/realtime/config.cpp | 61 ++++++------------- .../python/bindings/py_decoding_config.cpp | 31 ++-------- libs/qec/python/bindings/type_casters.h | 7 +-- libs/qec/unittests/test_decoders_yaml.cpp | 12 ++-- 5 files changed, 29 insertions(+), 98 deletions(-) diff --git a/libs/qec/include/cudaq/qec/realtime/decoding_config.h b/libs/qec/include/cudaq/qec/realtime/decoding_config.h index 53463d80..d8687f8b 100644 --- a/libs/qec/include/cudaq/qec/realtime/decoding_config.h +++ b/libs/qec/include/cudaq/qec/realtime/decoding_config.h @@ -94,19 +94,6 @@ struct single_error_lut_config { from_heterogeneous_map(const cudaqx::heterogeneous_map &map); }; -struct pymatching_decoder_config { - std::optional merge_strategy; - std::optional> error_rate_vec; - - bool operator==(const pymatching_decoder_config &) const = default; - - __attribute__((visibility("default"))) cudaqx::heterogeneous_map - to_heterogeneous_map() const; - - __attribute__((visibility("default"))) static pymatching_decoder_config - from_heterogeneous_map(const cudaqx::heterogeneous_map &map); -}; - struct pymatching_config { std::optional> error_rate_vec; std::optional merge_strategy; @@ -124,8 +111,7 @@ struct pymatching_config { // decoder. Other global decoder plugins may be constructed through lower-level // APIs when their parameters are supplied directly, but they are not serialized // by this config variant yet. -using global_decoder_config = - std::variant; +using global_decoder_config = std::variant; struct trt_decoder_config { std::optional onnx_load_path; diff --git a/libs/qec/lib/realtime/config.cpp b/libs/qec/lib/realtime/config.cpp index 58f010e6..c8737d7c 100644 --- a/libs/qec/lib/realtime/config.cpp +++ b/libs/qec/lib/realtime/config.cpp @@ -190,32 +190,14 @@ single_error_lut_config single_error_lut_config::from_heterogeneous_map( return config; } -// ------ pymatching_decoder_config ------ -cudaqx::heterogeneous_map -pymatching_decoder_config::to_heterogeneous_map() const { - cudaqx::heterogeneous_map config_map; - INSERT_ARG(merge_strategy); - INSERT_ARG(error_rate_vec); - return config_map; -} - -pymatching_decoder_config pymatching_decoder_config::from_heterogeneous_map( - const cudaqx::heterogeneous_map &map) { - pymatching_decoder_config config; - GET_ARG(merge_strategy); - GET_ARG(error_rate_vec); - return config; -} - cudaqx::heterogeneous_map global_decoder_config_to_heterogeneous_map( const global_decoder_config &global_decoder_params) { if (std::holds_alternative(global_decoder_params)) { return cudaqx::heterogeneous_map(); } - if (std::holds_alternative( - global_decoder_params)) { - return std::get(global_decoder_params) + if (std::holds_alternative(global_decoder_params)) { + return std::get(global_decoder_params) .to_heterogeneous_map(); } @@ -230,7 +212,7 @@ global_decoder_config global_decoder_config_from_heterogeneous_map( "global_decoder_params currently supports only pymatching."); } - return pymatching_decoder_config::from_heterogeneous_map(map); + return pymatching_config::from_heterogeneous_map(map); } // ------ pymatching_config ------ @@ -278,8 +260,8 @@ cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const { } // Note: when global_decoder_params is monostate we intentionally emit // nothing, even if global_decoder is set. Inventing an empty params map here - // would round-trip back as a default pymatching_decoder_config, mutating the - // config. Any runtime need for an empty params map is handled in + // would round-trip back as a default pymatching_config, mutating the config. + // Any runtime need for an empty params map is handled in // prepare_decoder_params (realtime_decoding.cpp), not in serialization. return config_map; @@ -300,13 +282,16 @@ trt_decoder_config trt_decoder_config::from_heterogeneous_map( if (!config.global_decoder.has_value()) throw std::runtime_error( "global_decoder_params present but global_decoder is not set."); + if (config.global_decoder.value() != "pymatching") + throw std::runtime_error( + "global_decoder_params currently supports only pymatching."); try { config.global_decoder_params = map.get("global_decoder_params"); } catch (...) { try { config.global_decoder_params = - map.get("global_decoder_params"); + map.get("global_decoder_params"); } catch (...) { auto nested_map = map.get("global_decoder_params"); @@ -450,16 +435,6 @@ struct MappingTraits { cudaq::qec::decoding::config::single_error_lut_config &config) {} }; -template <> -struct MappingTraits { - static void - mapping(IO &io, - cudaq::qec::decoding::config::pymatching_decoder_config &config) { - io.mapOptional("merge_strategy", config.merge_strategy); - io.mapOptional("error_rate_vec", config.error_rate_vec); - } -}; - template <> struct MappingTraits { static void @@ -471,16 +446,16 @@ struct MappingTraits { return; } - auto &pymatching_config = std::get(config); - io.mapOptional("merge_strategy", pymatching_config.merge_strategy); - io.mapOptional("error_rate_vec", pymatching_config.error_rate_vec); + auto ¶ms = std::get(config); + io.mapOptional("merge_strategy", params.merge_strategy); + io.mapOptional("error_rate_vec", params.error_rate_vec); return; } - pymatching_decoder_config pymatching_config; - io.mapOptional("merge_strategy", pymatching_config.merge_strategy); - io.mapOptional("error_rate_vec", pymatching_config.error_rate_vec); - config = std::move(pymatching_config); + pymatching_config params; + io.mapOptional("merge_strategy", params.merge_strategy); + io.mapOptional("error_rate_vec", params.error_rate_vec); + config = std::move(params); } }; @@ -508,10 +483,10 @@ struct MappingTraits { // Emit global_decoder_params only when it actually holds params. Mapping it // unconditionally on output writes an empty `global_decoder_params: {}` for // the monostate case, which deserializes back into a default - // pymatching_decoder_config -- mutating a monostate config across a YAML + // pymatching_config -- mutating a monostate config across a YAML // round-trip. On input we always map it: an absent key leaves the variant // at its monostate default (mapOptional skips the nested mapping), and a - // present key is parsed into pymatching_decoder_config. + // present key is parsed into pymatching_config. if (!io.outputting() || !std::holds_alternative(config.global_decoder_params)) io.mapOptional("global_decoder_params", config.global_decoder_params); diff --git a/libs/qec/python/bindings/py_decoding_config.cpp b/libs/qec/python/bindings/py_decoding_config.cpp index 9b484830..3bc3c454 100644 --- a/libs/qec/python/bindings/py_decoding_config.cpp +++ b/libs/qec/python/bindings/py_decoding_config.cpp @@ -111,27 +111,6 @@ void bindDecodingConfig(nb::module_ &mod) { &multi_error_lut_config::from_heterogeneous_map, nb::arg("map")); - // pymatching_decoder_config - nb::class_( - mod_cfg, "pymatching_decoder_config", "PyMatching decoder configuration.") - .def(nb::init<>()) - .def( - "__init__", - [](config::pymatching_decoder_config &self, - const cudaqx::heterogeneous_map &map) { - new (&self) pymatching_decoder_config( - pymatching_decoder_config::from_heterogeneous_map(map)); - }, - nb::arg("map")) - .def_rw("merge_strategy", &pymatching_decoder_config::merge_strategy) - .def_rw("error_rate_vec", &pymatching_decoder_config::error_rate_vec) - .def("to_heterogeneous_map", - &pymatching_decoder_config::to_heterogeneous_map, - nb::rv_policy::move) - .def_static("from_heterogeneous_map", - &pymatching_decoder_config::from_heterogeneous_map, - nb::arg("map")); - // trt_decoder_config nb::class_(mod_cfg, "trt_decoder_config", "TensorRT decoder configuration.") @@ -162,16 +141,14 @@ void bindDecodingConfig(nb::module_ &mod) { .def_prop_rw( "global_decoder_params", [](const trt_decoder_config &self) - -> std::optional { - if (std::holds_alternative( + -> std::optional { + if (std::holds_alternative( self.global_decoder_params)) { - return std::get( - self.global_decoder_params); + return std::get(self.global_decoder_params); } return std::nullopt; }, - [](trt_decoder_config &self, - std::optional value) { + [](trt_decoder_config &self, std::optional value) { if (value.has_value()) { self.global_decoder_params = value.value(); } else { diff --git a/libs/qec/python/bindings/type_casters.h b/libs/qec/python/bindings/type_casters.h index ee70d569..7de6d454 100644 --- a/libs/qec/python/bindings/type_casters.h +++ b/libs/qec/python/bindings/type_casters.h @@ -194,10 +194,6 @@ struct type_caster { cudaq::qec::decoding::config::single_error_lut_config>( &val)) { result[key.c_str()] = nb::cast(single_cfg->to_heterogeneous_map()); - } else if (auto *pm_cfg = std::any_cast< - cudaq::qec::decoding::config::pymatching_decoder_config>( - &val)) { - result[key.c_str()] = nb::cast(pm_cfg->to_heterogeneous_map()); } else if (auto *global_cfg = std::any_cast< cudaq::qec::decoding::config::global_decoder_config>( &val)) { @@ -205,8 +201,7 @@ struct type_caster { result[key.c_str()] = nb::none(); } else { result[key.c_str()] = nb::cast( - std::get< - cudaq::qec::decoding::config::pymatching_decoder_config>( + std::get( *global_cfg) .to_heterogeneous_map()); } diff --git a/libs/qec/unittests/test_decoders_yaml.cpp b/libs/qec/unittests/test_decoders_yaml.cpp index b1fd5152..ccc2b870 100644 --- a/libs/qec/unittests/test_decoders_yaml.cpp +++ b/libs/qec/unittests/test_decoders_yaml.cpp @@ -211,8 +211,7 @@ create_test_decoder_config_trt(int id) { trt_config.batch_size = 4; trt_config.use_cuda_graph = false; trt_config.global_decoder = "pymatching"; - auto pymatching_params = - cudaq::qec::decoding::config::pymatching_decoder_config(); + auto pymatching_params = cudaq::qec::decoding::config::pymatching_config(); pymatching_params.merge_strategy = "smallest_weight"; pymatching_params.error_rate_vec = std::vector(config.block_size, 0.1); @@ -229,9 +228,9 @@ TEST(DecoderYAMLTest, TrtDecoderConfigRoundTrip) { const auto &trt_config = std::get( multi_config.decoders[0].decoder_custom_args); - EXPECT_TRUE(std::holds_alternative< - cudaq::qec::decoding::config::pymatching_decoder_config>( - trt_config.global_decoder_params)); + EXPECT_TRUE( + std::holds_alternative( + trt_config.global_decoder_params)); } TEST(DecoderYAMLTest, TrtDecoderConfigToHeterogeneousMap) { @@ -309,8 +308,7 @@ TEST(DecoderYAMLTest, TrtDecoderParamsWithoutDecoderThrows) { cudaq::qec::decoding::config::trt_decoder_config trt_config; trt_config.onnx_load_path = "/tmp/predecoder.onnx"; - auto pymatching_params = - cudaq::qec::decoding::config::pymatching_decoder_config(); + auto pymatching_params = cudaq::qec::decoding::config::pymatching_config(); pymatching_params.merge_strategy = "smallest_weight"; trt_config.global_decoder_params = pymatching_params; EXPECT_THROW(trt_config.to_heterogeneous_map(), std::runtime_error);