Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1cd5d10
(qec): Extend trt_decoder with global decoder chaining + observables …
bmhowe23 Apr 29, 2026
6da6ff1
Merge branch 'main' into pr-composite-decoder
wsttiger Apr 30, 2026
84b56f5
Add composite TRT decoder realtime demo
wsttiger May 1, 2026
0418e45
Expand TRT decoder YAML config for composite decoding
wsttiger May 8, 2026
d3a387e
Merge remote-tracking branch 'upstream/main' into update_trt_decoder_…
wsttiger May 11, 2026
7d091ce
Refactor TRT global decoder params to variant config
wsttiger May 12, 2026
054f748
Merge remote-tracking branch 'upstream/main' into update_trt_decoder_…
wsttiger May 12, 2026
d60c4e2
Merge branch 'main' into update_trt_decoder_yaml
wsttiger May 15, 2026
5d1f2af
Merge branch 'main' into update_trt_decoder_yaml
wsttiger May 29, 2026
26be6b4
Restore optional None setter helper
wsttiger May 29, 2026
6348e8e
Merge remote-tracking branch 'upstream/main' into pr536-update-trt
melody-ren Jun 10, 2026
1dff1f1
Add method to set decoding result type
melody-ren Jun 9, 2026
c9b2c08
Mark composite trt_decoder as decode_to_obs
melody-ren Jun 12, 2026
1ff4932
enable logging for obs path in enqueue_syndrome
melody-ren Jun 12, 2026
626e40d
Fix trt_decoder_config monostate round-trip
melody-ren Jun 12, 2026
bf77915
format trt decoder config round-trip test
melody-ren Jun 12, 2026
7d3f09f
Restore no-O global decoder and reject params-without-decoder on write
melody-ren Jun 12, 2026
1cf7479
document realtime global decoder config support
melody-ren Jun 12, 2026
489175c
handle decoder result types explicitly
melody-ren Jun 12, 2026
7c003a8
consolidate trt decoder yaml tests
melody-ren Jun 12, 2026
62dd925
clean up enqueue_syndrome result handling bloat
melody-ren Jun 12, 2026
f70f2fc
Fix replay result type handling
melody-ren Jun 12, 2026
4f3e149
Fix replay script formatting
melody-ren Jun 12, 2026
e157661
Apply yapf to replay script
melody-ren Jun 12, 2026
c190ae0
Merge upstream/main into pr536-update-trt
melody-ren Jun 17, 2026
33057f5
Unify PyMatching realtime config type
melody-ren Jun 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions libs/qec/include/cudaq/qec/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@ class decoder
std::unique_ptr<rt_impl, rt_impl_deleter> pimpl;

public:
/// @brief Indicates whether decode() returns a full error frame (length
/// block_size) or an already-projected observable frame (length
/// num_observables). Decoders that accept an "O" observable matrix in their
/// constructor params should call set_result_type(decode_to_obs); all others
/// default to decode_to_errs.
///
/// Note: even in decode_to_obs mode, set_O_sparse() must still be called so
/// that enqueue_syndrome() knows num_observables and can size the corrections
/// buffer correctly.
enum decode_result_type {
decode_to_errs, ///< result.size() == block_size; enqueue_syndrome projects
///< via O_sparse
decode_to_obs, ///< result.size() == num_observables; enqueue_syndrome uses
///< result directly; set_O_sparse() still required
};

decoder() = delete;

/// @brief Constructor
Expand Down Expand Up @@ -234,6 +250,12 @@ class decoder
// Note: all of the current realtime decoding API is designed to be used with
// hard syndromes.

/// @brief Returns the type of result produced by decode().
/// Defaults to decode_to_errs. Decoders that project to observables
/// internally (i.e., constructed with an "O" param) should call
/// set_result_type(decode_to_obs) in their constructor.
decode_result_type get_result_type() const { return result_type_; }

/// @brief Get the number of measurement syndromes per decode call. This
/// depends on D_sparse, so you must have called set_D_sparse() first.
uint32_t get_num_msyn_per_decode() const;
Expand Down Expand Up @@ -315,6 +337,11 @@ class decoder
virtual std::string get_version() const;

protected:
/// @brief Sets the result type. Call in the constructor when an "O"
/// observable matrix is detected in the decoder params. Must be called
/// before the first enqueue_syndrome().
void set_result_type(decode_result_type type) { result_type_ = type; }

