diff --git a/lib/BUILD b/lib/BUILD index 60fd0e51a..8001c046b 100644 --- a/lib/BUILD +++ b/lib/BUILD @@ -55,11 +55,13 @@ cc_library( "simulator_avx.h", "simulator_avx512.h", "simulator_basic.h", + "simulator_neon.h", "simulator_sse.h", "statespace.h", "statespace_avx.h", "statespace_avx512.h", "statespace_basic.h", + "statespace_neon.h", "statespace_sse.h", "umux.h", "unitary_calculator_avx.h", @@ -130,6 +132,7 @@ cuda_library( "simulator_basic.h", "simulator_cuda.h", "simulator_cuda_kernels.h", + "simulator_neon.h", "simulator_sse.h", "statespace.h", "statespace_avx.h", @@ -137,6 +140,7 @@ cuda_library( "statespace_basic.h", "statespace_cuda.h", "statespace_cuda_kernels.h", + "statespace_neon.h", "statespace_sse.h", "umux.h", "unitary_calculator_avx.h", @@ -201,6 +205,7 @@ cuda_library( "simulator_basic.h", "simulator_custatevec.h", "simulator_custatevecex.h", + "simulator_neon.h", "simulator_sse.h", "statespace.h", "statespace_avx.h", @@ -208,6 +213,7 @@ cuda_library( "statespace_basic.h", "statespace_custatevec.h", "statespace_custatevecex.h", + "statespace_neon.h", "statespace_sse.h", "umux.h", "unitary_calculator_avx.h", @@ -263,11 +269,13 @@ cc_library( "simulator_avx.h", "simulator_avx512.h", "simulator_basic.h", + "simulator_neon.h", "simulator_sse.h", "statespace.h", "statespace_avx.h", "statespace_avx512.h", "statespace_basic.h", + "statespace_neon.h", "statespace_sse.h", "umux.h", "unitary_calculator_avx.h", @@ -310,11 +318,13 @@ cc_library( "simulator_avx.h", "simulator_avx512.h", "simulator_basic.h", + "simulator_neon.h", "simulator_sse.h", "statespace.h", "statespace_avx.h", "statespace_avx512.h", "statespace_basic.h", + "statespace_neon.h", "statespace_sse.h", "util.h", "util_cpu.h", @@ -602,6 +612,16 @@ cc_library( ], ) +cc_library( + name = "statespace_neon", + hdrs = ["statespace_neon.h"], + deps = [ + ":statespace", + ":util", + ":vectorspace", + ], +) + cc_library( name = "statespace_sse", hdrs = ["statespace_sse.h"], @@ -686,6 +706,17 @@ cc_library( ], ) +cc_library( + name = "simulator_neon", + hdrs = ["simulator_neon.h"], + deps = [ + ":simulator_base", + ":simulator_basic", + ":statespace_basic", + ":statespace_neon", + ], +) + cc_library( name = "simulator_sse", hdrs = ["simulator_sse.h"], @@ -741,6 +772,7 @@ cc_library( ":simulator_avx", ":simulator_avx512", ":simulator_basic", + ":simulator_neon", ":simulator_sse", ], ) diff --git a/lib/simmux.h b/lib/simmux.h index d3c4074ef..b56335d3b 100644 --- a/lib/simmux.h +++ b/lib/simmux.h @@ -15,7 +15,13 @@ #ifndef SIMMUX_H_ #define SIMMUX_H_ -#ifdef __AVX512F__ +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +# include "simulator_neon.h" + namespace qsim { + template + using Simulator = SimulatorNEON; + } +#elif defined(__AVX512F__) # include "simulator_avx512.h" namespace qsim { template diff --git a/lib/simulator_neon.h b/lib/simulator_neon.h new file mode 100644 index 000000000..4de4a47e4 --- /dev/null +++ b/lib/simulator_neon.h @@ -0,0 +1,931 @@ +// Copyright 2026 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SIMULATOR_NEON_H_ +#define SIMULATOR_NEON_H_ + +#include +#include +#include +#include + +#include "simulator.h" +#include "statespace_neon.h" + +namespace qsim { + +/** + * Quantum circuit simulator with NEON vectorization. + */ +template +class SimulatorNEON final : public SimulatorBase { + public: + using StateSpace = StateSpaceNEON; + using State = typename StateSpace::State; + using fp_type = typename StateSpace::fp_type; + + template + explicit SimulatorNEON(ForArgs&&... args) : for_(args...) {} + + /** + * Applies a gate using NEON instructions. + * @param qs Indices of the qubits affected by this gate. + * @param matrix Matrix representation of the gate to be applied. + * @param state The state of the system, to be updated by this method. + */ + void ApplyGate(const std::vector& qs, + const fp_type* matrix, State& state) const { + // Assume qs[0] < qs[1] < qs[2] < ... . + + switch (qs.size()) { + case 0: + ApplyGateH<0>(qs, matrix, state); + return; + case 1: + if (qs[0] > 1) { + ApplyGateH<1>(qs, matrix, state); + } else { + ApplyGateL<0, 1>(qs, matrix, state); + } + return; + case 2: + if (qs[0] > 1) { + ApplyGateH<2>(qs, matrix, state); + } else if (qs[1] > 1) { + ApplyGateL<1, 1>(qs, matrix, state); + } else { + ApplyGateL<0, 2>(qs, matrix, state); + } + return; + case 3: + if (qs[0] > 1) { + ApplyGateH<3>(qs, matrix, state); + } else if (qs[1] > 1) { + ApplyGateL<2, 1>(qs, matrix, state); + } else { + ApplyGateL<1, 2>(qs, matrix, state); + } + return; + case 4: + if (qs[0] > 1) { + ApplyGateH<4>(qs, matrix, state); + } else if (qs[1] > 1) { + ApplyGateL<3, 1>(qs, matrix, state); + } else { + ApplyGateL<2, 2>(qs, matrix, state); + } + return; + case 5: + if (qs[0] > 1) { + ApplyGateH<5>(qs, matrix, state); + } else if (qs[1] > 1) { + ApplyGateL<4, 1>(qs, matrix, state); + } else { + ApplyGateL<3, 2>(qs, matrix, state); + } + return; + case 6: + if (qs[0] > 1) { + ApplyGateH<6>(qs, matrix, state); + } else if (qs[1] > 1) { + ApplyGateL<5, 1>(qs, matrix, state); + } else { + ApplyGateL<4, 2>(qs, matrix, state); + } + return; + default: + break; + } + } + + /** + * Applies a controlled gate using NEON instructions. + * @param qs Indices of the qubits affected by this gate. + * @param cqs Indices of control qubits. + * @param cvals Bit mask of control qubit values. + * @param matrix Matrix representation of the gate to be applied. + * @param state The state of the system, to be updated by this method. + */ + void ApplyControlledGate(const std::vector& qs, + const std::vector& cqs, uint64_t cvals, + const fp_type* matrix, State& state) const { + // Assume qs[0] < qs[1] < qs[2] < ... . + // Assume cqs[0] < cqs[1] < cqs[2] < ... . + + if (cqs.empty()) { + ApplyGate(qs, matrix, state); + return; + } + + switch (qs.size()) { + case 0: + if (cqs[0] > 1) { + ApplyControlledGateHH<0>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateHL<0>(qs, cqs, cvals, matrix, state); + } + return; + case 1: + if (qs[0] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateHH<1>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateHL<1>(qs, cqs, cvals, matrix, state); + } + } else { + if (cqs[0] > 1) { + ApplyControlledGateL<0, 1, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<0, 1, false>(qs, cqs, cvals, matrix, state); + } + } + return; + case 2: + if (qs[0] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateHH<2>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateHL<2>(qs, cqs, cvals, matrix, state); + } + } else if (qs[1] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateL<1, 1, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<1, 1, false>(qs, cqs, cvals, matrix, state); + } + } else { + if (cqs[0] > 1) { + ApplyControlledGateL<0, 2, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<0, 2, false>(qs, cqs, cvals, matrix, state); + } + } + return; + case 3: + if (qs[0] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateHH<3>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateHL<3>(qs, cqs, cvals, matrix, state); + } + } else if (qs[1] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateL<2, 1, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<2, 1, false>(qs, cqs, cvals, matrix, state); + } + } else { + if (cqs[0] > 1) { + ApplyControlledGateL<1, 2, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<1, 2, false>(qs, cqs, cvals, matrix, state); + } + } + return; + case 4: + if (qs[0] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateHH<4>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateHL<4>(qs, cqs, cvals, matrix, state); + } + } else if (qs[1] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateL<3, 1, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<3, 1, false>(qs, cqs, cvals, matrix, state); + } + } else { + if (cqs[0] > 1) { + ApplyControlledGateL<2, 2, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<2, 2, false>(qs, cqs, cvals, matrix, state); + } + } + return; + case 5: + if (qs[0] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateHH<5>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateHL<5>(qs, cqs, cvals, matrix, state); + } + } else if (qs[1] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateL<4, 1, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<4, 1, false>(qs, cqs, cvals, matrix, state); + } + } else { + if (cqs[0] > 1) { + ApplyControlledGateL<3, 2, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<3, 2, false>(qs, cqs, cvals, matrix, state); + } + } + return; + case 6: + if (qs[0] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateHH<6>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateHL<6>(qs, cqs, cvals, matrix, state); + } + } else if (qs[1] > 1) { + if (cqs[0] > 1) { + ApplyControlledGateL<5, 1, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<5, 1, false>(qs, cqs, cvals, matrix, state); + } + } else { + if (cqs[0] > 1) { + ApplyControlledGateL<4, 2, true>(qs, cqs, cvals, matrix, state); + } else { + ApplyControlledGateL<4, 2, false>(qs, cqs, cvals, matrix, state); + } + } + return; + default: + break; + } + } + + std::complex ExpectationValue( + const std::vector& qs, const fp_type* matrix, + const State& state) const { + switch (qs.size()) { + case 1: + if (qs[0] > 1) { + return ExpectationValueH<1>(qs, matrix, state); + } else { + return ExpectationValueL<0, 1>(qs, matrix, state); + } + case 2: + if (qs[0] > 1) { + return ExpectationValueH<2>(qs, matrix, state); + } else if (qs[1] > 1) { + return ExpectationValueL<1, 1>(qs, matrix, state); + } else { + return ExpectationValueL<0, 2>(qs, matrix, state); + } + case 3: + if (qs[0] > 1) { + return ExpectationValueH<3>(qs, matrix, state); + } else if (qs[1] > 1) { + return ExpectationValueL<2, 1>(qs, matrix, state); + } else { + return ExpectationValueL<1, 2>(qs, matrix, state); + } + case 4: + if (qs[0] > 1) { + return ExpectationValueH<4>(qs, matrix, state); + } else if (qs[1] > 1) { + return ExpectationValueL<3, 1>(qs, matrix, state); + } else { + return ExpectationValueL<2, 2>(qs, matrix, state); + } + case 5: + if (qs[0] > 1) { + return ExpectationValueH<5>(qs, matrix, state); + } else if (qs[1] > 1) { + return ExpectationValueL<4, 1>(qs, matrix, state); + } else { + return ExpectationValueL<3, 2>(qs, matrix, state); + } + case 6: + if (qs[0] > 1) { + return ExpectationValueH<6>(qs, matrix, state); + } else if (qs[1] > 1) { + return ExpectationValueL<5, 1>(qs, matrix, state); + } else { + return ExpectationValueL<4, 2>(qs, matrix, state); + } + default: + break; + } + return 0; + } + + static constexpr unsigned SIMDRegisterSize() { + return sizeof(float32x4_t) / sizeof(float32_t); + } + + private: + struct Complex { + float32x4_t re; + float32x4_t im; + }; + + template + static Complex ApplyGateRow( + const float32x4_t* state_re, const float32x4_t* state_im, + GateCoeff gate_coeff) { + auto gate = gate_coeff(0); + auto re = vmulq_f32(state_re[0], gate.re); + auto im = vmulq_f32(state_re[0], gate.im); + re = vfmsq_f32(re, state_im[0], gate.im); + im = vfmaq_f32(im, state_im[0], gate.re); + + for (unsigned in = 1; in < Size; ++in) { + // Complex MAC: out += gate * state. + // re += state.re * gate.re - state.im * gate.im + // im += state.re * gate.im + state.im * gate.re + gate = gate_coeff(in); + re = vfmaq_f32(re, state_re[in], gate.re); + im = vfmaq_f32(im, state_re[in], gate.im); + re = vfmsq_f32(re, state_im[in], gate.im); + im = vfmaq_f32(im, state_im[in], gate.re); + } + + return Complex{re, im}; + } + + static void StoreStateAmplitudeRow( + fp_type* state_block, const uint64_t* state_offsets, + unsigned output_basis, const Complex& output_amplitudes) { + const auto addr_re = state_block + state_offsets[output_basis]; + const auto addr_im = addr_re + SIMDRegisterSize(); + vst1q_f32(addr_re, output_amplitudes.re); + vst1q_f32(addr_im, output_amplitudes.im); + } + + template + void ApplyGateH( + const std::vector& qs, const fp_type* matrix, + State& state) const { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* gate_matrix, const uint64_t* masks, + const uint64_t* state_offsets, fp_type* state_data) { + constexpr unsigned hsize = 1 << H; + + float32x4_t state_re[hsize]; + float32x4_t state_im[hsize]; + + i *= 4; + + uint64_t ii = i & masks[0]; + for (unsigned j = 1; j <= H; ++j) { + i *= 2; + ii |= i & masks[j]; + } + + auto block = state_data + 2 * ii; + + for (unsigned k = 0; k < hsize; ++k) { + state_re[k] = vld1q_f32(block + state_offsets[k]); + state_im[k] = vld1q_f32(block + state_offsets[k] + 4); + } + + auto load_gate_row = [](const fp_type* gate_row) { + return [gate_row](unsigned in) { + return Complex{ + vdupq_n_f32(gate_row[2 * in]), + vdupq_n_f32(gate_row[2 * in + 1]), + }; + }; + }; + + unsigned out = 0; + for (; out + 1 < hsize; out += 2) { + const fp_type* gate_row0 = gate_matrix + 2 * out * hsize; + const fp_type* gate_row1 = gate_row0 + 2 * hsize; + + auto out0 = ApplyGateRow(state_re, state_im, load_gate_row(gate_row0)); + auto out1 = ApplyGateRow(state_re, state_im, load_gate_row(gate_row1)); + + StoreStateAmplitudeRow(block, state_offsets, out, out0); + StoreStateAmplitudeRow(block, state_offsets, out + 1, out1); + } + + for (; out < hsize; ++out) { + const fp_type* gate_row = gate_matrix + 2 * out * hsize; + + auto out_row = ApplyGateRow(state_re, state_im, load_gate_row(gate_row)); + + StoreStateAmplitudeRow(block, state_offsets, out, out_row); + } + }; + + uint64_t ms[H + 1]; + uint64_t xss[1 << H]; + + FillIndices(state.num_qubits(), qs, ms, xss); + + const unsigned k = 2 + H; + const unsigned n = state.num_qubits() > k ? state.num_qubits() - k : 0; + const uint64_t size = uint64_t{1} << n; + + for_.Run(size, f, matrix, ms, xss, state.get()); + } + + template + void ApplyGateL( + const std::vector& qs, const fp_type* matrix, + State& state) const { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* w, + const uint64_t* ms, const uint64_t* xss, unsigned q0, + fp_type* rstate) { + constexpr unsigned gsize = 1 << (H + L); + constexpr unsigned hsize = 1 << H; + constexpr unsigned lsize = 1 << L; + + float32x4_t state_re[gsize]; + float32x4_t state_im[gsize]; + + i *= 4; + + uint64_t ii = i & ms[0]; + for (unsigned j = 1; j <= H; ++j) { + i *= 2; + ii |= i & ms[j]; + } + + auto p0 = rstate + 2 * ii; + + for (unsigned k = 0; k < hsize; ++k) { + unsigned k2 = lsize * k; + + state_re[k2] = vld1q_f32(p0 + xss[k]); + state_im[k2] = vld1q_f32(p0 + xss[k] + 4); + + if (L == 1) { + state_re[k2 + 1] = + q0 == 0 ? vrev64q_f32(state_re[k2]) + : vextq_f32(state_re[k2], state_re[k2], 2); + state_im[k2 + 1] = + q0 == 0 ? vrev64q_f32(state_im[k2]) + : vextq_f32(state_im[k2], state_im[k2], 2); + } else if (L == 2) { + state_re[k2 + 1] = vextq_f32(state_re[k2], state_re[k2], 1); + state_im[k2 + 1] = vextq_f32(state_im[k2], state_im[k2], 1); + state_re[k2 + 2] = vextq_f32(state_re[k2], state_re[k2], 2); + state_im[k2 + 2] = vextq_f32(state_im[k2], state_im[k2], 2); + state_re[k2 + 3] = vextq_f32(state_re[k2], state_re[k2], 3); + state_im[k2 + 3] = vextq_f32(state_im[k2], state_im[k2], 3); + } + } + + auto load_gate_row = [](const fp_type* gate_row) { + return [gate_row](unsigned in) { + return Complex{ + vld1q_f32(gate_row + 8 * in), + vld1q_f32(gate_row + 8 * in + 4), + }; + }; + }; + + for (unsigned k = 0; k < hsize; ++k) { + const fp_type* gate_row = w + 8 * k * gsize; + auto out_row = ApplyGateRow( + state_re, state_im, load_gate_row(gate_row)); + + StoreStateAmplitudeRow(p0, xss, k, out_row); + } + }; + + uint64_t ms[H + 1]; + uint64_t xss[1 << H]; + alignas(16) fp_type w[1 << (3 + 2 * H + L)]; + + auto m = GetMasks11(qs); + + FillIndices(state.num_qubits(), qs, ms, xss); + FillMatrix(m.qmaskl, matrix, w); + + const unsigned k = 2 + H; + const unsigned n = state.num_qubits() > k ? state.num_qubits() - k : 0; + const uint64_t size = uint64_t{1} << n; + + for_.Run(size, f, w, ms, xss, qs[0], state.get()); + } + + template + void ApplyControlledGateHH( + const std::vector& qs, const std::vector& cqs, + uint64_t cvals, const fp_type* matrix, State& state) const { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* v, + const uint64_t* ms, const uint64_t* xss, uint64_t cvalsh, + uint64_t cmaskh, fp_type* rstate) { + constexpr unsigned hsize = 1 << H; + + float32x4_t rs[hsize]; + float32x4_t is[hsize]; + + i *= 4; + + uint64_t ii = i & ms[0]; + for (unsigned j = 1; j <= H; ++j) { + i *= 2; + ii |= i & ms[j]; + } + + if ((ii & cmaskh) != cvalsh) return; + + auto p0 = rstate + 2 * ii; + + for (unsigned k = 0; k < hsize; ++k) { + rs[k] = vld1q_f32(p0 + xss[k]); + is[k] = vld1q_f32(p0 + xss[k] + 4); + } + + uint64_t j = 0; + + for (unsigned k = 0; k < hsize; ++k) { + float32x4_t ru = vdupq_n_f32(v[j]); + float32x4_t iu = vdupq_n_f32(v[j + 1]); + float32x4_t rn = vmulq_f32(rs[0], ru); + float32x4_t in = vmulq_f32(rs[0], iu); + rn = vfmsq_f32(rn, is[0], iu); + in = vfmaq_f32(in, is[0], ru); + + j += 2; + + for (unsigned l = 1; l < hsize; ++l) { + ru = vdupq_n_f32(v[j]); + iu = vdupq_n_f32(v[j + 1]); + rn = vfmaq_f32(rn, rs[l], ru); + in = vfmaq_f32(in, rs[l], iu); + rn = vfmsq_f32(rn, is[l], iu); + in = vfmaq_f32(in, is[l], ru); + + j += 2; + } + + vst1q_f32(p0 + xss[k], rn); + vst1q_f32(p0 + xss[k] + 4, in); + } + }; + + uint64_t ms[H + 1]; + uint64_t xss[1 << H]; + + auto m = GetMasks7(state.num_qubits(), qs, cqs, cvals); + FillIndices(state.num_qubits(), qs, ms, xss); + + const unsigned k = 2 + H; + const unsigned n = state.num_qubits() > k ? state.num_qubits() - k : 0; + const uint64_t size = uint64_t{1} << n; + + for_.Run(size, f, matrix, ms, xss, m.cvalsh, m.cmaskh, state.get()); + } + + template + void ApplyControlledGateHL( + const std::vector& qs, const std::vector& cqs, + uint64_t cvals, const fp_type* matrix, State& state) const { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* w, + const uint64_t* ms, const uint64_t* xss, uint64_t cvalsh, + uint64_t cmaskh, fp_type* rstate) { + constexpr unsigned hsize = 1 << H; + + float32x4_t rs[hsize]; + float32x4_t is[hsize]; + + i *= 4; + + uint64_t ii = i & ms[0]; + for (unsigned j = 1; j <= H; ++j) { + i *= 2; + ii |= i & ms[j]; + } + + if ((ii & cmaskh) != cvalsh) return; + + auto p0 = rstate + 2 * ii; + + for (unsigned k = 0; k < hsize; ++k) { + rs[k] = vld1q_f32(p0 + xss[k]); + is[k] = vld1q_f32(p0 + xss[k] + 4); + } + + uint64_t j = 0; + + for (unsigned k = 0; k < hsize; ++k) { + float32x4_t wre = vld1q_f32(w + 4 * j); + float32x4_t wim = vld1q_f32(w + 4 * (j + 1)); + float32x4_t rn = vmulq_f32(rs[0], wre); + float32x4_t in = vmulq_f32(rs[0], wim); + rn = vfmsq_f32(rn, is[0], wim); + in = vfmaq_f32(in, is[0], wre); + + j += 2; + + for (unsigned l = 1; l < hsize; ++l) { + wre = vld1q_f32(w + 4 * j); + wim = vld1q_f32(w + 4 * (j + 1)); + rn = vfmaq_f32(rn, rs[l], wre); + in = vfmaq_f32(in, rs[l], wim); + rn = vfmsq_f32(rn, is[l], wim); + in = vfmaq_f32(in, is[l], wre); + + j += 2; + } + + vst1q_f32(p0 + xss[k], rn); + vst1q_f32(p0 + xss[k] + 4, in); + } + }; + + uint64_t ms[H + 1]; + uint64_t xss[1 << H]; + alignas(16) fp_type w[1 << (3 + 2 * H)]; + + auto m = GetMasks8<2>(state.num_qubits(), qs, cqs, cvals); + FillIndices(state.num_qubits(), qs, ms, xss); + FillControlledMatrixH(m.cvalsl, m.cmaskl, matrix, w); + + const unsigned k = 2 + H; + const unsigned n = state.num_qubits() > k ? state.num_qubits() - k : 0; + const uint64_t size = uint64_t{1} << n; + + for_.Run(size, f, w, ms, xss, m.cvalsh, m.cmaskh, state.get()); + } + + template + void ApplyControlledGateL( + const std::vector& qs, const std::vector& cqs, + uint64_t cvals, const fp_type* matrix, State& state) const { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* w, + const uint64_t* ms, const uint64_t* xss, uint64_t cvalsh, + uint64_t cmaskh, unsigned q0, fp_type* rstate) { + constexpr unsigned gsize = 1 << (H + L); + constexpr unsigned hsize = 1 << H; + constexpr unsigned lsize = 1 << L; + + float32x4_t rs[gsize]; + float32x4_t is[gsize]; + + i *= 4; + + uint64_t ii = i & ms[0]; + for (unsigned j = 1; j <= H; ++j) { + i *= 2; + ii |= i & ms[j]; + } + + if ((ii & cmaskh) != cvalsh) return; + + auto p0 = rstate + 2 * ii; + + for (unsigned k = 0; k < hsize; ++k) { + unsigned k2 = lsize * k; + + rs[k2] = vld1q_f32(p0 + xss[k]); + is[k2] = vld1q_f32(p0 + xss[k] + 4); + + if (L == 1) { + rs[k2 + 1] = + q0 == 0 ? vrev64q_f32(rs[k2]) : vextq_f32(rs[k2], rs[k2], 2); + is[k2 + 1] = + q0 == 0 ? vrev64q_f32(is[k2]) : vextq_f32(is[k2], is[k2], 2); + } else if (L == 2) { + rs[k2 + 1] = vextq_f32(rs[k2], rs[k2], 1); + is[k2 + 1] = vextq_f32(is[k2], is[k2], 1); + rs[k2 + 2] = vextq_f32(rs[k2], rs[k2], 2); + is[k2 + 2] = vextq_f32(is[k2], is[k2], 2); + rs[k2 + 3] = vextq_f32(rs[k2], rs[k2], 3); + is[k2 + 3] = vextq_f32(is[k2], is[k2], 3); + } + } + + uint64_t j = 0; + + for (unsigned k = 0; k < hsize; ++k) { + float32x4_t wre = vld1q_f32(w + 4 * j); + float32x4_t wim = vld1q_f32(w + 4 * (j + 1)); + float32x4_t rn = vmulq_f32(rs[0], wre); + float32x4_t in = vmulq_f32(rs[0], wim); + rn = vfmsq_f32(rn, is[0], wim); + in = vfmaq_f32(in, is[0], wre); + + j += 2; + + for (unsigned l = 1; l < gsize; ++l) { + wre = vld1q_f32(w + 4 * j); + wim = vld1q_f32(w + 4 * (j + 1)); + rn = vfmaq_f32(rn, rs[l], wre); + in = vfmaq_f32(in, rs[l], wim); + rn = vfmsq_f32(rn, is[l], wim); + in = vfmaq_f32(in, is[l], wre); + + j += 2; + } + + vst1q_f32(p0 + xss[k], rn); + vst1q_f32(p0 + xss[k] + 4, in); + } + }; + + uint64_t ms[H + 1]; + uint64_t xss[1 << H]; + alignas(16) fp_type w[1 << (3 + 2 * H + L)]; + + FillIndices(state.num_qubits(), qs, ms, xss); + + const unsigned k = 2 + H; + const unsigned n = state.num_qubits() > k ? state.num_qubits() - k : 0; + const uint64_t size = uint64_t{1} << n; + + if (CH) { + auto m = GetMasks9(state.num_qubits(), qs, cqs, cvals); + FillMatrix(m.qmaskl, matrix, w); + + for_.Run(size, f, w, ms, xss, m.cvalsh, m.cmaskh, qs[0], state.get()); + } else { + auto m = GetMasks10(state.num_qubits(), qs, cqs, cvals); + FillControlledMatrixL(m.cvalsl, m.cmaskl, m.qmaskl, matrix, w); + + for_.Run(size, f, w, ms, xss, m.cvalsh, m.cmaskh, qs[0], state.get()); + } + } + + template + std::complex ExpectationValueH( + const std::vector& qs, const fp_type* matrix, + const State& state) const { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* v, + const uint64_t* ms, const uint64_t* xss, + const fp_type* rstate) -> std::complex { + constexpr unsigned hsize = 1 << H; + + float32x4_t rs[hsize]; + float32x4_t is[hsize]; + + i *= 4; + + uint64_t ii = i & ms[0]; + for (unsigned j = 1; j <= H; ++j) { + i *= 2; + ii |= i & ms[j]; + } + + auto p0 = rstate + 2 * ii; + + for (unsigned k = 0; k < hsize; ++k) { + rs[k] = vld1q_f32(p0 + xss[k]); + is[k] = vld1q_f32(p0 + xss[k] + 4); + } + + double re = 0; + double im = 0; + uint64_t j = 0; + + for (unsigned k = 0; k < hsize; ++k) { + float32x4_t ru = vdupq_n_f32(v[j]); + float32x4_t iu = vdupq_n_f32(v[j + 1]); + float32x4_t rn = vmulq_f32(rs[0], ru); + float32x4_t in = vmulq_f32(rs[0], iu); + rn = vfmsq_f32(rn, is[0], iu); + in = vfmaq_f32(in, is[0], ru); + + j += 2; + + for (unsigned l = 1; l < hsize; ++l) { + ru = vdupq_n_f32(v[j]); + iu = vdupq_n_f32(v[j + 1]); + rn = vfmaq_f32(rn, rs[l], ru); + in = vfmaq_f32(in, rs[l], iu); + rn = vfmsq_f32(rn, is[l], iu); + in = vfmaq_f32(in, is[l], ru); + j += 2; + } + + float32x4_t v_re = vmulq_f32(rs[k], rn); + v_re = vfmaq_f32(v_re, is[k], in); + float32x4_t v_im = vmulq_f32(rs[k], in); + v_im = vfmsq_f32(v_im, is[k], rn); + + re += detail::HorizontalSumNEON(v_re); + im += detail::HorizontalSumNEON(v_im); + } + + return std::complex{re, im}; + }; + + uint64_t ms[H + 1]; + uint64_t xss[1 << H]; + + FillIndices(state.num_qubits(), qs, ms, xss); + + const unsigned k = 2 + H; + const unsigned n = state.num_qubits() > k ? state.num_qubits() - k : 0; + const uint64_t size = uint64_t{1} << n; + + using Op = std::plus>; + return for_.RunReduce(size, f, Op(), matrix, ms, xss, state.get()); + } + + template + std::complex ExpectationValueL( + const std::vector& qs, const fp_type* matrix, + const State& state) const { + auto f = [](unsigned n, unsigned m, uint64_t i, const fp_type* w, + const uint64_t* ms, const uint64_t* xss, unsigned q0, + const fp_type* rstate) -> std::complex { + constexpr unsigned gsize = 1 << (H + L); + constexpr unsigned hsize = 1 << H; + constexpr unsigned lsize = 1 << L; + + float32x4_t rs[gsize]; + float32x4_t is[gsize]; + + i *= 4; + + uint64_t ii = i & ms[0]; + for (unsigned j = 1; j <= H; ++j) { + i *= 2; + ii |= i & ms[j]; + } + + auto p0 = rstate + 2 * ii; + + for (unsigned k = 0; k < hsize; ++k) { + unsigned k2 = lsize * k; + + rs[k2] = vld1q_f32(p0 + xss[k]); + is[k2] = vld1q_f32(p0 + xss[k] + 4); + + if (L == 1) { + rs[k2 + 1] = + q0 == 0 ? vrev64q_f32(rs[k2]) : vextq_f32(rs[k2], rs[k2], 2); + is[k2 + 1] = + q0 == 0 ? vrev64q_f32(is[k2]) : vextq_f32(is[k2], is[k2], 2); + } else if (L == 2) { + rs[k2 + 1] = vextq_f32(rs[k2], rs[k2], 1); + is[k2 + 1] = vextq_f32(is[k2], is[k2], 1); + rs[k2 + 2] = vextq_f32(rs[k2], rs[k2], 2); + is[k2 + 2] = vextq_f32(is[k2], is[k2], 2); + rs[k2 + 3] = vextq_f32(rs[k2], rs[k2], 3); + is[k2 + 3] = vextq_f32(is[k2], is[k2], 3); + } + } + + double re = 0; + double im = 0; + uint64_t j = 0; + + for (unsigned k = 0; k < hsize; ++k) { + float32x4_t wre = vld1q_f32(w + 4 * j); + float32x4_t wim = vld1q_f32(w + 4 * (j + 1)); + float32x4_t rn = vmulq_f32(rs[0], wre); + float32x4_t in = vmulq_f32(rs[0], wim); + rn = vfmsq_f32(rn, is[0], wim); + in = vfmaq_f32(in, is[0], wre); + j += 2; + + for (unsigned l = 1; l < gsize; ++l) { + wre = vld1q_f32(w + 4 * j); + wim = vld1q_f32(w + 4 * (j + 1)); + rn = vfmaq_f32(rn, rs[l], wre); + in = vfmaq_f32(in, rs[l], wim); + rn = vfmsq_f32(rn, is[l], wim); + in = vfmaq_f32(in, is[l], wre); + j += 2; + } + + unsigned m = lsize * k; + float32x4_t v_re = vmulq_f32(rs[m], rn); + v_re = vfmaq_f32(v_re, is[m], in); + float32x4_t v_im = vmulq_f32(rs[m], in); + v_im = vfmsq_f32(v_im, is[m], rn); + + re += detail::HorizontalSumNEON(v_re); + im += detail::HorizontalSumNEON(v_im); + } + + return std::complex{re, im}; + }; + + uint64_t ms[H + 1]; + uint64_t xss[1 << H]; + alignas(16) fp_type w[1 << (3 + 2 * H + L)]; + + auto m = GetMasks11(qs); + + FillIndices(state.num_qubits(), qs, ms, xss); + FillMatrix(m.qmaskl, matrix, w); + + const unsigned k = 2 + H; + const unsigned n = state.num_qubits() > k ? state.num_qubits() - k : 0; + const uint64_t size = uint64_t{1} << n; + + using Op = std::plus>; + return for_.RunReduce(size, f, Op(), w, ms, xss, qs[0], state.get()); + } + For for_; +}; + +} // namespace qsim + +#endif // SIMULATOR_NEON_H_ diff --git a/lib/statespace_neon.h b/lib/statespace_neon.h new file mode 100644 index 000000000..14480cd0b --- /dev/null +++ b/lib/statespace_neon.h @@ -0,0 +1,433 @@ +// Copyright 2026 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef STATESPACE_NEON_H_ +#define STATESPACE_NEON_H_ + +#include +#include +#include +#include +#include + +#if !defined(__ARM_NEON__) && !defined(__ARM_NEON) +#error "statespace_neon.h requires __ARM_NEON__." +#endif +#include + +#include "statespace.h" +#include "util.h" +#include "vectorspace.h" + +namespace qsim { + +namespace detail { + +inline double HorizontalSumNEON(float32x4_t s) { + float32x2_t s2 = vadd_f32(vget_low_f32(s), vget_high_f32(s)); + s2 = vpadd_f32(s2, s2); + return vget_lane_f32(s2, 0); +} + +inline uint32x4_t GetZeroMaskNEON(uint64_t i, uint64_t mask, uint64_t bits) { + alignas(16) uint32_t lanes[4]; + for (unsigned j = 0; j < 4; ++j) { + lanes[j] = ((i + j) & mask) == bits ? ~uint32_t{0} : 0; + } + return vld1q_u32(lanes); +} + +} // namespace detail + +/** + * Object containing context and routines for NEON state-vector manipulations. + * State is a vectorized sequence of four real components followed by four + * imaginary components. Four single-precison floating numbers can be loaded + * into a NEON register. + */ +template +class StateSpaceNEON : + public StateSpace, VectorSpace, For, float> { + private: + using Base = StateSpace, qsim::VectorSpace, For, float>; + + public: + using State = typename Base::State; + using fp_type = typename Base::fp_type; + + template + explicit StateSpaceNEON(ForArgs&&... args) : Base(args...) {} + + static uint64_t MinSize(unsigned num_qubits) { + return std::max(uint64_t{8}, 2 * (uint64_t{1} << num_qubits)); + } + + void InternalToNormalOrder(State& state) const { + if (state.num_qubits() == 1) { + auto s = state.get(); + + s[2] = s[1]; + s[1] = s[4]; + s[3] = s[5]; + + for (uint64_t i = 4; i < 8; ++i) { + s[i] = 0; + } + } else { + auto f = [](unsigned n, unsigned m, uint64_t i, fp_type* p) { + auto s = p + 8 * i; + + fp_type re[3]; + fp_type im[3]; + + for (uint64_t j = 0; j < 3; ++j) { + re[j] = s[j + 1]; + im[j] = s[j + 4]; + } + + for (uint64_t j = 0; j < 3; ++j) { + s[2 * j + 1] = im[j]; + s[2 * j + 2] = re[j]; + } + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 8, f, state.get()); + } + } + + void NormalToInternalOrder(State& state) const { + if (state.num_qubits() == 1) { + auto s = state.get(); + + s[4] = s[1]; + s[1] = s[2]; + s[5] = s[3]; + + s[2] = 0; + s[3] = 0; + s[6] = 0; + s[7] = 0; + } else { + auto f = [](unsigned n, unsigned m, uint64_t i, fp_type* p) { + auto s = p + 8 * i; + + fp_type re[3]; + fp_type im[3]; + + for (uint64_t j = 0; j < 3; ++j) { + im[j] = s[2 * j + 1]; + re[j] = s[2 * j + 2]; + } + + for (uint64_t j = 0; j < 3; ++j) { + s[j + 1] = re[j]; + s[j + 4] = im[j]; + } + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 8, f, state.get()); + } + } + + void SetAllZeros(State& state) const { + float32x4_t zero = vdupq_n_f32(0.0f); + + auto f = [](unsigned n, unsigned m, uint64_t i, + float32x4_t zero, fp_type* p) { + vst1q_f32(p + 8 * i, zero); + vst1q_f32(p + 8 * i + 4, zero); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 8, f, zero, state.get()); + } + + // Uniform superposition. + void SetStateUniform(State& state) const { + fp_type v = fp_type{1} / std::sqrt(uint64_t{1} << state.num_qubits()); + + float32x4_t zero = vdupq_n_f32(0.0f); + float32x4_t valu; + + if (state.num_qubits() == 1) { + alignas(16) float lanes[4] = {v, v, 0, 0}; + valu = vld1q_f32(lanes); + } else { + valu = vdupq_n_f32(v); + } + + auto f = [](unsigned n, unsigned m, uint64_t i, + float32x4_t zero, float32x4_t valu, fp_type* p) { + vst1q_f32(p + 8 * i, valu); + vst1q_f32(p + 8 * i + 4, zero); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 8, f, zero, valu, state.get()); + } + + // |0> state. + void SetStateZero(State& state) const { + SetAllZeros(state); + state.get()[0] = 1; + } + + static std::complex GetAmpl(const State& state, uint64_t i) { + uint64_t p = 8 * (i / 4) + (i % 4); + return std::complex(state.get()[p], state.get()[p + 4]); + } + + static void SetAmpl( + State& state, uint64_t i, const std::complex& ampl) { + uint64_t p = 8 * (i / 4) + (i % 4); + state.get()[p] = std::real(ampl); + state.get()[p + 4] = std::imag(ampl); + } + + static void SetAmpl(State& state, uint64_t i, fp_type re, fp_type im) { + uint64_t p = 8 * (i / 4) + (i % 4); + state.get()[p] = re; + state.get()[p + 4] = im; + } + + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, + const std::complex& val, + bool exclude = false) const { + BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val), exclude); + } + + void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re, + fp_type im, bool exclude = false) const { + float32x4_t re_n = vdupq_n_f32(re); + float32x4_t im_n = vdupq_n_f32(im); + uint32x4_t exclude_n = exclude ? vdupq_n_u32(~uint32_t{0}) : vdupq_n_u32(0); + + auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv, + uint64_t bitsv, float32x4_t re_n, float32x4_t im_n, + uint32x4_t exclude_n, fp_type* p) { + uint32x4_t ml = veorq_u32(detail::GetZeroMaskNEON(4 * i, maskv, bitsv), + exclude_n); + float32x4_t re = vld1q_f32(p + 8 * i); + float32x4_t im = vld1q_f32(p + 8 * i + 4); + + re = vbslq_f32(ml, re_n, re); + im = vbslq_f32(ml, im_n, im); + + vst1q_f32(p + 8 * i, re); + vst1q_f32(p + 8 * i + 4, im); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 8, f, mask, bits, re_n, + im_n, exclude_n, state.get()); + } + + bool Add(const State& src, State& dest) const { + if (src.num_qubits() != dest.num_qubits()) return false; + + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, fp_type* p2) { + float32x4_t re1 = vld1q_f32(p1 + 8 * i); + float32x4_t im1 = vld1q_f32(p1 + 8 * i + 4); + float32x4_t re2 = vld1q_f32(p2 + 8 * i); + float32x4_t im2 = vld1q_f32(p2 + 8 * i + 4); + + vst1q_f32(p2 + 8 * i, vaddq_f32(re1, re2)); + vst1q_f32(p2 + 8 * i + 4, vaddq_f32(im1, im2)); + }; + + Base::for_.Run(MinSize(src.num_qubits()) / 8, f, src.get(), dest.get()); + return true; + } + + void Multiply(fp_type a, State& state) const { + float32x4_t r = vdupq_n_f32(a); + + auto f = [](unsigned n, unsigned m, uint64_t i, float32x4_t r, fp_type* p) { + float32x4_t re = vld1q_f32(p + 8 * i); + float32x4_t im = vld1q_f32(p + 8 * i + 4); + + vst1q_f32(p + 8 * i, vmulq_f32(re, r)); + vst1q_f32(p + 8 * i + 4, vmulq_f32(im, r)); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 8, f, r, state.get()); + } + + std::complex InnerProduct( + const State& state1, const State& state2) const { + if (state1.num_qubits() != state2.num_qubits()) { + return std::nan(""); + } + + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, const fp_type* p2) -> std::complex { + float32x4_t re1 = vld1q_f32(p1 + 8 * i); + float32x4_t im1 = vld1q_f32(p1 + 8 * i + 4); + float32x4_t re2 = vld1q_f32(p2 + 8 * i); + float32x4_t im2 = vld1q_f32(p2 + 8 * i + 4); + + float32x4_t ip_re = vmulq_f32(re1, re2); + ip_re = vfmaq_f32(ip_re, im1, im2); + float32x4_t ip_im = vmulq_f32(re1, im2); + ip_im = vfmsq_f32(ip_im, im1, re2); + + return std::complex(detail::HorizontalSumNEON(ip_re), + detail::HorizontalSumNEON(ip_im)); + }; + + using Op = std::plus>; + return Base::for_.RunReduce( + MinSize(state1.num_qubits()) / 8, f, Op(), state1.get(), state2.get()); + } + + double RealInnerProduct(const State& state1, const State& state2) const { + if (state1.num_qubits() != state2.num_qubits()) { + return std::nan(""); + } + + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p1, const fp_type* p2) -> double { + float32x4_t re1 = vld1q_f32(p1 + 8 * i); + float32x4_t im1 = vld1q_f32(p1 + 8 * i + 4); + float32x4_t re2 = vld1q_f32(p2 + 8 * i); + float32x4_t im2 = vld1q_f32(p2 + 8 * i + 4); + + float32x4_t ip_re = vmulq_f32(re1, re2); + ip_re = vfmaq_f32(ip_re, im1, im2); + return detail::HorizontalSumNEON(ip_re); + }; + + using Op = std::plus; + return Base::for_.RunReduce( + MinSize(state1.num_qubits()) / 8, f, Op(), state1.get(), state2.get()); + } + + template + std::vector Sample( + const State& state, uint64_t num_samples, unsigned seed) const { + std::vector bitstrings; + + if (num_samples > 0) { + double norm = this->Norm(state); + uint64_t size = MinSize(state.num_qubits()) / 8; + const fp_type* p = state.get(); + + auto rs = GenerateRandomValues(num_samples, seed, norm); + + uint64_t m = 0; + double csum = 0; + bitstrings.reserve(num_samples); + + for (uint64_t k = 0; k < size; ++k) { + for (unsigned j = 0; j < 4; ++j) { + double re = p[8 * k + j]; + double im = p[8 * k + 4 + j]; + csum += re * re + im * im; + while (m < num_samples && rs[m] < csum) { + bitstrings.emplace_back(4 * k + j); + ++m; + } + } + } + + for (; m < num_samples; ++m) { + bitstrings.emplace_back((uint64_t{1} << state.num_qubits()) - 1); + } + } + + return bitstrings; + } + + using MeasurementResult = typename Base::MeasurementResult; + + void Collapse(const MeasurementResult& mr, State& state) const { + float32x4_t zero = vdupq_n_f32(0.0f); + + auto f1 = [](unsigned n, unsigned m, uint64_t i, uint64_t mask, + uint64_t bits, float32x4_t zero, + const fp_type* p) -> double { + uint32x4_t ml = detail::GetZeroMaskNEON(4 * i, mask, bits); + float32x4_t re = vld1q_f32(p + 8 * i); + float32x4_t im = vld1q_f32(p + 8 * i + 4); + float32x4_t s1 = vmulq_f32(re, re); + s1 = vfmaq_f32(s1, im, im); + s1 = vbslq_f32(ml, s1, zero); + return detail::HorizontalSumNEON(s1); + }; + + using Op = std::plus; + double norm = Base::for_.RunReduce(MinSize(state.num_qubits()) / 8, f1, + Op(), mr.mask, mr.bits, zero, + state.get()); + + float32x4_t renorm = vdupq_n_f32(1.0 / std::sqrt(norm)); + + auto f2 = [](unsigned n, unsigned m, uint64_t i, uint64_t mask, + uint64_t bits, float32x4_t renorm, float32x4_t zero, + fp_type* p) { + uint32x4_t ml = detail::GetZeroMaskNEON(4 * i, mask, bits); + float32x4_t re = vld1q_f32(p + 8 * i); + float32x4_t im = vld1q_f32(p + 8 * i + 4); + + re = vbslq_f32(ml, vmulq_f32(re, renorm), zero); + im = vbslq_f32(ml, vmulq_f32(im, renorm), zero); + + vst1q_f32(p + 8 * i, re); + vst1q_f32(p + 8 * i + 4, im); + }; + + Base::for_.Run(MinSize(state.num_qubits()) / 8, f2, + mr.mask, mr.bits, renorm, zero, state.get()); + } + + std::vector PartialNorms(const State& state) const { + auto f = [](unsigned n, unsigned m, uint64_t i, + const fp_type* p) -> double { + float32x4_t re = vld1q_f32(p + 8 * i); + float32x4_t im = vld1q_f32(p + 8 * i + 4); + float32x4_t s1 = vmulq_f32(re, re); + s1 = vfmaq_f32(s1, im, im); + return detail::HorizontalSumNEON(s1); + }; + + using Op = std::plus; + return Base::for_.RunReduceP( + MinSize(state.num_qubits()) / 8, f, Op(), state.get()); + } + + uint64_t FindMeasuredBits( + unsigned m, double r, uint64_t mask, const State& state) const { + double csum = 0; + + uint64_t k0 = Base::for_.GetIndex0(MinSize(state.num_qubits()) / 8, m); + uint64_t k1 = Base::for_.GetIndex1(MinSize(state.num_qubits()) / 8, m); + + const fp_type* p = state.get(); + + for (uint64_t k = k0; k < k1; ++k) { + for (uint64_t j = 0; j < 4; ++j) { + auto re = p[8 * k + j]; + auto im = p[8 * k + 4 + j]; + csum += re * re + im * im; + if (r < csum) { + return (4 * k + j) & mask; + } + } + } + + return (4 * k1 - 1) & mask; + } +}; + +} // namespace qsim + +#endif // STATESPACE_NEON_H_