diff --git a/docs/sphinx/api/qec/cpp_api.rst b/docs/sphinx/api/qec/cpp_api.rst index 60206629..fa26bdd9 100644 --- a/docs/sphinx/api/qec/cpp_api.rst +++ b/docs/sphinx/api/qec/cpp_api.rst @@ -33,7 +33,7 @@ Detector Error Model .. doxygenfunction:: cudaq::qec::dem_from_memory_circuit(const code &, operation, std::size_t, cudaq::noise_model &) .. doxygenfunction:: cudaq::qec::x_dem_from_memory_circuit(const code &, operation, std::size_t, cudaq::noise_model &) .. doxygenfunction:: cudaq::qec::z_dem_from_memory_circuit(const code &, operation, std::size_t, cudaq::noise_model &) -.. doxygenfunction:: cudaq::qec::dem_from_stim_text(const std::string &) +.. doxygenfunction:: cudaq::qec::dem_from_stim_text(const std::string &, bool) Decoder Interfaces ================== diff --git a/docs/sphinx/examples/qec/cpp/stim_dem_decoder.cpp b/docs/sphinx/examples/qec/cpp/stim_dem_decoder.cpp index 97f463df..206b3864 100644 --- a/docs/sphinx/examples/qec/cpp/stim_dem_decoder.cpp +++ b/docs/sphinx/examples/qec/cpp/stim_dem_decoder.cpp @@ -22,14 +22,19 @@ int main() { const std::string dem_text = R"(error(0.1) D0 L0 error(0.1) D1 L0 error(0.05) D0 D1 +error(0.02) D0 ^ D1 )"; auto decoder = cudaq::qec::get_decoder("single_error_lut", dem_text); auto dem = cudaq::qec::dem_from_stim_text(dem_text); + auto dem_decomposed = + cudaq::qec::dem_from_stim_text(dem_text, /*decompose_errors=*/true); std::cout << "detectors: " << dem.num_detectors() << "\n"; std::cout << "error mechanisms: " << dem.num_error_mechanisms() << "\n"; std::cout << "observables: " << dem.num_observables() << "\n"; + std::cout << "error mechanisms (decomposed): " + << dem_decomposed.num_error_mechanisms() << "\n"; const std::vector> syndromes = { {0.0, 0.0}, {1.0, 0.0}, {0.0, 1.0}, {1.0, 1.0}}; diff --git a/docs/sphinx/examples/qec/python/stim_dem_decoder.py b/docs/sphinx/examples/qec/python/stim_dem_decoder.py index 8189e65b..f43c8444 100644 --- a/docs/sphinx/examples/qec/python/stim_dem_decoder.py +++ b/docs/sphinx/examples/qec/python/stim_dem_decoder.py @@ -13,14 +13,17 @@ error(0.1) D0 L0 error(0.1) D1 L0 error(0.05) D0 D1 +error(0.02) D0 ^ D1 """ decoder = qec.get_decoder("single_error_lut", dem_text) dem = qec.dem_from_stim_text(dem_text) +dem_decomposed = qec.dem_from_stim_text(dem_text, decompose_errors=True) print("detectors:", dem.num_detectors()) print("error mechanisms:", dem.num_error_mechanisms()) print("observables:", dem.num_observables()) +print("error mechanisms (decomposed):", dem_decomposed.num_error_mechanisms()) syndromes = np.array([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=np.uint8) results = decoder.decode_batch(syndromes) diff --git a/libs/qec/include/cudaq/qec/detector_error_model.h b/libs/qec/include/cudaq/qec/detector_error_model.h index 6411f0f7..05369354 100644 --- a/libs/qec/include/cudaq/qec/detector_error_model.h +++ b/libs/qec/include/cudaq/qec/detector_error_model.h @@ -71,6 +71,11 @@ struct detector_error_model { /// Parse Stim DEM text into detector/observable flip matrices and error rates. /// DEM-native decoders should consume raw DEM text instead. -detector_error_model dem_from_stim_text(const std::string &dem_text); +/// If @p decompose_errors is true, error mechanisms that carry an explicit +/// graphlike decomposition (components separated by '^' in the DEM text) are +/// expanded into one column per component; otherwise the '^' separators are +/// ignored and each error instruction produces a single column. +detector_error_model dem_from_stim_text(const std::string &dem_text, + bool decompose_errors = false); } // namespace cudaq::qec diff --git a/libs/qec/lib/detector_error_model.cpp b/libs/qec/lib/detector_error_model.cpp index cb9e1a1d..9dd3dd4e 100644 --- a/libs/qec/lib/detector_error_model.cpp +++ b/libs/qec/lib/detector_error_model.cpp @@ -20,7 +20,8 @@ namespace cudaq::qec { -detector_error_model dem_from_stim_text(const std::string &dem_text) { +detector_error_model dem_from_stim_text(const std::string &dem_text, + bool decompose_errors) { auto dem = [&dem_text]() { try { return stim::DetectorErrorModel(dem_text); @@ -50,11 +51,10 @@ detector_error_model dem_from_stim_text(const std::string &dem_text) { std::to_string(prob) + " out of range [0, 1] at instruction index " + std::to_string(instruction_index)); + std::vector dets; std::vector obs; - for (const auto &target : inst.target_data) { - if (target.is_separator()) - continue; + auto push_target = [&](const stim::DemTarget &target) { if (target.is_relative_detector_id()) { dets.push_back(static_cast(target.val())); } else if (target.is_observable_id()) { @@ -66,10 +66,38 @@ detector_error_model dem_from_stim_text(const std::string &dem_text) { ") contains an unsupported target kind; only D* (detector) and " "L* (observable) targets are supported by the fallback parser"); } + }; + + if (decompose_errors) { + // Each segment delimited by '^' in the DEM text becomes its own column. + auto flush = [&]() { + if (!dets.empty() || !obs.empty()) { + detector_hits.push_back(dets); + observable_hits.push_back(obs); + rates.push_back(prob); + dets.clear(); + obs.clear(); + } + }; + for (const auto &target : inst.target_data) { + if (target.is_separator()) { + flush(); + } else { + push_target(target); + } + } + flush(); + } else { + // Ignore '^' separators; all targets become a single column. + for (const auto &target : inst.target_data) { + if (target.is_separator()) + continue; + push_target(target); + } + detector_hits.push_back(std::move(dets)); + observable_hits.push_back(std::move(obs)); + rates.push_back(prob); } - detector_hits.push_back(std::move(dets)); - observable_hits.push_back(std::move(obs)); - rates.push_back(prob); ++instruction_index; }); diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 57078a28..e8482d05 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -766,10 +766,15 @@ void bindDecoder(nb::module_ &mod) { )pbdoc", nb::arg("num_syndromes_per_round")); - qecmod.def( - "dem_from_stim_text", &dem_from_stim_text, - "Parse a Stim detector error model string into a DetectorErrorModel.", - nb::arg("dem_text")); + qecmod.def("dem_from_stim_text", &dem_from_stim_text, + R"pbdoc( + Parse a Stim detector error model string into a DetectorErrorModel. + + Args: + dem_text: A Stim detector error model string. + decompose_errors: If error mechanism separated by ``^`` are decomposed + )pbdoc", + nb::arg("dem_text"), nb::arg("decompose_errors") = false); // Expose decorator function that handles inheritance qecmod.def("decoder", [&](const std::string &name) { diff --git a/libs/qec/python/tests/test_decoder.py b/libs/qec/python/tests/test_decoder.py index 9f189619..d389f2a0 100644 --- a/libs/qec/python/tests/test_decoder.py +++ b/libs/qec/python/tests/test_decoder.py @@ -852,6 +852,84 @@ def test_dem_from_stim_text_explicit_parse_then_get_decoder(): assert decoder.get_block_size() == 3 +def test_dem_from_stim_text_decompose_errors(): + dem_text = ("error(0.05) D0 D1 L0\n" + "error(0.03) D2 L1\n" + "error(0.1) D0 D2 ^ D1 D3\n") + + # ── decompose_errors=False + dem_no = qec.dem_from_stim_text(dem_text, decompose_errors=False) + assert dem_no.num_detectors() == 4 + assert dem_no.num_observables() == 2 + assert dem_no.num_error_mechanisms() == 3 + + # Also confirm that the default matches explicit False. + assert qec.dem_from_stim_text(dem_text).num_error_mechanisms() == 3 + + explicit_H_no = np.array([[1, 0, 1], [1, 0, 1], [0, 1, 1], [0, 0, 1]], + dtype=np.uint8) + explicit_O_no = np.array([[1, 0, 0], [0, 1, 0]], dtype=np.uint8) + + np.testing.assert_array_equal( + np.array(dem_no.detector_error_matrix, dtype=np.uint8), explicit_H_no) + np.testing.assert_array_equal( + np.array(dem_no.observables_flips_matrix, dtype=np.uint8), + explicit_O_no) + np.testing.assert_allclose(dem_no.error_rates, [0.05, 0.03, 0.1], + atol=1e-12) + + # ── decompose_errors=True + dem_yes = qec.dem_from_stim_text(dem_text, decompose_errors=True) + assert dem_yes.num_detectors() == 4 + assert dem_yes.num_observables() == 2 + assert dem_yes.num_error_mechanisms() == 4 # instruction 3 splits into 2 + + explicit_H_yes = np.array( + [[1, 0, 1, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 0, 1]], + dtype=np.uint8) + explicit_O_yes = np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.uint8) + + np.testing.assert_array_equal( + np.array(dem_yes.detector_error_matrix, dtype=np.uint8), explicit_H_yes) + np.testing.assert_array_equal( + np.array(dem_yes.observables_flips_matrix, dtype=np.uint8), + explicit_O_yes) + np.testing.assert_allclose(dem_yes.error_rates, [0.05, 0.03, 0.1, 0.1], + atol=1e-12) + + +def test_dem_from_stim_text_decompose_errors_edge_cases(): + A = lambda d: np.array(d, dtype=np.uint8) + + # 1. No '^' in DEM — decompose_errors=True must be a no-op. + dem_text = "error(0.1) D0 D1 L0\nerror(0.2) D1 D2\n" + no = qec.dem_from_stim_text(dem_text, decompose_errors=False) + yes = qec.dem_from_stim_text(dem_text, decompose_errors=True) + np.testing.assert_array_equal(A(no.detector_error_matrix), + A(yes.detector_error_matrix)) + np.testing.assert_array_equal(A(no.observables_flips_matrix), + A(yes.observables_flips_matrix)) + np.testing.assert_allclose(no.error_rates, yes.error_rates, atol=1e-12) + + # 2. Observable flips split across components — each L stays with its '^' segment. + # error(0.1) D0 L0 ^ D1 L1 → col0: D0+L0, col1: D1+L1 + dem_text = "error(0.1) D0 L0 ^ D1 L1\n" + dem = qec.dem_from_stim_text(dem_text, decompose_errors=True) + assert dem.num_error_mechanisms() == 2 + np.testing.assert_array_equal(A(dem.detector_error_matrix), + A([[1, 0], [0, 1]])) + np.testing.assert_array_equal(A(dem.observables_flips_matrix), + A([[1, 0], [0, 1]])) + + # 3. Repeated detector within one component XOR-cancels to 0. + # error(0.1) D0 D0 ^ D1 → col0: D0 appears twice → cancels; col1: D1 + dem_text = "error(0.1) D0 D0 ^ D1\n" + dem = qec.dem_from_stim_text(dem_text, decompose_errors=True) + assert dem.num_error_mechanisms() == 2 + np.testing.assert_array_equal(A(dem.detector_error_matrix), + A([[0, 0], [0, 1]])) + + def test_get_decoder_rejects_malformed_stim_dem_text(): with pytest.raises(RuntimeError): qec.get_decoder("single_error_lut", "not a valid DEM")