/// @brief For a classical `[n,k]` code, this is `n`.
std::size_t block_size = 0;

Expand All @@ -329,6 +356,9 @@ class decoder

/// @brief The decoder's D matrix in sparse format
std::vector<std::vector<uint32_t>> D_sparse;

private:
decode_result_type result_type_ = decode_result_type::decode_to_errs;
};

/// @brief Convert a single soft probability to a hard 0/1 decision.
Expand Down
10 changes: 10 additions & 0 deletions libs/qec/include/cudaq/qec/realtime/decoding_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,22 @@ struct pymatching_config {
from_heterogeneous_map(const cudaqx::heterogeneous_map &map);
};

// The realtime trt_decoder config currently models only PyMatching as a global
// decoder. Other global decoder plugins may be constructed through lower-level
// APIs when their parameters are supplied directly, but they are not serialized
// by this config variant yet.
using global_decoder_config = std::variant<std::monostate, pymatching_config>;

struct trt_decoder_config {
std::optional<std::string> onnx_load_path;
std::optional<std::string> engine_load_path;
std::optional<std::string> engine_save_path;
std::optional<std::string> precision;
std::optional<std::size_t> memory_workspace;
std::optional<std::size_t> batch_size;
std::optional<bool> use_cuda_graph;
std::optional<std::string> global_decoder;
global_decoder_config global_decoder_params;

bool operator==(const trt_decoder_config &) const = default;

Expand Down
94 changes: 69 additions & 25 deletions libs/qec/lib/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
std::vector<uint32_t> log_msyn;
std::vector<uint32_t> log_detectors;
std::vector<uint32_t> log_errors;
std::vector<uint32_t> log_observables;
std::vector<uint8_t> log_observable_corrections;
// The four time points are used to measure the duration of each of 3 steps.
std::chrono::time_point<std::chrono::high_resolution_clock> log_t0, log_t1,
Expand All @@ -305,6 +306,7 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
if (should_log) {
log_t0 = std::chrono::high_resolution_clock::now();
log_errors.reserve(syndrome_length);
log_observables.reserve(O_sparse.size());
log_observable_corrections.resize(O_sparse.size());
}

Expand Down Expand Up @@ -362,36 +364,76 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
return false;
}
}
// Process the results.
// TODO - should this interrogate the decoded_result.converged flag?
const auto result_type = get_result_type();
const auto num_observables = get_num_observables();
const char *result_type_str = nullptr;
const char *result_type_name = nullptr;
std::size_t expected_result_size = 0;
switch (result_type) {
case decode_result_type::decode_to_errs:
result_type_str = "errs";
result_type_name = "decode_to_errs";
expected_result_size = block_size;
break;
case decode_result_type::decode_to_obs:
result_type_str = "obs";
result_type_name = "decode_to_obs";
expected_result_size = num_observables;
break;
}
if (!result_type_name)
throw std::runtime_error(
fmt::format("Unsupported decoder result type ({})",
static_cast<int>(result_type)));
if ((!pimpl->is_sliding_window &&
decoded_result.result.size() != block_size) ||
decoded_result.result.size() != expected_result_size) ||
(pimpl->is_sliding_window && !decoded_result.result.empty() &&
decoded_result.result.size() != block_size)) {
throw std::runtime_error(
fmt::format("Decoder result size ({}) does not match block_size ({})",
decoded_result.result.size(), block_size));
decoded_result.result.size() != expected_result_size)) {
throw std::runtime_error(fmt::format(
"Decoder result size ({}) does not match expected size ({}) for "
"result type {}",
decoded_result.result.size(), expected_result_size,
result_type_name));
}

