Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 124 additions & 9 deletions libs/qec/include/cudaq/qec/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,16 @@
#include "cuda-qx/core/heterogeneous_map.h"
#include "cuda-qx/core/tensor.h"
#include "sparse_binary_matrix.h"
#include "cudaq/qec/detector_error_model.h"
#include <algorithm>
#include <functional>
#include <future>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <tuple>
#include <variant>
#include <vector>

namespace cudaq::qec {
Expand All @@ -24,6 +32,10 @@ using float_t = CUDAQX_QEC_FLOAT_TYPE;
using float_t = double;
#endif

/// Decoder construction input: either a parity-check matrix or raw Stim DEM
/// text.
using decoder_init = std::variant<sparse_binary_matrix, std::string>;

/// @brief Validates that all keys in a heterogeneous map are found in a list of
/// acceptable types
/// @param config The heterogeneous map to validate
Expand Down Expand Up @@ -122,8 +134,7 @@ class async_decoder_result {
/// arbitrary constructor parameters that can be unique to each specific
/// decoder.
class decoder
: public cudaqx::extension_point<decoder,
const cudaq::qec::sparse_binary_matrix &,
: public cudaqx::extension_point<decoder, const decoder_init &,
const cudaqx::heterogeneous_map &> {
private:
struct rt_impl;
Expand All @@ -143,16 +154,16 @@ class decoder
/// @brief Decode a single syndrome
/// @param syndrome A vector of syndrome measurements where the floating point
/// value is the probability that the syndrome measurement is a |1>. The
/// length of the syndrome vector should be equal to \p syndrome_size.
/// @returns Vector of length \p block_size with soft probabilities of errors
/// length of the syndrome vector should be equal to `syndrome_size`.
/// @returns Vector of length `block_size` with soft probabilities of errors
/// in each index.
virtual decoder_result decode(const std::vector<float_t> &syndrome) = 0;

/// @brief Decode a single syndrome
/// @param syndrome An order-1 tensor of syndrome measurements where a 1 bit
/// represents that the syndrome measurement is a |1>. The
/// length of the syndrome vector should be equal to \p syndrome_size.
/// @returns Vector of length \p block_size of errors in each index.
/// length of the syndrome vector should be equal to `syndrome_size`.
/// @returns Vector of length `block_size` of errors in each index.
virtual decoder_result decode(const cudaqx::tensor<uint8_t> &syndrome);

/// @brief Decode a single syndrome
Expand All @@ -172,11 +183,49 @@ class decoder
virtual std::vector<decoder_result>
decode_batch(const std::vector<std::vector<float_t>> &syndrome);

/// @brief This `get` overload supports default values.
/// @brief Construct a registered decoder by name.
/// @param name The registered decoder name.
/// @param init A parity-check matrix or raw Stim DEM string.
/// @param param_map Optional decoder-specific parameters.
static std::unique_ptr<decoder>
get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
get(const std::string &name, const decoder_init &init,
const cudaqx::heterogeneous_map &param_map = cudaqx::heterogeneous_map());

static std::unique_ptr<decoder>
get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
const cudaqx::heterogeneous_map &param_map =
cudaqx::heterogeneous_map()) {
return get(name, decoder_init{H}, param_map);
}

static std::unique_ptr<decoder>
get(const std::string &name, const cudaqx::tensor<uint8_t> &H,
const cudaqx::heterogeneous_map &param_map =
cudaqx::heterogeneous_map()) {
return get(name, cudaq::qec::sparse_binary_matrix(H), param_map);
}

static std::unique_ptr<decoder>
get(const std::string &name, const std::string &stim_dem_text,
const cudaqx::heterogeneous_map &param_map =
cudaqx::heterogeneous_map()) {
return get(name, decoder_init{stim_dem_text}, param_map);
}

static std::unique_ptr<decoder>
get(const std::string &name, const char *stim_dem_text,
const cudaqx::heterogeneous_map &param_map =
cudaqx::heterogeneous_map()) {
return get(name, decoder_init{std::string{stim_dem_text}}, param_map);
}

static std::unique_ptr<decoder>
get(const std::string &name, std::string_view stim_dem_text,
const cudaqx::heterogeneous_map &param_map =
cudaqx::heterogeneous_map()) {
return get(name, decoder_init{std::string{stim_dem_text}}, param_map);
}

std::size_t get_block_size() { return block_size; }
std::size_t get_syndrome_size() { return syndrome_size; }

Expand Down Expand Up @@ -435,6 +484,72 @@ inline void convert_vec_hard_to_soft(const std::vector<std::vector<t_hard>> &in,
}

std::unique_ptr<decoder>
get_decoder(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
get_decoder(const std::string &name, const decoder_init &init,
const cudaqx::heterogeneous_map options = {});

inline std::unique_ptr<decoder>
get_decoder(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
const cudaqx::heterogeneous_map options = {}) {
return get_decoder(name, decoder_init{H}, options);
}

inline std::unique_ptr<decoder>
get_decoder(const std::string &name, const cudaqx::tensor<uint8_t> &H,
const cudaqx::heterogeneous_map options = {}) {
return get_decoder(name, cudaq::qec::sparse_binary_matrix(H), options);
}

inline std::unique_ptr<decoder>
get_decoder(const std::string &name, const std::string &stim_dem_text,
const cudaqx::heterogeneous_map options = {}) {
return get_decoder(name, decoder_init{stim_dem_text}, options);
}

inline std::unique_ptr<decoder>
get_decoder(const std::string &name, const char *stim_dem_text,
const cudaqx::heterogeneous_map options = {}) {
return get_decoder(name, decoder_init{std::string{stim_dem_text}}, options);
}

inline std::unique_ptr<decoder>
get_decoder(const std::string &name, std::string_view stim_dem_text,
const cudaqx::heterogeneous_map options = {}) {
return get_decoder(name, decoder_init{std::string{stim_dem_text}}, options);
}

namespace details {
// Declared here because `make_pcm_decoder` is a header-defined template.
/// DEM-derived defaults; pointers alias into the source `dem`.
struct dem_default_values {
const cudaqx::tensor<uint8_t> *O = nullptr;
Comment thread
vedika-saravanan marked this conversation as resolved.
const std::vector<double> *error_rate_vec = nullptr;
};

/// Return DEM defaults for keys not already supplied by the user.
dem_default_values dem_defaults_for_missing_keys(
const std::function<bool(const std::string &)> &contains_user_key,
const detector_error_model &dem);
} // namespace details

/// If `init` holds DEM text, parse it and inject `"O"` / `"error_rate_vec"`
/// defaults when absent.
template <typename DecoderT>
std::unique_ptr<decoder>
make_pcm_decoder(const decoder_init &init,
const cudaqx::heterogeneous_map &params) {
if (const auto *H = std::get_if<cudaq::qec::sparse_binary_matrix>(&init))
return std::make_unique<DecoderT>(*H, params);

const auto dem = dem_from_stim_text(std::get<std::string>(init));
cudaqx::heterogeneous_map merged = params;
const auto defaults = details::dem_defaults_for_missing_keys(
[&](const std::string &key) { return merged.contains(key); }, dem);
if (defaults.O)
merged.insert("O", *defaults.O);
if (defaults.error_rate_vec)
merged.insert("error_rate_vec", *defaults.error_rate_vec);
return std::make_unique<DecoderT>(
cudaq::qec::sparse_binary_matrix(dem.detector_error_matrix), merged);
}

} // namespace cudaq::qec
16 changes: 12 additions & 4 deletions libs/qec/include/cudaq/qec/detector_error_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
#pragma once

#include "cuda-qx/core/tensor.h"
#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>
#include <vector>

namespace cudaq::qec {

Expand All @@ -18,9 +22,9 @@ namespace cudaq::qec {
/// decoder to help make predictions about observables flips.
///
/// Shared size parameters among the matrix types.
/// - \p detector_error_matrix: num_detectors x num_error_mechanisms [d, e]
/// - \p error_rates: num_error_mechanisms
/// - \p observables_flips_matrix: num_observables x num_error_mechanisms [k, e]
/// - `detector_error_matrix`: num_detectors x num_error_mechanisms [d, e]
/// - `error_rates`: num_error_mechanisms
/// - `observables_flips_matrix`: num_observables x num_error_mechanisms [k, e]
///
/// @note The C++ API for this class may change in the future. The Python API is
/// more likely to be backwards compatible.
Expand All @@ -32,7 +36,7 @@ struct detector_error_model {
cudaqx::tensor<uint8_t> detector_error_matrix;

/// The list of weights has length equal to the number of columns of
/// \p detector_error_matrix, which assigns a likelihood to each error
/// `detector_error_matrix`, which assigns a likelihood to each error
/// mechanism.
std::vector<double> error_rates;

Expand Down Expand Up @@ -65,4 +69,8 @@ struct detector_error_model {
void canonicalize_for_rounds(uint32_t num_syndromes_per_round);
};

/// Parse Stim DEM text into detector/observable flip matrices and error rates.
/// This is lossy; DEM-native decoders should consume raw DEM text instead.
detector_error_model dem_from_stim_text(const std::string &dem_text);

} // namespace cudaq::qec
26 changes: 25 additions & 1 deletion libs/qec/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

set(LIBRARY_NAME cudaq-qec)

include(FetchContent)

add_compile_options(-Wno-attributes)

find_package(CUDAToolkit REQUIRED)
Expand Down Expand Up @@ -44,7 +46,7 @@ set(QEC_SOURCES
pcm_utils.cpp
plugin_loader.cpp
sparse_binary_matrix.cpp
stabilizer_utils.cpp
stabilizer_utils.cpp
decoders/lut.cpp
decoders/sliding_window.cpp
version.cpp
Expand All @@ -58,6 +60,28 @@ list(APPEND QEC_SOURCES
# FIXME?: This must be a shared library. Trying to build a static one will fail.
add_library(${LIBRARY_NAME} SHARED ${QEC_SOURCES})

if(NOT TARGET libstim)
FetchContent_Declare(
stim
GIT_REPOSITORY https://github.com/quantumlib/Stim.git
GIT_TAG v1.15.0
EXCLUDE_FROM_ALL
)
FetchContent_MakeAvailable(stim)
endif()

if(NOT TARGET libstim)
message(FATAL_ERROR
"Stim FetchContent did not provide the libstim target.")
endif()
set_target_properties(${LIBRARY_NAME} PROPERTIES
VISIBILITY_INLINES_HIDDEN ON
)
target_link_libraries(${LIBRARY_NAME} PRIVATE libstim)
target_link_options(${LIBRARY_NAME} PRIVATE
$<$<OR:$<CXX_COMPILER_ID:GNU>,$<CXX_COMPILER_ID:Clang>>:-Wl,--exclude-libs,libstim.a>
)

Comment thread
bmhowe23 marked this conversation as resolved.
add_subdirectory(decoders/plugins/example)
add_subdirectory(decoders/plugins/pymatching)

Expand Down
28 changes: 20 additions & 8 deletions libs/qec/lib/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
#include <filesystem>
#include <vector>

INSTANTIATE_REGISTRY(cudaq::qec::decoder,
const cudaq::qec::sparse_binary_matrix &)
INSTANTIATE_REGISTRY(cudaq::qec::decoder,
const cudaq::qec::sparse_binary_matrix &,
INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaq::qec::decoder_init &,
const cudaqx::heterogeneous_map &)

// Include decoder implementations AFTER registry instantiation
Expand Down Expand Up @@ -131,7 +128,7 @@ decoder::decode_async(const std::vector<float_t> &syndrome) {
}

std::unique_ptr<decoder>
decoder::get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
decoder::get(const std::string &name, const decoder_init &init,
const cudaqx::heterogeneous_map &param_map) {
auto [mutex, registry] = get_registry();
std::lock_guard<std::recursive_mutex> lock(mutex);
Expand All @@ -141,9 +138,24 @@ decoder::get(const std::string &name, const cudaq::qec::sparse_binary_matrix &H,
"invalid decoder requested: " + name +
". Run with CUDAQ_LOG_LEVEL=info (environment variable) to see "
"additional plugin diagnostics at startup.");
return iter->second(H, param_map);
return iter->second(init, param_map);
}

namespace details {

dem_default_values dem_defaults_for_missing_keys(
const std::function<bool(const std::string &)> &contains_user_key,
const detector_error_model &dem) {
dem_default_values out;
if (!contains_user_key("O") && dem.num_observables() > 0)
out.O = &dem.observables_flips_matrix;
if (!contains_user_key("error_rate_vec"))
out.error_rate_vec = &dem.error_rates;
return out;
}

} // namespace details

static uint32_t calculate_num_msyn_per_decode(
const std::vector<std::vector<uint32_t>> &D_sparse) {
uint32_t max_col = 0;
Expand Down Expand Up @@ -480,9 +492,9 @@ void decoder::reset_decoder() {
}

std::unique_ptr<decoder> get_decoder(const std::string &name,
const cudaq::qec::sparse_binary_matrix &H,
const decoder_init &init,
const cudaqx::heterogeneous_map options) {
return decoder::get(name, H, options);
return decoder::get(name, init, options);
}

// Constructor function for auto-loading plugins
Expand Down
8 changes: 4 additions & 4 deletions libs/qec/lib/decoders/lut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@ class multi_error_lut : public decoder {

CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
multi_error_lut, static std::unique_ptr<decoder> create(
const cudaq::qec::sparse_binary_matrix &H,
const cudaq::qec::decoder_init &init,
const cudaqx::heterogeneous_map &params) {
return std::make_unique<multi_error_lut>(H, params);
return cudaq::qec::make_pcm_decoder<multi_error_lut>(init, params);
})
};

Expand All @@ -246,9 +246,9 @@ class single_error_lut : public multi_error_lut {

CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
single_error_lut, static std::unique_ptr<decoder> create(
const cudaq::qec::sparse_binary_matrix &H,
const cudaq::qec::decoder_init &init,
const cudaqx::heterogeneous_map &params) {
return std::make_unique<single_error_lut>(H, params);
return cudaq::qec::make_pcm_decoder<single_error_lut>(init, params);
})
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ class single_error_lut_example : public decoder {

CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
single_error_lut_example, static std::unique_ptr<decoder> create(
const cudaq::qec::sparse_binary_matrix &H,
const cudaq::qec::decoder_init &init,
const cudaqx::heterogeneous_map &params) {
return std::make_unique<single_error_lut_example>(H, params);
return cudaq::qec::make_pcm_decoder<single_error_lut_example>(init,
params);
})
};

Expand Down
4 changes: 2 additions & 2 deletions libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ class pymatching : public decoder {

CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
pymatching, static std::unique_ptr<decoder> create(
const cudaq::qec::sparse_binary_matrix &H,
const cudaq::qec::decoder_init &init,
const cudaqx::heterogeneous_map &params) {
return std::make_unique<pymatching>(H, params);
return cudaq::qec::make_pcm_decoder<pymatching>(init, params);
})
};

Expand Down
4 changes: 2 additions & 2 deletions libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,9 @@ class trt_decoder : public decoder {

CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
trt_decoder, static std::unique_ptr<decoder> create(
const cudaq::qec::sparse_binary_matrix &H,
const cudaq::qec::decoder_init &init,
const cudaqx::heterogeneous_map &params) {
return std::make_unique<trt_decoder>(H, params);
return cudaq::qec::make_pcm_decoder<trt_decoder>(init, params);
})

private:
Expand Down
Loading
Loading