From 7c13c6342d101d43fd56e2b4e7eb184537d7e9f0 Mon Sep 17 00:00:00 2001 From: jake massimo Date: Tue, 16 Jun 2026 17:45:15 +0000 Subject: [PATCH] x86_64: Replace rej_uniform_eta2/eta4 intrinsics with hand-written assembly Add hand-written x86_64 AVX2 assembly for rej_uniform_eta2 and rej_uniform_eta4 and remove the AVX2 intrinsics implementations they replace, following the rej_uniform approach in #1014: the table is passed as a parameter and all constants are built from immediates (no .rodata), enabling future HOL-Light verification. Both eta2 and eta4 are wired to the new asm in meta.h, with contracts in arith_native_x86_64.h, bytecode dump targets in autogen and the Makefile, and a poly_uniform_eta_4x component benchmark. The asm entry points are declared MLD_SYSV_ABI (like the other x86_64 asm routines) so they are called with the System V register convention on all platforms, including Windows/MinGW. The endbr64 is emitted via MLD_ASM_FN_SYMBOL (CET-gated) rather than a raw mnemonic, so older assemblers (e.g. clang-6) build cleanly. The eta2 vector path applies the centered mod-5 reduction to (2 - nibble) directly (matching the reference), rather than reducing the raw nibble and subtracting afterwards; the two are not equivalent because vpmulhrsw rounds to nearest. Verified against the ACVP keyGen vectors for all parameter sets. Signed-off-by: jake massimo --- BIBLIOGRAPHY.md | 10 +- dev/x86_64/meta.h | 6 +- dev/x86_64/src/arith_native_x86_64.h | 38 ++- dev/x86_64/src/rej_uniform_eta2_avx2.c | 157 ---------- dev/x86_64/src/rej_uniform_eta2_avx2_asm.S | 278 ++++++++++++++++++ dev/x86_64/src/rej_uniform_eta4_avx2.c | 141 --------- dev/x86_64/src/rej_uniform_eta4_avx2_asm.S | 232 +++++++++++++++ mldsa/mldsa_native.c | 6 +- mldsa/mldsa_native_asm.S | 6 +- mldsa/src/native/x86_64/meta.h | 6 +- .../native/x86_64/src/arith_native_x86_64.h | 38 ++- .../native/x86_64/src/rej_uniform_eta2_avx2.c | 157 ---------- .../x86_64/src/rej_uniform_eta2_avx2_asm.S | 174 +++++++++++ .../native/x86_64/src/rej_uniform_eta4_avx2.c | 141 --------- .../x86_64/src/rej_uniform_eta4_avx2_asm.S | 145 +++++++++ proofs/hol_light/x86_64/Makefile | 8 +- .../x86_64/mldsa/rej_uniform_eta2_avx2_asm.S | 173 +++++++++++ .../x86_64/mldsa/rej_uniform_eta4_avx2_asm.S | 144 +++++++++ scripts/autogen | 12 + test/bench/bench_components_mldsa.c | 12 + 20 files changed, 1255 insertions(+), 629 deletions(-) delete mode 100644 dev/x86_64/src/rej_uniform_eta2_avx2.c create mode 100644 dev/x86_64/src/rej_uniform_eta2_avx2_asm.S delete mode 100644 dev/x86_64/src/rej_uniform_eta4_avx2.c create mode 100644 dev/x86_64/src/rej_uniform_eta4_avx2_asm.S delete mode 100644 mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c create mode 100644 mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S delete mode 100644 mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c create mode 100644 mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S create mode 100644 proofs/hol_light/x86_64/mldsa/rej_uniform_eta2_avx2_asm.S create mode 100644 proofs/hol_light/x86_64/mldsa/rej_uniform_eta4_avx2_asm.S diff --git a/BIBLIOGRAPHY.md b/BIBLIOGRAPHY.md index 2ed637232..011ee5bd4 100644 --- a/BIBLIOGRAPHY.md +++ b/BIBLIOGRAPHY.md @@ -279,8 +279,8 @@ source code and documentation. - [dev/x86_64/src/polyz_unpack_17_avx2_asm.S](dev/x86_64/src/polyz_unpack_17_avx2_asm.S) - [dev/x86_64/src/polyz_unpack_19_avx2_asm.S](dev/x86_64/src/polyz_unpack_19_avx2_asm.S) - [dev/x86_64/src/rej_uniform_avx2.c](dev/x86_64/src/rej_uniform_avx2.c) - - [dev/x86_64/src/rej_uniform_eta2_avx2.c](dev/x86_64/src/rej_uniform_eta2_avx2.c) - - [dev/x86_64/src/rej_uniform_eta4_avx2.c](dev/x86_64/src/rej_uniform_eta4_avx2.c) + - [dev/x86_64/src/rej_uniform_eta2_avx2_asm.S](dev/x86_64/src/rej_uniform_eta2_avx2_asm.S) + - [dev/x86_64/src/rej_uniform_eta4_avx2_asm.S](dev/x86_64/src/rej_uniform_eta4_avx2_asm.S) - [mldsa/src/native/x86_64/src/intt_avx2_asm.S](mldsa/src/native/x86_64/src/intt_avx2_asm.S) - [mldsa/src/native/x86_64/src/ntt_avx2_asm.S](mldsa/src/native/x86_64/src/ntt_avx2_asm.S) - [mldsa/src/native/x86_64/src/nttunpack_avx2_asm.S](mldsa/src/native/x86_64/src/nttunpack_avx2_asm.S) @@ -297,8 +297,8 @@ source code and documentation. - [mldsa/src/native/x86_64/src/polyz_unpack_17_avx2_asm.S](mldsa/src/native/x86_64/src/polyz_unpack_17_avx2_asm.S) - [mldsa/src/native/x86_64/src/polyz_unpack_19_avx2_asm.S](mldsa/src/native/x86_64/src/polyz_unpack_19_avx2_asm.S) - [mldsa/src/native/x86_64/src/rej_uniform_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_avx2.c) - - [mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c) - - [mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c) + - [mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S](mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S) + - [mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S](mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S) - [proofs/hol_light/x86_64/mldsa/intt_avx2_asm.S](proofs/hol_light/x86_64/mldsa/intt_avx2_asm.S) - [proofs/hol_light/x86_64/mldsa/ntt_avx2_asm.S](proofs/hol_light/x86_64/mldsa/ntt_avx2_asm.S) - [proofs/hol_light/x86_64/mldsa/nttunpack_avx2_asm.S](proofs/hol_light/x86_64/mldsa/nttunpack_avx2_asm.S) @@ -310,6 +310,8 @@ source code and documentation. - [proofs/hol_light/x86_64/mldsa/poly_chknorm_avx2_asm.S](proofs/hol_light/x86_64/mldsa/poly_chknorm_avx2_asm.S) - [proofs/hol_light/x86_64/mldsa/polyz_unpack_17_avx2_asm.S](proofs/hol_light/x86_64/mldsa/polyz_unpack_17_avx2_asm.S) - [proofs/hol_light/x86_64/mldsa/polyz_unpack_19_avx2_asm.S](proofs/hol_light/x86_64/mldsa/polyz_unpack_19_avx2_asm.S) + - [proofs/hol_light/x86_64/mldsa/rej_uniform_eta2_avx2_asm.S](proofs/hol_light/x86_64/mldsa/rej_uniform_eta2_avx2_asm.S) + - [proofs/hol_light/x86_64/mldsa/rej_uniform_eta4_avx2_asm.S](proofs/hol_light/x86_64/mldsa/rej_uniform_eta4_avx2_asm.S) ### `Round3_Spec` diff --git a/dev/x86_64/meta.h b/dev/x86_64/meta.h index 55924ffec..c0124b3c2 100644 --- a/dev/x86_64/meta.h +++ b/dev/x86_64/meta.h @@ -106,7 +106,8 @@ static MLD_INLINE int mld_rej_uniform_eta2_native(int32_t *r, unsigned len, * We declassify prior the input data and mark the outputs as secret. */ MLD_CT_TESTING_DECLASSIFY(buf, buflen); - outlen = mld_rej_uniform_eta2_avx2(r, buf); + outlen = mld_rej_uniform_eta2_avx2_asm( + r, buf, (const uint8_t *)mld_rej_uniform_table); MLD_CT_TESTING_SECRET(r, sizeof(int32_t) * outlen); /* Safety: outlen is at most MLDSA_N and, hence, this cast is safe. */ return (int)outlen; @@ -135,7 +136,8 @@ static MLD_INLINE int mld_rej_uniform_eta4_native(int32_t *r, unsigned len, * We declassify prior the input data and mark the outputs as secret. */ MLD_CT_TESTING_DECLASSIFY(buf, buflen); - outlen = mld_rej_uniform_eta4_avx2(r, buf); + outlen = mld_rej_uniform_eta4_avx2_asm( + r, buf, (const uint8_t *)mld_rej_uniform_table); MLD_CT_TESTING_SECRET(r, sizeof(int32_t) * outlen); /* Safety: outlen is at most MLDSA_N and, hence, this cast is safe. */ return (int)outlen; diff --git a/dev/x86_64/src/arith_native_x86_64.h b/dev/x86_64/src/arith_native_x86_64.h index 6ec3c1434..189b635d5 100644 --- a/dev/x86_64/src/arith_native_x86_64.h +++ b/dev/x86_64/src/arith_native_x86_64.h @@ -83,15 +83,37 @@ unsigned mld_rej_uniform_avx2(int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN]); #if !defined(MLD_CONFIG_NO_KEYPAIR_API) -#define mld_rej_uniform_eta2_avx2 MLD_NAMESPACE(mld_rej_uniform_eta2_avx2) -MLD_MUST_CHECK_RETURN_VALUE -unsigned mld_rej_uniform_eta2_avx2( - int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN]); +#define mld_rej_uniform_eta2_avx2_asm MLD_NAMESPACE(rej_uniform_eta2_avx2_asm) +/* This contract must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/rej_uniform_eta2_avx2_asm.ml */ +MLD_MUST_CHECK_RETURN_VALUE MLD_SYSV_ABI +unsigned mld_rej_uniform_eta2_avx2_asm( + int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN], + const uint8_t *table) +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN)) + requires(table == mld_rej_uniform_table) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(return_value <= MLDSA_N) + ensures(array_bound(r, 0, return_value, -2, 2)) +); -#define mld_rej_uniform_eta4_avx2 MLD_NAMESPACE(mld_rej_uniform_eta4_avx2) -MLD_MUST_CHECK_RETURN_VALUE -unsigned mld_rej_uniform_eta4_avx2( - int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN]); +#define mld_rej_uniform_eta4_avx2_asm MLD_NAMESPACE(rej_uniform_eta4_avx2_asm) +/* This contract must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/rej_uniform_eta4_avx2_asm.ml */ +MLD_MUST_CHECK_RETURN_VALUE MLD_SYSV_ABI +unsigned mld_rej_uniform_eta4_avx2_asm( + int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN], + const uint8_t *table) +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN)) + requires(table == mld_rej_uniform_table) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(return_value <= MLDSA_N) + ensures(array_bound(r, 0, return_value, -4, 4)) +); #endif /* !MLD_CONFIG_NO_KEYPAIR_API */ #if !defined(MLD_CONFIG_NO_SIGN_API) diff --git a/dev/x86_64/src/rej_uniform_eta2_avx2.c b/dev/x86_64/src/rej_uniform_eta2_avx2.c deleted file mode 100644 index dac73995b..000000000 --- a/dev/x86_64/src/rej_uniform_eta2_avx2.c +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright (c) The mldsa-native project authors - * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT - */ - -/* References - * ========== - * - * - [REF_AVX2] - * CRYSTALS-Dilithium optimized AVX2 implementation - * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé - * https://github.com/pq-crystals/dilithium/tree/master/avx2 - */ - -/* - * This file is derived from the public domain - * AVX2 Dilithium implementation @[REF_AVX2]. - */ - -#include "../../../common.h" - -#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ - !defined(MLD_CONFIG_NO_KEYPAIR_API) && \ - !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ - (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || MLDSA_ETA == 2) - -#include -#include "arith_native_x86_64.h" -#include "consts.h" - -#define MLD_AVX2_ETA2 2 - -/* - * Reference: In the pqcrystals implementation this function is called - * rej_eta_avx and supports multiple values for ETA via preprocessor - * conditionals. We move the conditionals to the frontend. - */ -unsigned int mld_rej_uniform_eta2_avx2( - int32_t *MLD_RESTRICT r, - const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN]) -{ - unsigned int ctr, pos; - uint32_t good; - __m256i f0, f1, f2; - __m128i g0, g1; - const __m256i mask = _mm256_set1_epi8(15); - const __m256i eta = _mm256_set1_epi8(MLD_AVX2_ETA2); - const __m256i bound = mask; - /* check-magic: -6560 == 32*round(-2**10 / 5) */ - const __m256i v = _mm256_set1_epi32(-6560); - const __m256i p = _mm256_set1_epi32(5); - - ctr = pos = 0; - while (ctr <= MLDSA_N - 8 && pos <= MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN - 16) - { - f0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)&buf[pos])); - f1 = _mm256_slli_epi16(f0, 4); - f0 = _mm256_or_si256(f0, f1); - f0 = _mm256_and_si256(f0, mask); - - f1 = _mm256_sub_epi8(f0, bound); - f0 = _mm256_sub_epi8(eta, f0); - good = (uint32_t)_mm256_movemask_epi8(f1); - - g0 = _mm256_castsi256_si128(f0); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - f2 = _mm256_mulhrs_epi16(f1, v); - f2 = _mm256_mullo_epi16(f2, p); - f1 = _mm256_add_epi32(f1, f2); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm_bsrli_si128(g0, 8); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - f2 = _mm256_mulhrs_epi16(f1, v); - f2 = _mm256_mullo_epi16(f2, p); - f1 = _mm256_add_epi32(f1, f2); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm256_extracti128_si256(f0, 1); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - f2 = _mm256_mulhrs_epi16(f1, v); - f2 = _mm256_mullo_epi16(f2, p); - f1 = _mm256_add_epi32(f1, f2); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm_bsrli_si128(g0, 8); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - f2 = _mm256_mulhrs_epi16(f1, v); - f2 = _mm256_mullo_epi16(f2, p); - f1 = _mm256_add_epi32(f1, f2); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good); - pos += 4; - } - - while (ctr < MLDSA_N && pos < MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN) - { - uint32_t t0 = buf[pos] & 0x0F; - uint32_t t1 = buf[pos++] >> 4; - - if (t0 < 15) - { - t0 = t0 - (205 * t0 >> 10) * 5; - r[ctr++] = (int32_t)(2 - t0); - } - if (t1 < 15 && ctr < MLDSA_N) - { - t1 = t1 - (205 * t1 >> 10) * 5; - r[ctr++] = (int32_t)(2 - t1); - } - } - - return ctr; -} - -#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_KEYPAIR_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLDSA_ETA == 2) */ - -MLD_EMPTY_CU(avx2_rej_uniform_eta2) - -#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_KEYPAIR_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLDSA_ETA == 2)) */ - -/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. - * Don't modify by hand -- this is auto-generated by scripts/autogen. */ -#undef MLD_AVX2_ETA2 diff --git a/dev/x86_64/src/rej_uniform_eta2_avx2_asm.S b/dev/x86_64/src/rej_uniform_eta2_avx2_asm.S new file mode 100644 index 000000000..e77cb675d --- /dev/null +++ b/dev/x86_64/src/rej_uniform_eta2_avx2_asm.S @@ -0,0 +1,278 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) +/* simpasm: header-end */ + +#define out %rdi +#define in %rsi +#define tab %rdx + +#define ctr %eax +#define pos %ecx + +#define good %r8d +#define cnt %r9d +#define tmp %r10 +#define val %r11d + +#define f0 %ymm0 +#define f1 %ymm1 +#define f2 %ymm2 +#define mask %ymm3 +#define eta %ymm4 +#define bound %ymm5 +#define v_const %ymm6 +#define p_const %ymm7 +#define g0 %xmm8 +#define g1 %xmm9 + + .text + +/* + * unsigned mld_rej_uniform_eta2_avx2_asm(int32_t *r, const uint8_t *buf, + * const uint8_t *table) + * + * Rejection sampling for ETA=2 polynomial coefficients. + * Extracts 4-bit nibbles from a byte buffer and accepts those < 15. + * Applies modulo-5 reduction: t = t - (205 * t >> 10) * 5 + * Output: coefficient = 2 - t, producing values in [-2, 2]. + * + * Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t) + * in (rsi): pointer to input byte buffer buf + * tab (rdx): pointer to rejection sampling lookup table (256x8 bytes) + * + * Returns: ctr (eax): number of valid coefficients written to r + */ + .balign 4 + .global MLD_ASM_NAMESPACE(rej_uniform_eta2_avx2_asm) +MLD_ASM_FN_SYMBOL(rej_uniform_eta2_avx2_asm) + +// Construct broadcast constants + movl $0x0F0F0F0F, good + vmovd good, %xmm3 + vpbroadcastd %xmm3, mask // mask: extract low 4 bits from each byte + + movl $0x02020202, good + vmovd good, %xmm4 + vpbroadcastd %xmm4, eta // eta: broadcast ETA=2 + + movl $0x0F0F0F0F, good + vmovd good, %xmm5 + vpbroadcastd %xmm5, bound // bound: rejection threshold (15) + + // Modulo-5 magic constants + // v = -6560 == 32*round(-2**10 / 5) for multiply-high-round-scale + movl $-6560, good + vpinsrw $0, good, %xmm6, %xmm6 + vpbroadcastw %xmm6, v_const // v_const: -6560 for mulhrs + + movl $5, good + vpinsrw $0, good, %xmm7, %xmm7 + vpbroadcastw %xmm7, p_const // p_const: 5 for mullo + +// Initialize counters + xorl ctr, ctr // ctr = 0 + xorl pos, pos // pos = 0 + +/* + * Main SIMD loop: process 16 input bytes into 32 nibbles, producing up to 8 + * coefficients per iteration (processing 4 groups of 8 nibbles each). + * Loops while ctr <= MLDSA_N - 8 and pos <= BUFLEN - 16. + */ +rej_uniform_eta2_avx2_asm_loop: + cmpl $248, ctr // MLDSA_N - 8 + ja rej_uniform_eta2_avx2_asm_scalar + cmpl $120, pos // MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN - 16 + ja rej_uniform_eta2_avx2_asm_scalar + + // Load 16 bytes and extract 32 nibbles into f0 + vpmovzxbw (in, %rcx), f0 // load 16 bytes, zero-extend to 16 words + vpsllw $4, f0, f1 // shift left by 4 to align high nibbles + vpor f1, f0, f0 // OR: each word now has nibble duplicated + vpand mask, f0, f0 // mask to get 4 bits per byte + + // Check which nibbles are < 15 and compute eta - nibble + vpsubb bound, f0, f1 // f1 = nibble - 15 (negative if valid) + vpsubb f0, eta, f0 // f0 = 2 - nibble + vpmovmskb f1, good // extract sign bits (valid = 1) + + // For valid nibbles, reduce (2 - nibble) mod 5 (centered, range [-2,2]): + // t = (2 - nibble) - (205 * (2 - nibble) >> 10) * 5 + // matching r = 2 - (nibble mod 5) from the reference. + + // Process first group of 8 nibbles (low 128 bits, low 64 bits of that) + vextracti128 $0, f0, g0 // extract low 128 bits + movzbl %r8b, %r10d // get low 8 bits of mask + vmovq (tab, tmp, 8), g1 // load shuffle indices from table + vpshufb g1, g0, g1 // compact valid nibbles + vpmovsxbd g1, f1 // sign-extend bytes to dwords + + // Apply modulo-5 reduction + vpmulhrsw v_const, f1, f2 // f2 = mulhrs(f1, -6560) + vpmullw p_const, f2, f2 // f2 = f2 * 5 + vpaddd f2, f1, f1 // f1 = (2 - nibble) reduced mod 5 + + vmovdqu f1, (out, %rax, 4) // store 8 dwords + popcntl %r10d, cnt // popcount of low 8 bits (in r10d) + addl cnt, ctr // ctr += popcount(low 8 bits) + shrl $8, good + addl $4, pos // consumed 4 input bytes + + cmpl $248, ctr + ja rej_uniform_eta2_avx2_asm_scalar + + // Process second group of 8 nibbles (low 128 bits, high 64 bits) + vpsrldq $8, g0, g0 // shift right to get next 8 nibbles + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vpmulhrsw v_const, f1, f2 + vpmullw p_const, f2, f2 + vpaddd f2, f1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + shrl $8, good + addl $4, pos + + cmpl $248, ctr + ja rej_uniform_eta2_avx2_asm_scalar + + // Process third group of 8 nibbles (high 128 bits, low 64 bits) + vextracti128 $1, f0, g0 // extract high 128 bits + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vpmulhrsw v_const, f1, f2 + vpmullw p_const, f2, f2 + vpaddd f2, f1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + shrl $8, good + addl $4, pos + + cmpl $248, ctr + ja rej_uniform_eta2_avx2_asm_scalar + + // Process fourth group of 8 nibbles (high 128 bits, high 64 bits) + vpsrldq $8, g0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vpmulhrsw v_const, f1, f2 + vpmullw p_const, f2, f2 + vpaddd f2, f1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + addl $4, pos + + jmp rej_uniform_eta2_avx2_asm_loop + +/* + * Scalar fallback loop: process remaining bytes one nibble at a time. + * Each byte contains two 4-bit nibbles (low and high). + * Apply modulo-5 reduction: t = t - (205 * t >> 10) * 5 + */ +rej_uniform_eta2_avx2_asm_scalar: + cmpl $256, ctr // MLDSA_N + jae rej_uniform_eta2_avx2_asm_done + cmpl $136, pos // MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN + jae rej_uniform_eta2_avx2_asm_done + + movzbl (in, %rcx), val // load 1 byte + incl pos + + // Process low nibble + movl val, %r10d + andl $0x0F, %r10d // extract low 4 bits + cmpl $15, %r10d + jae rej_uniform_eta2_avx2_asm_high_nibble + + // Apply modulo-5 reduction: t = t - (205 * t >> 10) * 5 + movl %r10d, val + imull $205, val + shrl $10, val + imull $5, val + subl val, %r10d // tmp = tmp - (205*tmp>>10)*5 + + movl $2, %r11d + subl %r10d, %r11d // 2 - tmp + movl %r11d, (out, %rax, 4) + incl ctr + + cmpl $256, ctr + jae rej_uniform_eta2_avx2_asm_done + +rej_uniform_eta2_avx2_asm_high_nibble: + // Reload original byte for high nibble + movzbl -1(in, %rcx), val // reload byte + shrl $4, val + andl $0x0F, val + cmpl $15, val + jae rej_uniform_eta2_avx2_asm_scalar + + // Apply modulo-5 reduction + movl val, %r10d + imull $205, %r10d + shrl $10, %r10d + imull $5, %r10d + subl %r10d, val // val = val - (205*val>>10)*5 + + movl $2, %r10d + subl val, %r10d // 2 - val + movl %r10d, (out, %rax, 4) + incl ctr + + jmp rej_uniform_eta2_avx2_asm_scalar + +rej_uniform_eta2_avx2_asm_done: + ret + +/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. + * Don't modify by hand -- this is auto-generated by scripts/autogen. */ +#undef out +#undef in +#undef tab +#undef ctr +#undef pos +#undef good +#undef cnt +#undef tmp +#undef val +#undef f0 +#undef f1 +#undef f2 +#undef mask +#undef eta +#undef bound +#undef v_const +#undef p_const +#undef g0 +#undef g1 + +/* simpasm: footer-start */ +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + */ diff --git a/dev/x86_64/src/rej_uniform_eta4_avx2.c b/dev/x86_64/src/rej_uniform_eta4_avx2.c deleted file mode 100644 index 6f41486d2..000000000 --- a/dev/x86_64/src/rej_uniform_eta4_avx2.c +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (c) The mldsa-native project authors - * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT - */ - -/* References - * ========== - * - * - [REF_AVX2] - * CRYSTALS-Dilithium optimized AVX2 implementation - * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé - * https://github.com/pq-crystals/dilithium/tree/master/avx2 - */ - -/* - * This file is derived from the public domain - * AVX2 Dilithium implementation @[REF_AVX2]. - */ - -#include "../../../common.h" - -#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ - !defined(MLD_CONFIG_NO_KEYPAIR_API) && \ - !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ - (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || MLDSA_ETA == 4) - -#include -#include "arith_native_x86_64.h" -#include "consts.h" - -#define MLD_AVX2_ETA4 4 - -/* - * Reference: In the pqcrystals implementation this function is called - * rej_eta_avx and supports multiple values for ETA via preprocessor - * conditionals. We move the conditionals to the frontend. - */ - -unsigned int mld_rej_uniform_eta4_avx2( - int32_t *MLD_RESTRICT r, - const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN]) -{ - unsigned int ctr, pos; - uint32_t good; - __m256i f0, f1; - __m128i g0, g1; - const __m256i mask = _mm256_set1_epi8(15); - const __m256i eta = _mm256_set1_epi8(MLD_AVX2_ETA4); - const __m256i bound = _mm256_set1_epi8(9); - - ctr = pos = 0; - while (ctr <= MLDSA_N - 8 && pos <= MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN - 16) - { - f0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)&buf[pos])); - f1 = _mm256_slli_epi16(f0, 4); - f0 = _mm256_or_si256(f0, f1); - f0 = _mm256_and_si256(f0, mask); - - f1 = _mm256_sub_epi8(f0, bound); - f0 = _mm256_sub_epi8(eta, f0); - good = (uint32_t)_mm256_movemask_epi8(f1); - - g0 = _mm256_castsi256_si128(f0); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm_bsrli_si128(g0, 8); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm256_extracti128_si256(f0, 1); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm_bsrli_si128(g0, 8); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good); - pos += 4; - } - - while (ctr < MLDSA_N && pos < MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN) - { - uint32_t t0 = buf[pos] & 0x0F; - uint32_t t1 = buf[pos++] >> 4; - - if (t0 < 9) - { - r[ctr++] = (int32_t)(4 - t0); - } - if (t1 < 9 && ctr < MLDSA_N) - { - r[ctr++] = (int32_t)(4 - t1); - } - } - - return ctr; -} - -#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_KEYPAIR_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLDSA_ETA == 4) */ - -MLD_EMPTY_CU(avx2_rej_uniform_eta4) - -#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_KEYPAIR_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLDSA_ETA == 4)) */ - -/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. - * Don't modify by hand -- this is auto-generated by scripts/autogen. */ -#undef MLD_AVX2_ETA4 diff --git a/dev/x86_64/src/rej_uniform_eta4_avx2_asm.S b/dev/x86_64/src/rej_uniform_eta4_avx2_asm.S new file mode 100644 index 000000000..922df283a --- /dev/null +++ b/dev/x86_64/src/rej_uniform_eta4_avx2_asm.S @@ -0,0 +1,232 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) +/* simpasm: header-end */ + +#define out %rdi +#define in %rsi +#define tab %rdx + +#define ctr %eax +#define pos %ecx + +#define good %r8d +#define cnt %r9d +#define tmp %r10 +#define val %r11d + +#define f0 %ymm0 +#define f1 %ymm1 +#define mask %ymm2 +#define eta %ymm3 +#define bound %ymm4 +#define g0 %xmm5 +#define g1 %xmm6 + + .text + +/* + * unsigned mld_rej_uniform_eta4_avx2_asm(int32_t *r, const uint8_t *buf, + * const uint8_t *table) + * + * Rejection sampling for ETA=4 polynomial coefficients. + * Extracts 4-bit nibbles from a byte buffer and accepts those < 9. + * Output: coefficient = 4 - nibble, producing values in [-4, 4]. + * + * Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t) + * in (rsi): pointer to input byte buffer buf (272 bytes) + * tab (rdx): pointer to rejection sampling lookup table (256x8 bytes) + * + * Returns: ctr (eax): number of valid coefficients written to r + */ + .balign 4 + .global MLD_ASM_NAMESPACE(rej_uniform_eta4_avx2_asm) +MLD_ASM_FN_SYMBOL(rej_uniform_eta4_avx2_asm) + +// Construct broadcast constants + movl $0x0F0F0F0F, good + vmovd good, %xmm2 + vpbroadcastd %xmm2, mask // mask: extract low 4 bits from each byte + + movl $0x04040404, good + vmovd good, %xmm3 + vpbroadcastd %xmm3, eta // eta: broadcast ETA=4 + + movl $0x09090909, good + vmovd good, %xmm4 + vpbroadcastd %xmm4, bound // bound: rejection threshold (9) + +// Initialize counters + xorl ctr, ctr // ctr = 0 + xorl pos, pos // pos = 0 + +/* + * Main SIMD loop: process 16 input bytes into 32 nibbles, producing up to + * 32 coefficients per iteration (4 sub-iterations of 8 nibbles each). + * + * Loop-head guards: ctr <= MLDSA_N - 8 = 248 and pos <= BUFLEN - 16 = 256. + * + * Mid-iter early exits at ctr > 248 prevent buffer overshoot: each sub-iter + * stores 8 ints starting at r[ctr], and ctr advances by popcount(<= 8). With + * ctr <= 248 entering a sub-iter, the store touches r[248..256] — exactly + * fitting the 256-int output buffer. + */ +rej_uniform_eta4_avx2_asm_loop: + cmpl $248, ctr // MLDSA_N - 8 + ja rej_uniform_eta4_avx2_asm_scalar + cmpl $256, pos // MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN - 16 + ja rej_uniform_eta4_avx2_asm_scalar + + // Load 16 bytes and extract 32 nibbles into f0 + vpmovzxbw (in, %rcx), f0 // load 16 bytes, zero-extend to 16 words + vpsllw $4, f0, f1 // shift left by 4 to align high nibbles + vpor f1, f0, f0 // OR: each word now has nibble duplicated + vpand mask, f0, f0 // mask to get 4 bits per byte + + // Check which nibbles are < 9 and compute eta - nibble + vpsubb bound, f0, f1 // f1 = nibble - 9 (negative if valid) + vpsubb f0, eta, f0 // f0 = 4 - nibble + vpmovmskb f1, good // extract sign bits (valid = 1) + + // Sub-iter 1: extract low 128 bits of f0; process bits 0..7 of mask + vextracti128 $0, f0, g0 + movzbl %r8b, %r10d // tmp = good & 0xFF + vmovq (tab, tmp, 8), g1 // g1 = table[good & 0xFF] (8 byte indices) + vpshufb g1, g0, g1 // compact accepted nibbles to front + vpmovsxbd g1, f1 // sign-extend 8 bytes -> 8 int32 + vmovdqu f1, (out, %rax, 4) // store 8 ints at r[ctr] + popcntl %r10d, cnt // cnt = popcount(good & 0xFF) + addl cnt, ctr // ctr += cnt + shrl $8, good // shift good for next sub-iter + addl $4, pos // 4 input bytes consumed + + cmpl $248, ctr // mid-iter exit if ctr > 248 + ja rej_uniform_eta4_avx2_asm_scalar + + // Sub-iter 2: shift xmm5 down by 8 bytes; process next 8 bits of mask + vpsrldq $8, g0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + shrl $8, good + addl $4, pos + + cmpl $248, ctr + ja rej_uniform_eta4_avx2_asm_scalar + + // Sub-iter 3: extract high 128 bits of f0 + vextracti128 $1, f0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + shrl $8, good + addl $4, pos + + cmpl $248, ctr + ja rej_uniform_eta4_avx2_asm_scalar + + // Sub-iter 4: shift xmm5 down by 8 bytes; process final 8 bits of mask + vpsrldq $8, g0, g0 + movzbl %r8b, %r10d + vmovq (tab, tmp, 8), g1 + vpshufb g1, g0, g1 + vpmovsxbd g1, f1 + vmovdqu f1, (out, %rax, 4) + popcntl %r10d, cnt + addl cnt, ctr + addl $4, pos + + jmp rej_uniform_eta4_avx2_asm_loop + +/* + * Scalar fallback loop: process remaining bytes one nibble at a time. + * Each byte contains two 4-bit nibbles (low and high), each accepted iff < 9. + */ +rej_uniform_eta4_avx2_asm_scalar: + cmpl $256, ctr // MLDSA_N + jae rej_uniform_eta4_avx2_asm_done + cmpl $272, pos // MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN + jae rej_uniform_eta4_avx2_asm_done + + movzbl (in, %rcx), val // load 1 byte + incl pos + + // Process low nibble + movl val, %r10d + andl $0x0F, %r10d // extract low 4 bits + cmpl $9, %r10d + jae rej_uniform_eta4_avx2_asm_high_nibble + + movl $4, cnt + subl %r10d, cnt // 4 - nibble + movl cnt, (out, %rax, 4) + incl ctr + + cmpl $256, ctr + jae rej_uniform_eta4_avx2_asm_done + +rej_uniform_eta4_avx2_asm_high_nibble: + // Process high nibble + shrl $4, val + andl $0x0F, val + cmpl $9, val + jae rej_uniform_eta4_avx2_asm_scalar + + movl $4, %r10d + subl val, %r10d // 4 - nibble + movl %r10d, (out, %rax, 4) + incl ctr + + jmp rej_uniform_eta4_avx2_asm_scalar + +rej_uniform_eta4_avx2_asm_done: + ret + +/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. + * Don't modify by hand -- this is auto-generated by scripts/autogen. */ +#undef out +#undef in +#undef tab +#undef ctr +#undef pos +#undef good +#undef cnt +#undef tmp +#undef val +#undef f0 +#undef f1 +#undef mask +#undef eta +#undef bound +#undef g0 +#undef g1 + +/* simpasm: footer-start */ +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + */ diff --git a/mldsa/mldsa_native.c b/mldsa/mldsa_native.c index 9365ed369..4cf1bcd32 100644 --- a/mldsa/mldsa_native.c +++ b/mldsa/mldsa_native.c @@ -88,8 +88,6 @@ #include "src/native/x86_64/src/poly_use_hint_32_avx2.c" #include "src/native/x86_64/src/poly_use_hint_88_avx2.c" #include "src/native/x86_64/src/rej_uniform_avx2.c" -#include "src/native/x86_64/src/rej_uniform_eta2_avx2.c" -#include "src/native/x86_64/src/rej_uniform_eta4_avx2.c" #include "src/native/x86_64/src/rej_uniform_table.c" #endif /* MLD_SYS_X86_64 */ #endif /* MLD_CONFIG_USE_NATIVE_BACKEND_ARITH */ @@ -792,8 +790,8 @@ #undef mld_polyz_unpack_17_avx2_asm #undef mld_polyz_unpack_19_avx2_asm #undef mld_rej_uniform_avx2 -#undef mld_rej_uniform_eta2_avx2 -#undef mld_rej_uniform_eta4_avx2 +#undef mld_rej_uniform_eta2_avx2_asm +#undef mld_rej_uniform_eta4_avx2_asm #undef mld_rej_uniform_table /* mldsa/src/native/x86_64/src/consts.h */ #undef MLD_AVX2_BACKEND_DATA_OFFSET_8XDIV diff --git a/mldsa/mldsa_native_asm.S b/mldsa/mldsa_native_asm.S index 4877d5156..8e1a7bea7 100644 --- a/mldsa/mldsa_native_asm.S +++ b/mldsa/mldsa_native_asm.S @@ -90,6 +90,8 @@ #include "src/native/x86_64/src/poly_chknorm_avx2_asm.S" #include "src/native/x86_64/src/polyz_unpack_17_avx2_asm.S" #include "src/native/x86_64/src/polyz_unpack_19_avx2_asm.S" +#include "src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S" +#include "src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S" #endif /* MLD_SYS_X86_64 */ #endif /* MLD_CONFIG_USE_NATIVE_BACKEND_ARITH */ @@ -805,8 +807,8 @@ #undef mld_polyz_unpack_17_avx2_asm #undef mld_polyz_unpack_19_avx2_asm #undef mld_rej_uniform_avx2 -#undef mld_rej_uniform_eta2_avx2 -#undef mld_rej_uniform_eta4_avx2 +#undef mld_rej_uniform_eta2_avx2_asm +#undef mld_rej_uniform_eta4_avx2_asm #undef mld_rej_uniform_table /* mldsa/src/native/x86_64/src/consts.h */ #undef MLD_AVX2_BACKEND_DATA_OFFSET_8XDIV diff --git a/mldsa/src/native/x86_64/meta.h b/mldsa/src/native/x86_64/meta.h index 55924ffec..c0124b3c2 100644 --- a/mldsa/src/native/x86_64/meta.h +++ b/mldsa/src/native/x86_64/meta.h @@ -106,7 +106,8 @@ static MLD_INLINE int mld_rej_uniform_eta2_native(int32_t *r, unsigned len, * We declassify prior the input data and mark the outputs as secret. */ MLD_CT_TESTING_DECLASSIFY(buf, buflen); - outlen = mld_rej_uniform_eta2_avx2(r, buf); + outlen = mld_rej_uniform_eta2_avx2_asm( + r, buf, (const uint8_t *)mld_rej_uniform_table); MLD_CT_TESTING_SECRET(r, sizeof(int32_t) * outlen); /* Safety: outlen is at most MLDSA_N and, hence, this cast is safe. */ return (int)outlen; @@ -135,7 +136,8 @@ static MLD_INLINE int mld_rej_uniform_eta4_native(int32_t *r, unsigned len, * We declassify prior the input data and mark the outputs as secret. */ MLD_CT_TESTING_DECLASSIFY(buf, buflen); - outlen = mld_rej_uniform_eta4_avx2(r, buf); + outlen = mld_rej_uniform_eta4_avx2_asm( + r, buf, (const uint8_t *)mld_rej_uniform_table); MLD_CT_TESTING_SECRET(r, sizeof(int32_t) * outlen); /* Safety: outlen is at most MLDSA_N and, hence, this cast is safe. */ return (int)outlen; diff --git a/mldsa/src/native/x86_64/src/arith_native_x86_64.h b/mldsa/src/native/x86_64/src/arith_native_x86_64.h index 6ec3c1434..189b635d5 100644 --- a/mldsa/src/native/x86_64/src/arith_native_x86_64.h +++ b/mldsa/src/native/x86_64/src/arith_native_x86_64.h @@ -83,15 +83,37 @@ unsigned mld_rej_uniform_avx2(int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN]); #if !defined(MLD_CONFIG_NO_KEYPAIR_API) -#define mld_rej_uniform_eta2_avx2 MLD_NAMESPACE(mld_rej_uniform_eta2_avx2) -MLD_MUST_CHECK_RETURN_VALUE -unsigned mld_rej_uniform_eta2_avx2( - int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN]); +#define mld_rej_uniform_eta2_avx2_asm MLD_NAMESPACE(rej_uniform_eta2_avx2_asm) +/* This contract must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/rej_uniform_eta2_avx2_asm.ml */ +MLD_MUST_CHECK_RETURN_VALUE MLD_SYSV_ABI +unsigned mld_rej_uniform_eta2_avx2_asm( + int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN], + const uint8_t *table) +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN)) + requires(table == mld_rej_uniform_table) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(return_value <= MLDSA_N) + ensures(array_bound(r, 0, return_value, -2, 2)) +); -#define mld_rej_uniform_eta4_avx2 MLD_NAMESPACE(mld_rej_uniform_eta4_avx2) -MLD_MUST_CHECK_RETURN_VALUE -unsigned mld_rej_uniform_eta4_avx2( - int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN]); +#define mld_rej_uniform_eta4_avx2_asm MLD_NAMESPACE(rej_uniform_eta4_avx2_asm) +/* This contract must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/rej_uniform_eta4_avx2_asm.ml */ +MLD_MUST_CHECK_RETURN_VALUE MLD_SYSV_ABI +unsigned mld_rej_uniform_eta4_avx2_asm( + int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN], + const uint8_t *table) +__contract__( + requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(buf, MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN)) + requires(table == mld_rej_uniform_table) + assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N)) + ensures(return_value <= MLDSA_N) + ensures(array_bound(r, 0, return_value, -4, 4)) +); #endif /* !MLD_CONFIG_NO_KEYPAIR_API */ #if !defined(MLD_CONFIG_NO_SIGN_API) diff --git a/mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c b/mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c deleted file mode 100644 index dac73995b..000000000 --- a/mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Copyright (c) The mldsa-native project authors - * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT - */ - -/* References - * ========== - * - * - [REF_AVX2] - * CRYSTALS-Dilithium optimized AVX2 implementation - * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé - * https://github.com/pq-crystals/dilithium/tree/master/avx2 - */ - -/* - * This file is derived from the public domain - * AVX2 Dilithium implementation @[REF_AVX2]. - */ - -#include "../../../common.h" - -#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ - !defined(MLD_CONFIG_NO_KEYPAIR_API) && \ - !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ - (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || MLDSA_ETA == 2) - -#include -#include "arith_native_x86_64.h" -#include "consts.h" - -#define MLD_AVX2_ETA2 2 - -/* - * Reference: In the pqcrystals implementation this function is called - * rej_eta_avx and supports multiple values for ETA via preprocessor - * conditionals. We move the conditionals to the frontend. - */ -unsigned int mld_rej_uniform_eta2_avx2( - int32_t *MLD_RESTRICT r, - const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN]) -{ - unsigned int ctr, pos; - uint32_t good; - __m256i f0, f1, f2; - __m128i g0, g1; - const __m256i mask = _mm256_set1_epi8(15); - const __m256i eta = _mm256_set1_epi8(MLD_AVX2_ETA2); - const __m256i bound = mask; - /* check-magic: -6560 == 32*round(-2**10 / 5) */ - const __m256i v = _mm256_set1_epi32(-6560); - const __m256i p = _mm256_set1_epi32(5); - - ctr = pos = 0; - while (ctr <= MLDSA_N - 8 && pos <= MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN - 16) - { - f0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)&buf[pos])); - f1 = _mm256_slli_epi16(f0, 4); - f0 = _mm256_or_si256(f0, f1); - f0 = _mm256_and_si256(f0, mask); - - f1 = _mm256_sub_epi8(f0, bound); - f0 = _mm256_sub_epi8(eta, f0); - good = (uint32_t)_mm256_movemask_epi8(f1); - - g0 = _mm256_castsi256_si128(f0); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - f2 = _mm256_mulhrs_epi16(f1, v); - f2 = _mm256_mullo_epi16(f2, p); - f1 = _mm256_add_epi32(f1, f2); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm_bsrli_si128(g0, 8); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - f2 = _mm256_mulhrs_epi16(f1, v); - f2 = _mm256_mullo_epi16(f2, p); - f1 = _mm256_add_epi32(f1, f2); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm256_extracti128_si256(f0, 1); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - f2 = _mm256_mulhrs_epi16(f1, v); - f2 = _mm256_mullo_epi16(f2, p); - f1 = _mm256_add_epi32(f1, f2); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm_bsrli_si128(g0, 8); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - f2 = _mm256_mulhrs_epi16(f1, v); - f2 = _mm256_mullo_epi16(f2, p); - f1 = _mm256_add_epi32(f1, f2); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good); - pos += 4; - } - - while (ctr < MLDSA_N && pos < MLD_AVX2_REJ_UNIFORM_ETA2_BUFLEN) - { - uint32_t t0 = buf[pos] & 0x0F; - uint32_t t1 = buf[pos++] >> 4; - - if (t0 < 15) - { - t0 = t0 - (205 * t0 >> 10) * 5; - r[ctr++] = (int32_t)(2 - t0); - } - if (t1 < 15 && ctr < MLDSA_N) - { - t1 = t1 - (205 * t1 >> 10) * 5; - r[ctr++] = (int32_t)(2 - t1); - } - } - - return ctr; -} - -#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_KEYPAIR_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLDSA_ETA == 2) */ - -MLD_EMPTY_CU(avx2_rej_uniform_eta2) - -#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_KEYPAIR_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLDSA_ETA == 2)) */ - -/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. - * Don't modify by hand -- this is auto-generated by scripts/autogen. */ -#undef MLD_AVX2_ETA2 diff --git a/mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S b/mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S new file mode 100644 index 000000000..f881d85d4 --- /dev/null +++ b/mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2_asm.S @@ -0,0 +1,174 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/x86_64/src/rej_uniform_eta2_avx2_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +.global MLD_ASM_NAMESPACE(rej_uniform_eta2_avx2_asm) +MLD_ASM_FN_SYMBOL(rej_uniform_eta2_avx2_asm) + + .cfi_startproc + movl $0xf0f0f0f, %r8d # imm = 0xF0F0F0F + vmovd %r8d, %xmm3 + vpbroadcastd %xmm3, %ymm3 + movl $0x2020202, %r8d # imm = 0x2020202 + vmovd %r8d, %xmm4 + vpbroadcastd %xmm4, %ymm4 + movl $0xf0f0f0f, %r8d # imm = 0xF0F0F0F + vmovd %r8d, %xmm5 + vpbroadcastd %xmm5, %ymm5 + movl $0xffffe660, %r8d # imm = 0xFFFFE660 + vpinsrw $0x0, %r8d, %xmm6, %xmm6 + vpbroadcastw %xmm6, %ymm6 + movl $0x5, %r8d + vpinsrw $0x0, %r8d, %xmm7, %xmm7 + vpbroadcastw %xmm7, %ymm7 + xorl %eax, %eax + xorl %ecx, %ecx + +Lrej_uniform_eta2_avx2_asm_loop: + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + cmpl $0x78, %ecx + ja Lrej_uniform_eta2_avx2_asm_scalar + vpmovzxbw (%rsi,%rcx), %ymm0 + vpsllw $0x4, %ymm0, %ymm1 + vpor %ymm1, %ymm0, %ymm0 + vpand %ymm3, %ymm0, %ymm0 + vpsubb %ymm5, %ymm0, %ymm1 + vpsubb %ymm0, %ymm4, %ymm0 + vpmovmskb %ymm1, %r8d + vextracti128 $0x0, %ymm0, %xmm8 + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + vpsrldq $0x8, %xmm8, %xmm8 # xmm8 = xmm8[8,9,10,11,12,13,14,15],zero,zero,zero,zero,zero,zero,zero,zero + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + vextracti128 $0x1, %ymm0, %xmm8 + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + vpsrldq $0x8, %xmm8, %xmm8 # xmm8 = xmm8[8,9,10,11,12,13,14,15],zero,zero,zero,zero,zero,zero,zero,zero + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + addl $0x4, %ecx + jmp Lrej_uniform_eta2_avx2_asm_loop + +Lrej_uniform_eta2_avx2_asm_scalar: + cmpl $0x100, %eax # imm = 0x100 + jae Lrej_uniform_eta2_avx2_asm_done + cmpl $0x88, %ecx + jae Lrej_uniform_eta2_avx2_asm_done + movzbl (%rsi,%rcx), %r11d + incl %ecx + movl %r11d, %r10d + andl $0xf, %r10d + cmpl $0xf, %r10d + jae Lrej_uniform_eta2_avx2_asm_high_nibble + movl %r10d, %r11d + imull $0xcd, %r11d, %r11d + shrl $0xa, %r11d + imull $0x5, %r11d, %r11d + subl %r11d, %r10d + movl $0x2, %r11d + subl %r10d, %r11d + movl %r11d, (%rdi,%rax,4) + incl %eax + cmpl $0x100, %eax # imm = 0x100 + jae Lrej_uniform_eta2_avx2_asm_done + +Lrej_uniform_eta2_avx2_asm_high_nibble: + movzbl -0x1(%rsi,%rcx), %r11d + shrl $0x4, %r11d + andl $0xf, %r11d + cmpl $0xf, %r11d + jae Lrej_uniform_eta2_avx2_asm_scalar + movl %r11d, %r10d + imull $0xcd, %r10d, %r10d + shrl $0xa, %r10d + imull $0x5, %r10d, %r10d + subl %r10d, %r11d + movl $0x2, %r10d + subl %r11d, %r10d + movl %r10d, (%rdi,%rax,4) + incl %eax + jmp Lrej_uniform_eta2_avx2_asm_scalar + +Lrej_uniform_eta2_avx2_asm_done: + retq + .cfi_endproc + +MLD_ASM_FN_SIZE(rej_uniform_eta2_avx2_asm) + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + */ + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c b/mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c deleted file mode 100644 index 6f41486d2..000000000 --- a/mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright (c) The mldsa-native project authors - * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT - */ - -/* References - * ========== - * - * - [REF_AVX2] - * CRYSTALS-Dilithium optimized AVX2 implementation - * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé - * https://github.com/pq-crystals/dilithium/tree/master/avx2 - */ - -/* - * This file is derived from the public domain - * AVX2 Dilithium implementation @[REF_AVX2]. - */ - -#include "../../../common.h" - -#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ - !defined(MLD_CONFIG_NO_KEYPAIR_API) && \ - !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ - (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || MLDSA_ETA == 4) - -#include -#include "arith_native_x86_64.h" -#include "consts.h" - -#define MLD_AVX2_ETA4 4 - -/* - * Reference: In the pqcrystals implementation this function is called - * rej_eta_avx and supports multiple values for ETA via preprocessor - * conditionals. We move the conditionals to the frontend. - */ - -unsigned int mld_rej_uniform_eta4_avx2( - int32_t *MLD_RESTRICT r, - const uint8_t buf[MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN]) -{ - unsigned int ctr, pos; - uint32_t good; - __m256i f0, f1; - __m128i g0, g1; - const __m256i mask = _mm256_set1_epi8(15); - const __m256i eta = _mm256_set1_epi8(MLD_AVX2_ETA4); - const __m256i bound = _mm256_set1_epi8(9); - - ctr = pos = 0; - while (ctr <= MLDSA_N - 8 && pos <= MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN - 16) - { - f0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i *)&buf[pos])); - f1 = _mm256_slli_epi16(f0, 4); - f0 = _mm256_or_si256(f0, f1); - f0 = _mm256_and_si256(f0, mask); - - f1 = _mm256_sub_epi8(f0, bound); - f0 = _mm256_sub_epi8(eta, f0); - good = (uint32_t)_mm256_movemask_epi8(f1); - - g0 = _mm256_castsi256_si128(f0); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm_bsrli_si128(g0, 8); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm256_extracti128_si256(f0, 1); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good & 0xFF]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good & 0xFF); - good >>= 8; - pos += 4; - - if (ctr > MLDSA_N - 8) - { - break; - } - g0 = _mm_bsrli_si128(g0, 8); - g1 = _mm_loadl_epi64((__m128i *)&mld_rej_uniform_table[good]); - g1 = _mm_shuffle_epi8(g0, g1); - f1 = _mm256_cvtepi8_epi32(g1); - _mm256_storeu_si256((__m256i *)&r[ctr], f1); - ctr += (unsigned)_mm_popcnt_u32(good); - pos += 4; - } - - while (ctr < MLDSA_N && pos < MLD_AVX2_REJ_UNIFORM_ETA4_BUFLEN) - { - uint32_t t0 = buf[pos] & 0x0F; - uint32_t t1 = buf[pos++] >> 4; - - if (t0 < 9) - { - r[ctr++] = (int32_t)(4 - t0); - } - if (t1 < 9 && ctr < MLDSA_N) - { - r[ctr++] = (int32_t)(4 - t1); - } - } - - return ctr; -} - -#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_KEYPAIR_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLDSA_ETA == 4) */ - -MLD_EMPTY_CU(avx2_rej_uniform_eta4) - -#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_KEYPAIR_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLDSA_ETA == 4)) */ - -/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. - * Don't modify by hand -- this is auto-generated by scripts/autogen. */ -#undef MLD_AVX2_ETA4 diff --git a/mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S b/mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S new file mode 100644 index 000000000..1903029bd --- /dev/null +++ b/mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2_asm.S @@ -0,0 +1,145 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + +#include "../../../common.h" +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/x86_64/src/rej_uniform_eta4_avx2_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +.global MLD_ASM_NAMESPACE(rej_uniform_eta4_avx2_asm) +MLD_ASM_FN_SYMBOL(rej_uniform_eta4_avx2_asm) + + .cfi_startproc + movl $0xf0f0f0f, %r8d # imm = 0xF0F0F0F + vmovd %r8d, %xmm2 + vpbroadcastd %xmm2, %ymm2 + movl $0x4040404, %r8d # imm = 0x4040404 + vmovd %r8d, %xmm3 + vpbroadcastd %xmm3, %ymm3 + movl $0x9090909, %r8d # imm = 0x9090909 + vmovd %r8d, %xmm4 + vpbroadcastd %xmm4, %ymm4 + xorl %eax, %eax + xorl %ecx, %ecx + +Lrej_uniform_eta4_avx2_asm_loop: + cmpl $0xf8, %eax + ja Lrej_uniform_eta4_avx2_asm_scalar + cmpl $0x100, %ecx # imm = 0x100 + ja Lrej_uniform_eta4_avx2_asm_scalar + vpmovzxbw (%rsi,%rcx), %ymm0 + vpsllw $0x4, %ymm0, %ymm1 + vpor %ymm1, %ymm0, %ymm0 + vpand %ymm2, %ymm0, %ymm0 + vpsubb %ymm4, %ymm0, %ymm1 + vpsubb %ymm0, %ymm3, %ymm0 + vpmovmskb %ymm1, %r8d + vextracti128 $0x0, %ymm0, %xmm5 + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm6 + vpshufb %xmm6, %xmm5, %xmm6 + vpmovsxbd %xmm6, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta4_avx2_asm_scalar + vpsrldq $0x8, %xmm5, %xmm5 # xmm5 = xmm5[8,9,10,11,12,13,14,15],zero,zero,zero,zero,zero,zero,zero,zero + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm6 + vpshufb %xmm6, %xmm5, %xmm6 + vpmovsxbd %xmm6, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta4_avx2_asm_scalar + vextracti128 $0x1, %ymm0, %xmm5 + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm6 + vpshufb %xmm6, %xmm5, %xmm6 + vpmovsxbd %xmm6, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta4_avx2_asm_scalar + vpsrldq $0x8, %xmm5, %xmm5 # xmm5 = xmm5[8,9,10,11,12,13,14,15],zero,zero,zero,zero,zero,zero,zero,zero + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm6 + vpshufb %xmm6, %xmm5, %xmm6 + vpmovsxbd %xmm6, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + addl $0x4, %ecx + jmp Lrej_uniform_eta4_avx2_asm_loop + +Lrej_uniform_eta4_avx2_asm_scalar: + cmpl $0x100, %eax # imm = 0x100 + jae Lrej_uniform_eta4_avx2_asm_done + cmpl $0x110, %ecx # imm = 0x110 + jae Lrej_uniform_eta4_avx2_asm_done + movzbl (%rsi,%rcx), %r11d + incl %ecx + movl %r11d, %r10d + andl $0xf, %r10d + cmpl $0x9, %r10d + jae Lrej_uniform_eta4_avx2_asm_high_nibble + movl $0x4, %r9d + subl %r10d, %r9d + movl %r9d, (%rdi,%rax,4) + incl %eax + cmpl $0x100, %eax # imm = 0x100 + jae Lrej_uniform_eta4_avx2_asm_done + +Lrej_uniform_eta4_avx2_asm_high_nibble: + shrl $0x4, %r11d + andl $0xf, %r11d + cmpl $0x9, %r11d + jae Lrej_uniform_eta4_avx2_asm_scalar + movl $0x4, %r10d + subl %r11d, %r10d + movl %r10d, (%rdi,%rax,4) + incl %eax + jmp Lrej_uniform_eta4_avx2_asm_scalar + +Lrej_uniform_eta4_avx2_asm_done: + retq + .cfi_endproc + +MLD_ASM_FN_SIZE(rej_uniform_eta4_avx2_asm) + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + */ + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/proofs/hol_light/x86_64/Makefile b/proofs/hol_light/x86_64/Makefile index 693078496..355f82b8c 100644 --- a/proofs/hol_light/x86_64/Makefile +++ b/proofs/hol_light/x86_64/Makefile @@ -40,9 +40,9 @@ OBJDUMP=$(CROSS_PREFIX)objdump -d # by single-quote characters in comments, so we eliminate // comments first. ifeq ($(OSTYPE_RESULT),Darwin) -PREPROCESS=sed -e 's/\/\/.*//' | $(CC) -E -xassembler-with-cpp - +PREPROCESS=sed -e 's/\/\/.*//' | $(CC) -E -xassembler-with-cpp -I$(BASE)/../../../mldsa -I$(BASE)/../../../mldsa/src -I$(BASE)/../../../common -DMLD_CONFIG_PARAMETER_SET=65 - else -PREPROCESS=$(CC) -E -xassembler-with-cpp - +PREPROCESS=$(CC) -E -xassembler-with-cpp -I$(BASE)/../../../mldsa -I$(BASE)/../../../mldsa/src -I$(BASE)/../../../common -DMLD_CONFIG_PARAMETER_SET=65 - endif # Generally GNU-type assemblers are happy with multiple instructions on @@ -61,7 +61,9 @@ OBJ = mldsa/ntt_avx2_asm.o \ mldsa/pointwise_acc_l4_avx2_asm.o \ mldsa/pointwise_acc_l5_avx2_asm.o \ mldsa/pointwise_acc_l7_avx2_asm.o \ - mldsa/keccak_f1600_x4_avx2_asm.o + mldsa/keccak_f1600_x4_avx2_asm.o \ + mldsa/rej_uniform_eta4_avx2_asm.o \ + mldsa/rej_uniform_eta2_avx2_asm.o # Build object files from assembly sources $(OBJ): %.o : %.S diff --git a/proofs/hol_light/x86_64/mldsa/rej_uniform_eta2_avx2_asm.S b/proofs/hol_light/x86_64/mldsa/rej_uniform_eta2_avx2_asm.S new file mode 100644 index 000000000..5a7af0a61 --- /dev/null +++ b/proofs/hol_light/x86_64/mldsa/rej_uniform_eta2_avx2_asm.S @@ -0,0 +1,173 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/x86_64/src/rej_uniform_eta2_avx2_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +#ifdef __APPLE__ +.global _mld_rej_uniform_eta2_avx2_asm +_mld_rej_uniform_eta2_avx2_asm: +#else +.global mld_rej_uniform_eta2_avx2_asm +mld_rej_uniform_eta2_avx2_asm: +#endif + + .cfi_startproc + endbr64 + movl $0xf0f0f0f, %r8d # imm = 0xF0F0F0F + vmovd %r8d, %xmm3 + vpbroadcastd %xmm3, %ymm3 + movl $0x2020202, %r8d # imm = 0x2020202 + vmovd %r8d, %xmm4 + vpbroadcastd %xmm4, %ymm4 + movl $0xf0f0f0f, %r8d # imm = 0xF0F0F0F + vmovd %r8d, %xmm5 + vpbroadcastd %xmm5, %ymm5 + movl $0xffffe660, %r8d # imm = 0xFFFFE660 + vpinsrw $0x0, %r8d, %xmm6, %xmm6 + vpbroadcastw %xmm6, %ymm6 + movl $0x5, %r8d + vpinsrw $0x0, %r8d, %xmm7, %xmm7 + vpbroadcastw %xmm7, %ymm7 + xorl %eax, %eax + xorl %ecx, %ecx + +Lrej_uniform_eta2_avx2_asm_loop: + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + cmpl $0x78, %ecx + ja Lrej_uniform_eta2_avx2_asm_scalar + vpmovzxbw (%rsi,%rcx), %ymm0 + vpsllw $0x4, %ymm0, %ymm1 + vpor %ymm1, %ymm0, %ymm0 + vpand %ymm3, %ymm0, %ymm0 + vpsubb %ymm5, %ymm0, %ymm1 + vpsubb %ymm0, %ymm4, %ymm0 + vpmovmskb %ymm1, %r8d + vextracti128 $0x0, %ymm0, %xmm8 + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + vpsrldq $0x8, %xmm8, %xmm8 # xmm8 = xmm8[8,9,10,11,12,13,14,15],zero,zero,zero,zero,zero,zero,zero,zero + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + vextracti128 $0x1, %ymm0, %xmm8 + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta2_avx2_asm_scalar + vpsrldq $0x8, %xmm8, %xmm8 # xmm8 = xmm8[8,9,10,11,12,13,14,15],zero,zero,zero,zero,zero,zero,zero,zero + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm9 + vpshufb %xmm9, %xmm8, %xmm9 + vpmovsxbd %xmm9, %ymm1 + vpmulhrsw %ymm6, %ymm1, %ymm2 + vpmullw %ymm7, %ymm2, %ymm2 + vpaddd %ymm2, %ymm1, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + addl $0x4, %ecx + jmp Lrej_uniform_eta2_avx2_asm_loop + +Lrej_uniform_eta2_avx2_asm_scalar: + cmpl $0x100, %eax # imm = 0x100 + jae Lrej_uniform_eta2_avx2_asm_done + cmpl $0x88, %ecx + jae Lrej_uniform_eta2_avx2_asm_done + movzbl (%rsi,%rcx), %r11d + incl %ecx + movl %r11d, %r10d + andl $0xf, %r10d + cmpl $0xf, %r10d + jae Lrej_uniform_eta2_avx2_asm_high_nibble + movl %r10d, %r11d + imull $0xcd, %r11d, %r11d + shrl $0xa, %r11d + imull $0x5, %r11d, %r11d + subl %r11d, %r10d + movl $0x2, %r11d + subl %r10d, %r11d + movl %r11d, (%rdi,%rax,4) + incl %eax + cmpl $0x100, %eax # imm = 0x100 + jae Lrej_uniform_eta2_avx2_asm_done + +Lrej_uniform_eta2_avx2_asm_high_nibble: + movzbl -0x1(%rsi,%rcx), %r11d + shrl $0x4, %r11d + andl $0xf, %r11d + cmpl $0xf, %r11d + jae Lrej_uniform_eta2_avx2_asm_scalar + movl %r11d, %r10d + imull $0xcd, %r10d, %r10d + shrl $0xa, %r10d + imull $0x5, %r10d, %r10d + subl %r10d, %r11d + movl $0x2, %r10d + subl %r11d, %r10d + movl %r10d, (%rdi,%rax,4) + incl %eax + jmp Lrej_uniform_eta2_avx2_asm_scalar + +Lrej_uniform_eta2_avx2_asm_done: + retq + .cfi_endproc + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/proofs/hol_light/x86_64/mldsa/rej_uniform_eta4_avx2_asm.S b/proofs/hol_light/x86_64/mldsa/rej_uniform_eta4_avx2_asm.S new file mode 100644 index 000000000..40c73f10e --- /dev/null +++ b/proofs/hol_light/x86_64/mldsa/rej_uniform_eta4_avx2_asm.S @@ -0,0 +1,144 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/x86_64/src/rej_uniform_eta4_avx2_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +#ifdef __APPLE__ +.global _mld_rej_uniform_eta4_avx2_asm +_mld_rej_uniform_eta4_avx2_asm: +#else +.global mld_rej_uniform_eta4_avx2_asm +mld_rej_uniform_eta4_avx2_asm: +#endif + + .cfi_startproc + endbr64 + movl $0xf0f0f0f, %r8d # imm = 0xF0F0F0F + vmovd %r8d, %xmm2 + vpbroadcastd %xmm2, %ymm2 + movl $0x4040404, %r8d # imm = 0x4040404 + vmovd %r8d, %xmm3 + vpbroadcastd %xmm3, %ymm3 + movl $0x9090909, %r8d # imm = 0x9090909 + vmovd %r8d, %xmm4 + vpbroadcastd %xmm4, %ymm4 + xorl %eax, %eax + xorl %ecx, %ecx + +Lrej_uniform_eta4_avx2_asm_loop: + cmpl $0xf8, %eax + ja Lrej_uniform_eta4_avx2_asm_scalar + cmpl $0x100, %ecx # imm = 0x100 + ja Lrej_uniform_eta4_avx2_asm_scalar + vpmovzxbw (%rsi,%rcx), %ymm0 + vpsllw $0x4, %ymm0, %ymm1 + vpor %ymm1, %ymm0, %ymm0 + vpand %ymm2, %ymm0, %ymm0 + vpsubb %ymm4, %ymm0, %ymm1 + vpsubb %ymm0, %ymm3, %ymm0 + vpmovmskb %ymm1, %r8d + vextracti128 $0x0, %ymm0, %xmm5 + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm6 + vpshufb %xmm6, %xmm5, %xmm6 + vpmovsxbd %xmm6, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta4_avx2_asm_scalar + vpsrldq $0x8, %xmm5, %xmm5 # xmm5 = xmm5[8,9,10,11,12,13,14,15],zero,zero,zero,zero,zero,zero,zero,zero + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm6 + vpshufb %xmm6, %xmm5, %xmm6 + vpmovsxbd %xmm6, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta4_avx2_asm_scalar + vextracti128 $0x1, %ymm0, %xmm5 + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm6 + vpshufb %xmm6, %xmm5, %xmm6 + vpmovsxbd %xmm6, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + shrl $0x8, %r8d + addl $0x4, %ecx + cmpl $0xf8, %eax + ja Lrej_uniform_eta4_avx2_asm_scalar + vpsrldq $0x8, %xmm5, %xmm5 # xmm5 = xmm5[8,9,10,11,12,13,14,15],zero,zero,zero,zero,zero,zero,zero,zero + movzbl %r8b, %r10d + vmovq (%rdx,%r10,8), %xmm6 + vpshufb %xmm6, %xmm5, %xmm6 + vpmovsxbd %xmm6, %ymm1 + vmovdqu %ymm1, (%rdi,%rax,4) + popcntl %r10d, %r9d + addl %r9d, %eax + addl $0x4, %ecx + jmp Lrej_uniform_eta4_avx2_asm_loop + +Lrej_uniform_eta4_avx2_asm_scalar: + cmpl $0x100, %eax # imm = 0x100 + jae Lrej_uniform_eta4_avx2_asm_done + cmpl $0x110, %ecx # imm = 0x110 + jae Lrej_uniform_eta4_avx2_asm_done + movzbl (%rsi,%rcx), %r11d + incl %ecx + movl %r11d, %r10d + andl $0xf, %r10d + cmpl $0x9, %r10d + jae Lrej_uniform_eta4_avx2_asm_high_nibble + movl $0x4, %r9d + subl %r10d, %r9d + movl %r9d, (%rdi,%rax,4) + incl %eax + cmpl $0x100, %eax # imm = 0x100 + jae Lrej_uniform_eta4_avx2_asm_done + +Lrej_uniform_eta4_avx2_asm_high_nibble: + shrl $0x4, %r11d + andl $0xf, %r11d + cmpl $0x9, %r11d + jae Lrej_uniform_eta4_avx2_asm_scalar + movl $0x4, %r10d + subl %r11d, %r10d + movl %r10d, (%rdi,%rax,4) + incl %eax + jmp Lrej_uniform_eta4_avx2_asm_scalar + +Lrej_uniform_eta4_avx2_asm_done: + retq + .cfi_endproc + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/scripts/autogen b/scripts/autogen index 3ce73146f..a7cf48147 100755 --- a/scripts/autogen +++ b/scripts/autogen @@ -2880,6 +2880,18 @@ def hol_light_asm_joblist(): f"-Idev/fips202/x86_64/src -Imldsa/src/fips202/native/x86_64/src {x86_64_flags}", "x86_64", ), + ( + "rej_uniform_eta4_avx2_asm.S", + "dev/x86_64/src", + f"-DMLD_ARITH_BACKEND_X86_64_DEFAULT -Imldsa/src/native/x86_64/src -Icommon {x86_64_flags}", + "x86_64", + ), + ( + "rej_uniform_eta2_avx2_asm.S", + "dev/x86_64/src", + f"-DMLD_ARITH_BACKEND_X86_64_DEFAULT -Imldsa/src/native/x86_64/src -Icommon {x86_64_flags}", + "x86_64", + ), ] return joblist_aarch64 + joblist_x86_64 diff --git a/test/bench/bench_components_mldsa.c b/test/bench/bench_components_mldsa.c index 948915dae..0654bc610 100644 --- a/test/bench/bench_components_mldsa.c +++ b/test/bench/bench_components_mldsa.c @@ -98,6 +98,18 @@ static int bench(void) BENCH("poly_caddq", mld_poly_caddq((mld_poly *)data0)); + /* poly_uniform_eta_4x — exercises rej_uniform_eta{2,4}_avx2_asm on x86 */ +#if !defined(MLD_CONFIG_SERIAL_FIPS202_ONLY) && \ + !defined(MLD_CONFIG_NO_KEYPAIR_API) + { + MLD_ALIGN mld_poly poly_eta0, poly_eta1, poly_eta2, poly_eta3; + BENCH( + "poly_uniform_eta_4x", + mld_poly_uniform_eta_4x(&poly_eta0, &poly_eta1, &poly_eta2, &poly_eta3, + (const uint8_t *)data0, 0, 1, 2, 3)) + } +#endif /* !MLD_CONFIG_SERIAL_FIPS202_ONLY && !MLD_CONFIG_NO_KEYPAIR_API */ + BENCH("poly_chknorm", chknorm_acc ^= mld_poly_chknorm((const mld_poly *)data0, MLDSA_GAMMA1 - MLDSA_BETA);)