if (should_log) {
// Flip an observable correction and mirror it into the per-call log so the
// logged flips stay faithful to the applied corrections.
auto flip_correction = [&](std::size_t i) {
pimpl->corrections[i] ^= 1;
if (should_log)
log_observable_corrections[i] ^= 1;
};

if (should_log)
log_t2 = std::chrono::high_resolution_clock::now();
for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++)
if (decoded_result.result[e])
log_errors.push_back(e);
}
// Process the results.
// TODO - should this interrogate the decoded_result.converged flag?
auto num_observables = O_sparse.size();
// For each observable
for (std::size_t i = 0; i < num_observables; i++) {
// For each error that flips this observable
for (auto col : O_sparse[i]) {
// If the decoder predicted that this error occurred
if (decoded_result.result[col]) {
// Flip the correction for this observable
pimpl->corrections[i] ^= 1;

switch (result_type) {
case decode_result_type::decode_to_obs:
// Observable-frame path: decoder already projected to observables via its
// internal "O" matrix; use the result directly.
for (std::size_t i = 0; i < num_observables; i++)
if (decoded_result.result[i]) {
if (should_log)
log_observable_corrections[i] ^= 1;
log_observables.push_back(i);
flip_correction(i);
}
}
break;
case decode_result_type::decode_to_errs:
// Error-frame path: decoder returns a block-sized error vector; project
// to observables via O_sparse.
if (should_log)
for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++)
if (decoded_result.result[e])
log_errors.push_back(e);
// For each observable, flip its correction once for each predicted error
// that flips it (net parity over O_sparse[i]).
for (std::size_t i = 0; i < num_observables; i++)
for (auto col : O_sparse[i])
if (decoded_result.result[col])
flip_correction(i);
break;
}
if (should_log) {
log_t3 = std::chrono::high_resolution_clock::now();
Expand All @@ -401,13 +443,15 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
pimpl->log_counter++;
auto s = fmt::format(
"[DecoderStats][{}] Counter:{} DecoderId:{} InputMsyn:{} "
"InputDetectors:{} Converged:{} Errors:{} "
"InputDetectors:{} Converged:{} ResultType:{} Errors:{} "
"Observables:{} "
"ObservableCorrectionsThisCall:{} ObservableCorrectionsTotal:{} "
"Dur1:{:.1f}us Dur2:{:.1f}us Dur3:{:.1f}us",
static_cast<const void *>(this), pimpl->log_counter,
pimpl->decoder_id, fmt::join(log_msyn, ","),
fmt::join(log_detectors, ","), decoded_result.converged ? 1 : 0,
fmt::join(log_errors, ","),
result_type_str, fmt::join(log_errors, ","),
fmt::join(log_observables, ","),
fmt::join(log_observable_corrections, ","),
fmt::join(std::vector<uint8_t>(pimpl->corrections.begin(),
pimpl->corrections.end()),
Expand Down
1 change: 1 addition & 0 deletions libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ trt_decoder::trt_decoder(const cudaq::qec::sparse_binary_matrix &H,
}
decode_to_observables_ = true;
num_observables_ = O.shape()[0];
set_result_type(decode_result_type::decode_to_obs);

// The TRT model output must encode [pre_L (num_observables_ entries),
// residual_dets (rest)]. Validate sizing where we can.
Expand Down
116 changes: 116 additions & 0 deletions libs/qec/lib/realtime/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,31 @@ single_error_lut_config single_error_lut_config::from_heterogeneous_map(
return config;
}

cudaqx::heterogeneous_map global_decoder_config_to_heterogeneous_map(
const global_decoder_config &global_decoder_params) {
if (std::holds_alternative<std::monostate>(global_decoder_params)) {
return cudaqx::heterogeneous_map();
}

if (std::holds_alternative<pymatching_config>(global_decoder_params)) {
return std::get<pymatching_config>(global_decoder_params)
.to_heterogeneous_map();
}

throw std::runtime_error("Unsupported global decoder parameters.");
}

global_decoder_config global_decoder_config_from_heterogeneous_map(
const cudaqx::heterogeneous_map &map,
const std::optional<std::string> &global_decoder) {
if (global_decoder.has_value() && global_decoder.value() != "pymatching") {
throw std::runtime_error(
"global_decoder_params currently supports only pymatching.");
}

return pymatching_config::from_heterogeneous_map(map);
}

// ------ pymatching_config ------
cudaqx::heterogeneous_map pymatching_config::to_heterogeneous_map() const {
cudaqx::heterogeneous_map config_map;
Expand Down Expand Up @@ -217,6 +242,27 @@ cudaqx::heterogeneous_map trt_decoder_config::to_heterogeneous_map() const {
INSERT_ARG(engine_save_path);
INSERT_ARG(precision);
INSERT_ARG(memory_workspace);
INSERT_ARG(batch_size);
INSERT_ARG(use_cuda_graph);
INSERT_ARG(global_decoder);
if (!std::holds_alternative<std::monostate>(global_decoder_params)) {
if (!global_decoder.has_value()) {
throw std::runtime_error(
"global_decoder_params present but global_decoder is not set.");
}
if (global_decoder.value() != "pymatching") {
throw std::runtime_error(
"global_decoder_params currently supports only pymatching.");
}
config_map.insert(
"global_decoder_params",
global_decoder_config_to_heterogeneous_map(global_decoder_params));
}
// Note: when global_decoder_params is monostate we intentionally emit
// nothing, even if global_decoder is set. Inventing an empty params map here
// would round-trip back as a default pymatching_config, mutating the config.
// Any runtime need for an empty params map is handled in
// prepare_decoder_params (realtime_decoding.cpp), not in serialization.

return config_map;
}
Expand All @@ -229,6 +275,32 @@ trt_decoder_config trt_decoder_config::from_heterogeneous_map(
GET_ARG(engine_save_path);
GET_ARG(precision);
GET_ARG(memory_workspace);
GET_ARG(batch_size);
GET_ARG(use_cuda_graph);
GET_ARG(global_decoder);
if (map.contains("global_decoder_params")) {
Comment thread
melody-ren marked this conversation as resolved.
if (!config.global_decoder.has_value())
throw std::runtime_error(
"global_decoder_params present but global_decoder is not set.");
if (config.global_decoder.value() != "pymatching")
throw std::runtime_error(
"global_decoder_params currently supports only pymatching.");
try {
config.global_decoder_params =
map.get<global_decoder_config>("global_decoder_params");
} catch (...) {
try {
config.global_decoder_params =
map.get<pymatching_config>("global_decoder_params");
} catch (...) {
auto nested_map =
map.get<cudaqx::heterogeneous_map>("global_decoder_params");
config.global_decoder_params =
global_decoder_config_from_heterogeneous_map(nested_map,
config.global_decoder);
}
}
}

return config;
}
Expand Down Expand Up @@ -363,6 +435,30 @@ struct MappingTraits<cudaq::qec::decoding::config::single_error_lut_config> {
cudaq::qec::decoding::config::single_error_lut_config &config) {}
};

template <>
struct MappingTraits<cudaq::qec::decoding::config::global_decoder_config> {
static void
mapping(IO &io, cudaq::qec::decoding::config::global_decoder_config &config) {
using namespace cudaq::qec::decoding::config;

if (io.outputting()) {
if (std::holds_alternative<std::monostate>(config)) {
return;
}

auto &params = std::get<pymatching_config>(config);
io.mapOptional("merge_strategy", params.merge_strategy);
io.mapOptional("error_rate_vec", params.error_rate_vec);
return;
}

pymatching_config params;
io.mapOptional("merge_strategy", params.merge_strategy);
io.mapOptional("error_rate_vec", params.error_rate_vec);
config = std::move(params);
}
};

template <>
struct MappingTraits<cudaq::qec::decoding::config::pymatching_config> {
static void mapping(IO &io,
Expand All @@ -381,6 +477,26 @@ struct MappingTraits<cudaq::qec::decoding::config::trt_decoder_config> {
io.mapOptional("engine_save_path", config.engine_save_path);
io.mapOptional("precision", config.precision);
io.mapOptional("memory_workspace", config.memory_workspace);
io.mapOptional("batch_size", config.batch_size);
io.mapOptional("use_cuda_graph", config.use_cuda_graph);
io.mapOptional("global_decoder", config.global_decoder);
// Emit global_decoder_params only when it actually holds params. Mapping it
// unconditionally on output writes an empty `global_decoder_params: {}` for
// the monostate case, which deserializes back into a default
// pymatching_config -- mutating a monostate config across a YAML
// round-trip. On input we always map it: an absent key leaves the variant
// at its monostate default (mapOptional skips the nested mapping), and a
// present key is parsed into pymatching_config.
if (!io.outputting() ||
!std::holds_alternative<std::monostate>(config.global_decoder_params))
io.mapOptional("global_decoder_params", config.global_decoder_params);

if (!std::holds_alternative<std::monostate>(config.global_decoder_params) &&
config.global_decoder.has_value() &&
config.global_decoder.value() != "pymatching") {
throw std::runtime_error(
"global_decoder_params currently supports only pymatching.");
}
}
};

Expand Down
Loading
Loading