diff --git a/.github/workflows/hol_light.yml b/.github/workflows/hol_light.yml index ef89ac54a..ae715b8f0 100644 --- a/.github/workflows/hol_light.yml +++ b/.github/workflows/hol_light.yml @@ -209,6 +209,10 @@ jobs: needs: ["mldsa_specs.ml", "mldsa_utils.ml", "subroutine_signatures.ml"] - name: polyz_unpack_19_avx2_asm needs: ["mldsa_specs.ml", "mldsa_utils.ml", "subroutine_signatures.ml"] + - name: poly_use_hint_32_avx2_asm + needs: ["mldsa_specs.ml", "mldsa_utils.ml", "subroutine_signatures.ml"] + - name: poly_use_hint_88_avx2_asm + needs: ["mldsa_specs.ml", "mldsa_utils.ml", "subroutine_signatures.ml"] - name: ntt_avx2_asm needs: ["mldsa_specs.ml", "mldsa_utils.ml", "mldsa_zetas.ml", "subroutine_signatures.ml"] - name: intt_avx2_asm diff --git a/BIBLIOGRAPHY.md b/BIBLIOGRAPHY.md index 2ed637232..21440e28c 100644 --- a/BIBLIOGRAPHY.md +++ b/BIBLIOGRAPHY.md @@ -274,8 +274,8 @@ source code and documentation. - [dev/x86_64/src/poly_chknorm_avx2_asm.S](dev/x86_64/src/poly_chknorm_avx2_asm.S) - [dev/x86_64/src/poly_decompose_32_avx2.c](dev/x86_64/src/poly_decompose_32_avx2.c) - [dev/x86_64/src/poly_decompose_88_avx2.c](dev/x86_64/src/poly_decompose_88_avx2.c) - - [dev/x86_64/src/poly_use_hint_32_avx2.c](dev/x86_64/src/poly_use_hint_32_avx2.c) - - [dev/x86_64/src/poly_use_hint_88_avx2.c](dev/x86_64/src/poly_use_hint_88_avx2.c) + - [dev/x86_64/src/poly_use_hint_32_avx2_asm.S](dev/x86_64/src/poly_use_hint_32_avx2_asm.S) + - [dev/x86_64/src/poly_use_hint_88_avx2_asm.S](dev/x86_64/src/poly_use_hint_88_avx2_asm.S) - [dev/x86_64/src/polyz_unpack_17_avx2_asm.S](dev/x86_64/src/polyz_unpack_17_avx2_asm.S) - [dev/x86_64/src/polyz_unpack_19_avx2_asm.S](dev/x86_64/src/polyz_unpack_19_avx2_asm.S) - [dev/x86_64/src/rej_uniform_avx2.c](dev/x86_64/src/rej_uniform_avx2.c) @@ -292,8 +292,8 @@ source code and documentation. - [mldsa/src/native/x86_64/src/poly_chknorm_avx2_asm.S](mldsa/src/native/x86_64/src/poly_chknorm_avx2_asm.S) - [mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c](mldsa/src/native/x86_64/src/poly_decompose_32_avx2.c) - [mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c](mldsa/src/native/x86_64/src/poly_decompose_88_avx2.c) - - [mldsa/src/native/x86_64/src/poly_use_hint_32_avx2.c](mldsa/src/native/x86_64/src/poly_use_hint_32_avx2.c) - - [mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c](mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c) + - [mldsa/src/native/x86_64/src/poly_use_hint_32_avx2_asm.S](mldsa/src/native/x86_64/src/poly_use_hint_32_avx2_asm.S) + - [mldsa/src/native/x86_64/src/poly_use_hint_88_avx2_asm.S](mldsa/src/native/x86_64/src/poly_use_hint_88_avx2_asm.S) - [mldsa/src/native/x86_64/src/polyz_unpack_17_avx2_asm.S](mldsa/src/native/x86_64/src/polyz_unpack_17_avx2_asm.S) - [mldsa/src/native/x86_64/src/polyz_unpack_19_avx2_asm.S](mldsa/src/native/x86_64/src/polyz_unpack_19_avx2_asm.S) - [mldsa/src/native/x86_64/src/rej_uniform_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_avx2.c) @@ -308,6 +308,8 @@ source code and documentation. - [proofs/hol_light/x86_64/mldsa/pointwise_avx2_asm.S](proofs/hol_light/x86_64/mldsa/pointwise_avx2_asm.S) - [proofs/hol_light/x86_64/mldsa/poly_caddq_avx2_asm.S](proofs/hol_light/x86_64/mldsa/poly_caddq_avx2_asm.S) - [proofs/hol_light/x86_64/mldsa/poly_chknorm_avx2_asm.S](proofs/hol_light/x86_64/mldsa/poly_chknorm_avx2_asm.S) + - [proofs/hol_light/x86_64/mldsa/poly_use_hint_32_avx2_asm.S](proofs/hol_light/x86_64/mldsa/poly_use_hint_32_avx2_asm.S) + - [proofs/hol_light/x86_64/mldsa/poly_use_hint_88_avx2_asm.S](proofs/hol_light/x86_64/mldsa/poly_use_hint_88_avx2_asm.S) - [proofs/hol_light/x86_64/mldsa/polyz_unpack_17_avx2_asm.S](proofs/hol_light/x86_64/mldsa/polyz_unpack_17_avx2_asm.S) - [proofs/hol_light/x86_64/mldsa/polyz_unpack_19_avx2_asm.S](proofs/hol_light/x86_64/mldsa/polyz_unpack_19_avx2_asm.S) diff --git a/dev/x86_64/meta.h b/dev/x86_64/meta.h index 55924ffec..67750c5eb 100644 --- a/dev/x86_64/meta.h +++ b/dev/x86_64/meta.h @@ -195,7 +195,7 @@ static MLD_INLINE int mld_poly_use_hint_32_native(int32_t *a, const int32_t *h) { return MLD_NATIVE_FUNC_FALLBACK; } - mld_poly_use_hint_32_avx2(a, h); + mld_poly_use_hint_32_avx2_asm(a, h); return MLD_NATIVE_FUNC_SUCCESS; } #endif /* MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \ @@ -209,7 +209,7 @@ static MLD_INLINE int mld_poly_use_hint_88_native(int32_t *a, const int32_t *h) { return MLD_NATIVE_FUNC_FALLBACK; } - mld_poly_use_hint_88_avx2(a, h); + mld_poly_use_hint_88_avx2_asm(a, h); return MLD_NATIVE_FUNC_SUCCESS; } #endif /* MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 44 \ diff --git a/dev/x86_64/src/arith_native_x86_64.h b/dev/x86_64/src/arith_native_x86_64.h index 6ec3c1434..b7fcab7aa 100644 --- a/dev/x86_64/src/arith_native_x86_64.h +++ b/dev/x86_64/src/arith_native_x86_64.h @@ -115,11 +115,33 @@ __contract__( ); #if !defined(MLD_CONFIG_NO_VERIFY_API) -#define mld_poly_use_hint_32_avx2 MLD_NAMESPACE(mld_poly_use_hint_32_avx2) -void mld_poly_use_hint_32_avx2(int32_t *a, const int32_t *h); +#define mld_poly_use_hint_32_avx2_asm MLD_NAMESPACE(poly_use_hint_32_avx2_asm) +MLD_SYSV_ABI +void mld_poly_use_hint_32_avx2_asm(int32_t *a, const int32_t *h) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/poly_use_hint_32_avx2_asm.ml */ +__contract__( + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(h, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a, 0, MLDSA_N, 0, MLDSA_Q)) + requires(array_bound(h, 0, MLDSA_N, 0, 2)) + assigns(memory_slice(a, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(a, 0, MLDSA_N, 0, (MLDSA_Q - 1) / (2 * MLDSA_GAMMA2))) +); -#define mld_poly_use_hint_88_avx2 MLD_NAMESPACE(mld_poly_use_hint_88_avx2) -void mld_poly_use_hint_88_avx2(int32_t *a, const int32_t *h); +#define mld_poly_use_hint_88_avx2_asm MLD_NAMESPACE(poly_use_hint_88_avx2_asm) +MLD_SYSV_ABI +void mld_poly_use_hint_88_avx2_asm(int32_t *a, const int32_t *h) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/poly_use_hint_88_avx2_asm.ml */ +__contract__( + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(h, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a, 0, MLDSA_N, 0, MLDSA_Q)) + requires(array_bound(h, 0, MLDSA_N, 0, 2)) + assigns(memory_slice(a, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(a, 0, MLDSA_N, 0, (MLDSA_Q - 1) / (2 * MLDSA_GAMMA2))) +); #endif /* !MLD_CONFIG_NO_VERIFY_API */ #define mld_poly_chknorm_avx2_asm MLD_NAMESPACE(poly_chknorm_avx2_asm) diff --git a/dev/x86_64/src/poly_use_hint_32_avx2.c b/dev/x86_64/src/poly_use_hint_32_avx2.c deleted file mode 100644 index 9a9edfa53..000000000 --- a/dev/x86_64/src/poly_use_hint_32_avx2.c +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright (c) The mldsa-native project authors - * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT - */ - -/* References - * ========== - * - * - [REF_AVX2] - * CRYSTALS-Dilithium optimized AVX2 implementation - * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé - * https://github.com/pq-crystals/dilithium/tree/master/avx2 - */ - -/* - * This file is derived from the public domain - * AVX2 Dilithium implementation @[REF_AVX2]. - */ - -#include "../../../common.h" - -#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ - !defined(MLD_CONFIG_NO_VERIFY_API) && \ - !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ - (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \ - (MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87)) - -#include -#include "arith_native_x86_64.h" -#include "consts.h" - -#define MLD_MM256_BLENDV_EPI32(a, b, mask) \ - _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(a), \ - _mm256_castsi256_ps(b), \ - _mm256_castsi256_ps(mask))) - -void mld_poly_use_hint_32_avx2(int32_t *a, const int32_t *hint) -{ - unsigned int i; - __m256i f, f0, f1, h, t; - const __m256i q_bound = _mm256_set1_epi32(31 * ((MLDSA_Q - 1) / 32)); - /* check-magic: 1025 == floor(2**22 / 4092) */ - const __m256i v = _mm256_set1_epi32(1025); - const __m256i alpha = _mm256_set1_epi32(2 * ((MLDSA_Q - 1) / 32)); - const __m256i off = _mm256_set1_epi32(127); - const __m256i shift = _mm256_set1_epi32(512); - const __m256i mask = _mm256_set1_epi32(15); - const __m256i zero = _mm256_setzero_si256(); - - for (i = 0; i < MLDSA_N / 8; i++) - { - f = _mm256_load_si256((const __m256i *)&a[8 * i]); - h = _mm256_load_si256((const __m256i *)&hint[8 * i]); - - /* Reference: - * - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop. - * - Our implementation of decompose() is slightly different from that in - * @[REF_AVX2]. See poly_decompose_32_avx2.c for more information. - */ - /* f1, f2 = decompose(f) */ - f1 = _mm256_add_epi32(f, off); - f1 = _mm256_srli_epi32(f1, 7); - f1 = _mm256_mulhi_epu16(f1, v); - f1 = _mm256_mulhrs_epi16(f1, shift); - t = _mm256_cmpgt_epi32(f, q_bound); - f0 = _mm256_mullo_epi32(f1, alpha); - f0 = _mm256_sub_epi32(f, f0); - f1 = _mm256_andnot_si256(t, f1); - f0 = _mm256_add_epi32(f0, t); - - /* Reference: The reference avx2 implementation checks a0 >= 0, which is - * different from the specification and the reference C implementation. We - * follow the specification and check a0 > 0. - */ - /* t = (f0 > 0) ? h : -h */ - f0 = _mm256_cmpgt_epi32(f0, zero); - t = MLD_MM256_BLENDV_EPI32(h, zero, f0); - t = _mm256_slli_epi32(t, 1); - h = _mm256_sub_epi32(h, t); - - /* f1 = (f1 + t) % 16 */ - f1 = _mm256_add_epi32(f1, h); - f1 = _mm256_and_si256(f1, mask); - - _mm256_store_si256((__m256i *)&a[8 * i], f1); - } -} - -#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_VERIFY_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \ - || MLD_CONFIG_PARAMETER_SET == 87) */ - -MLD_EMPTY_CU(avx2_poly_use_hint_32) - -#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_VERIFY_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \ - || MLD_CONFIG_PARAMETER_SET == 87)) */ - -/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. - * Don't modify by hand -- this is auto-generated by scripts/autogen. */ -#undef MLD_MM256_BLENDV_EPI32 diff --git a/dev/x86_64/src/poly_use_hint_32_avx2_asm.S b/dev/x86_64/src/poly_use_hint_32_avx2_asm.S new file mode 100644 index 000000000..f2569d950 --- /dev/null +++ b/dev/x86_64/src/poly_use_hint_32_avx2_asm.S @@ -0,0 +1,148 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + + +/************************************************* + * Name: mld_poly_use_hint_32_avx2_asm + * + * Description: Use hint polynomial to correct the high bits of a polynomial. + * Variant for parameter sets ML-DSA-65 and ML-DSA-87 + * (GAMMA2 = (Q-1)/32). Operates in place. + * + * Arguments: - int32_t *a: pointer to input/output polynomial; the + * corrected high bits are written back here + * - const int32_t *h: pointer to input hint polynomial + **************************************************/ + +#include "../../../common.h" + +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ + (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \ + (MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87)) + +/* simpasm: header-end */ + +/* Reference: + * - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop. + * - Our implementation of decompose() is slightly different from that in + * @[REF_AVX2]. See poly_decompose_32_avx2.c for more information. + */ + +// a aliased with a0 +.macro decompose32_avx2 a1, a, temp1, temp2, temp3 +// Compute a1 = round-(a / 523776) ≈ round(a * 1074791425 / +// 2^49), where round-() denotes "round half down". This is +// exact for 0 <= a < Q. Note that half is rounded down since +// 1074791425 / 2^49 ≲ 1 / 523776. +vpaddd \a, %ymm5, \temp1 +vpsrld $7, \temp1, \temp1 +vpmulhuw %ymm8, \temp1, \temp1 +vpmulhrsw %ymm7, \temp1, \temp1 + +// If a1 = 16, i.e. a > 31*GAMMA2, proceed as if a' = a - Q was +// given instead. (For a = 31*GAMMA2 + 1 thus a' = -GAMMA2, we +// still round it to 0 like other "wrapped around" cases.) + +// Check for wrap-around +vpcmpgtd %ymm4, \a, \temp2 +vpandn \temp1, \temp2, \a1 + +// Compute remainder a0 = a - a1 * 2 * GAMMA2 = a - a1 * 523776 +vpslld $10, \temp1, \temp3 +vpsubd \temp1, \temp3, \temp1 +vpslld $9, \temp1, \temp1 +vpsubd \temp1, \a, \a + +// If wrap-around is required, adjust a0 by -1 +vpaddd \temp2, \a, \a +.endm + +/* Reference: The reference avx2 implementation checks a0 >= 0, which is + * different from the specification and the reference C implementation. We + * follow the specification and check a0 > 0. + */ + +// a aliased with delta +.macro use_hint32_avx2 b, a, h, a1, temp1, temp2, temp3 +decompose32_avx2 \a1, \a, \temp1, \temp2, \temp3 + +// delta = (a0 <= 0) ? -1 : 1 +vpcmpgtd %ymm6, \a, \a +vpandn \h, \a, \a +vpslld $1, \a, \a +vpsubd \a, \h, \h + +// b = (b + delta * h) % 16 +vpaddd \a1, \h, \b +vpand %ymm3, \b, \b +.endm + +.text +.balign 16 +.global MLD_ASM_NAMESPACE(poly_use_hint_32_avx2_asm) +MLD_ASM_FN_SYMBOL(poly_use_hint_32_avx2_asm) + +// Initialize constants +movl $127, %ecx + +/* check-magic: 1025 == floor(2^22 / 4092) */ +movl $1025, %r8d +vmovd %r8d, %xmm8 +vpbroadcastd %xmm8, %ymm8 + +xorl %eax, %eax +vpxor %xmm6, %xmm6, %xmm6 +vmovd %ecx, %xmm5 + +/* 31 * ((Q-1) / 32) == 31 * GAMMA2, wrap-around threshold */ +movl $8118528, %ecx + +/* round(x * 2^9 / 2^15) => round(x / 2^6), for f1 = round(f1'' / 2^6) */ +movl $512, %r9d +vmovd %r9d, %xmm7 +vpbroadcastd %xmm7, %ymm7 + +vmovd %ecx, %xmm4 +movl $15, %ecx +vpbroadcastd %xmm5, %ymm5 +vmovd %ecx, %xmm3 +vpbroadcastd %xmm4, %ymm4 +vpbroadcastd %xmm3, %ymm3 + + +poly_use_hint_32_avx2_asm_loop: +vmovdqa (%rdi), %ymm0 +vmovdqa (%rsi), %ymm2 + +use_hint32_avx2 %ymm2, %ymm0, %ymm2, %ymm9, %ymm1, %ymm11, %ymm10 + +vmovdqa %ymm2, (%rdi) +addq $32, %rdi +addq $32, %rsi +addq $32, %rax +cmpq $1024, %rax +jne poly_use_hint_32_avx2_asm_loop +ret + +/* simpasm: footer-start */ + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + && (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == \ + 65 || MLD_CONFIG_PARAMETER_SET == 87) */ diff --git a/dev/x86_64/src/poly_use_hint_88_avx2.c b/dev/x86_64/src/poly_use_hint_88_avx2.c deleted file mode 100644 index 41112a2a3..000000000 --- a/dev/x86_64/src/poly_use_hint_88_avx2.c +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) The mldsa-native project authors - * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT - */ - -/* References - * ========== - * - * - [REF_AVX2] - * CRYSTALS-Dilithium optimized AVX2 implementation - * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé - * https://github.com/pq-crystals/dilithium/tree/master/avx2 - */ - -/* - * This file is derived from the public domain - * AVX2 Dilithium implementation @[REF_AVX2]. - */ - -#include "../../../common.h" - -#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ - !defined(MLD_CONFIG_NO_VERIFY_API) && \ - !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ - (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \ - MLD_CONFIG_PARAMETER_SET == 44) - -#include -#include "arith_native_x86_64.h" -#include "consts.h" - -#define MLD_MM256_BLENDV_EPI32(a, b, mask) \ - _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(a), \ - _mm256_castsi256_ps(b), \ - _mm256_castsi256_ps(mask))) - -void mld_poly_use_hint_88_avx2(int32_t *a, const int32_t *hint) -{ - unsigned int i; - __m256i f, f0, f1, h, t; - const __m256i q_bound = _mm256_set1_epi32(87 * ((MLDSA_Q - 1) / 88)); - /* check-magic: 11275 == floor(2**24 / 1488) */ - const __m256i v = _mm256_set1_epi32(11275); - const __m256i alpha = _mm256_set1_epi32(2 * ((MLDSA_Q - 1) / 88)); - const __m256i off = _mm256_set1_epi32(127); - const __m256i shift = _mm256_set1_epi32(128); - const __m256i max = _mm256_set1_epi32(43); - const __m256i zero = _mm256_setzero_si256(); - - for (i = 0; i < MLDSA_N / 8; i++) - { - f = _mm256_load_si256((const __m256i *)&a[8 * i]); - h = _mm256_load_si256((const __m256i *)&hint[8 * i]); - - /* Reference: - * - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop. - * - Our implementation of decompose() is slightly different from that in - * @[REF_AVX2]. See poly_decompose_88_avx2.c for more information. - */ - /* f1, f2 = decompose(f) */ - f1 = _mm256_add_epi32(f, off); - f1 = _mm256_srli_epi32(f1, 7); - f1 = _mm256_mulhi_epu16(f1, v); - f1 = _mm256_mulhrs_epi16(f1, shift); - t = _mm256_cmpgt_epi32(f, q_bound); - f0 = _mm256_mullo_epi32(f1, alpha); - f0 = _mm256_sub_epi32(f, f0); - f1 = _mm256_andnot_si256(t, f1); - f0 = _mm256_add_epi32(f0, t); - - /* Reference: The reference avx2 implementation checks a0 >= 0, which is - * different from the specification and the reference C implementation. We - * follow the specification and check a0 > 0. - */ - /* t = (f0 > 0) ? h : -h */ - f0 = _mm256_cmpgt_epi32(f0, zero); - t = MLD_MM256_BLENDV_EPI32(h, zero, f0); - t = _mm256_slli_epi32(t, 1); - h = _mm256_sub_epi32(h, t); - - /* f1 = (f1 + t) % 44 */ - f1 = _mm256_add_epi32(f1, h); - f1 = MLD_MM256_BLENDV_EPI32(f1, max, f1); - f = _mm256_cmpgt_epi32(f1, max); - f1 = MLD_MM256_BLENDV_EPI32(f1, zero, f); - - _mm256_store_si256((__m256i *)&a[8 * i], f1); - } -} - -#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_VERIFY_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 44) \ - */ - -MLD_EMPTY_CU(avx2_poly_use_hint_88) - -#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_VERIFY_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == \ - 44)) */ - -/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. - * Don't modify by hand -- this is auto-generated by scripts/autogen. */ -#undef MLD_MM256_BLENDV_EPI32 diff --git a/dev/x86_64/src/poly_use_hint_88_avx2_asm.S b/dev/x86_64/src/poly_use_hint_88_avx2_asm.S new file mode 100644 index 000000000..5bbd946fe --- /dev/null +++ b/dev/x86_64/src/poly_use_hint_88_avx2_asm.S @@ -0,0 +1,157 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + + +/************************************************* + * Name: mld_poly_use_hint_88_avx2_asm + * + * Description: Use hint polynomial to correct the high bits of a polynomial. + * Variant for parameter set ML-DSA-44 (GAMMA2 = (Q-1)/88). + * Operates in place. + * + * Arguments: - int32_t *a: pointer to input/output polynomial; the + * corrected high bits are written back here + * - const int32_t *h: pointer to input hint polynomial + **************************************************/ + +#include "../../../common.h" + +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ + (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \ + MLD_CONFIG_PARAMETER_SET == 44) +/* simpasm: header-end */ + +/* Reference: + * - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop. + * - Our implementation of decompose() is slightly different from that in + * @[REF_AVX2]. See poly_decompose_88_avx2.c for more information. + */ + +// a aliased with a0 +.macro decompose88_avx2 a1, a, temp1, temp2, temp3 +// Compute a1 = round-(a / 190464) ≈ round(a * 1477838209 / +// 2^48), where round-() denotes "round half down". This is +// exact for 0 <= a < Q. Note that half is rounded down since +// 1477838209 / 2^48 ≲ 1 / 190464. +vpaddd \a, %ymm4, \temp1 +vpsrld $7, \temp1, \temp1 +vpmulhuw %ymm8, \temp1, \temp1 +vpmulhrsw %ymm7, \temp1, \temp1 + +// If a1 = 44, i.e. a > 87*GAMMA2, proceed as if a' = a - Q was +// given instead. (For a = 87*GAMMA2 + 1 thus a' = -GAMMA2, we +// still round it to 0 like other "wrapped around" cases.) + +// Check for wrap-around +vpcmpgtd %ymm3, \a, \temp2 +vpandn \temp1, \temp2, \a1 + +// Compute remainder a0 = a - a1 * 2 * GAMMA2 = a - a1 * 190464 +vpslld $1, \temp1, \temp3 +vpaddd \temp1, \temp3, \temp3 +vpslld $5, \temp3, \temp1 +vpsubd \temp3, \temp1, \temp1 +vpslld $11, \temp1, \temp1 +vpsubd \temp1, \a, \a + +// If wrap-around is required, adjust a0 by -1 +vpaddd \temp2, \a, \a +.endm + +/* Reference: The reference avx2 implementation checks a0 >= 0, which is + * different from the specification and the reference C implementation. We + * follow the specification and check a0 > 0. + */ + +// a aliased with delta +.macro use_hint88_avx2 b, a, h, a1, temp1, temp2, temp3 +decompose88_avx2 \a1, \a, \temp1, \temp2, \temp3 + +// delta = (a0 <= 0) ? -1 : 1 +vpcmpgtd %ymm5, \a, \a +vpandn \h, \a, \a +vpslld $1, \a, \a +vpsubd \a, \h, \a + +// b = (b + delta * h) % 44 +vpaddd \a1, \a, \b +// If b wrapped below 0 (a1 == 0, delta == -1), set b = 43. +// b is in [-1, 44], so the per-dword sign bit and the per-byte +// blend mask of vpblendvb coincide. +vpblendvb \b, %ymm6, \b, \b +// If b overflowed above 43 (a1 == 43, delta == +1), set b = 0. +vpcmpgtd %ymm6, \b, \h +vpandn \b, \h, \b +.endm + +.text +.balign 16 +.global MLD_ASM_NAMESPACE(poly_use_hint_88_avx2_asm) +MLD_ASM_FN_SYMBOL(poly_use_hint_88_avx2_asm) + +// Initialize constants +movl $127, %ecx +xorl %eax, %eax +vpxor %xmm5, %xmm5, %xmm5 + +/* check-magic: 11275 == floor(2^24 / 1488) */ +movl $11275, %r8d +vmovd %r8d, %xmm8 +vpbroadcastd %xmm8, %ymm8 + +vmovd %ecx, %xmm4 + +/* 87 * ((Q-1) / 88) == 87 * GAMMA2, wrap-around threshold */ +movl $8285184, %ecx + +/* round(x * 2^7 / 2^15) => round(x / 2^8), for f1 = round(f1'' / 2^8) */ +movl $128, %r9d +vmovd %r9d, %xmm7 +vpbroadcastd %xmm7, %ymm7 + +/* max a1 value */ +movl $43, %r10d +vmovd %r10d, %xmm6 +vpbroadcastd %xmm6, %ymm6 + +vmovd %ecx, %xmm3 +vpbroadcastd %xmm4, %ymm4 +vpbroadcastd %xmm3, %ymm3 + +poly_use_hint_88_avx2_asm_loop: +vmovdqa (%rdi), %ymm0 +vmovdqa (%rsi), %ymm1 + +use_hint88_avx2 %ymm0, %ymm0, %ymm1, %ymm9, %ymm10, %ymm11, %ymm12 + +vmovdqa %ymm0, (%rdi) +addq $32, %rdi +addq $32, %rsi +addq $32, %rax +cmpq $1024, %rax +jne poly_use_hint_88_avx2_asm_loop + +ret + +/* simpasm: footer-start */ + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + && (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == \ + 44) */ diff --git a/mldsa/mldsa_native.c b/mldsa/mldsa_native.c index 9365ed369..ba7f1cf8f 100644 --- a/mldsa/mldsa_native.c +++ b/mldsa/mldsa_native.c @@ -85,8 +85,6 @@ #include "src/native/x86_64/src/consts.c" #include "src/native/x86_64/src/poly_decompose_32_avx2.c" #include "src/native/x86_64/src/poly_decompose_88_avx2.c" -#include "src/native/x86_64/src/poly_use_hint_32_avx2.c" -#include "src/native/x86_64/src/poly_use_hint_88_avx2.c" #include "src/native/x86_64/src/rej_uniform_avx2.c" #include "src/native/x86_64/src/rej_uniform_eta2_avx2.c" #include "src/native/x86_64/src/rej_uniform_eta4_avx2.c" @@ -787,8 +785,8 @@ #undef mld_poly_chknorm_avx2_asm #undef mld_poly_decompose_32_avx2 #undef mld_poly_decompose_88_avx2 -#undef mld_poly_use_hint_32_avx2 -#undef mld_poly_use_hint_88_avx2 +#undef mld_poly_use_hint_32_avx2_asm +#undef mld_poly_use_hint_88_avx2_asm #undef mld_polyz_unpack_17_avx2_asm #undef mld_polyz_unpack_19_avx2_asm #undef mld_rej_uniform_avx2 diff --git a/mldsa/mldsa_native_asm.S b/mldsa/mldsa_native_asm.S index 4877d5156..555bcf1c9 100644 --- a/mldsa/mldsa_native_asm.S +++ b/mldsa/mldsa_native_asm.S @@ -88,6 +88,8 @@ #include "src/native/x86_64/src/pointwise_avx2_asm.S" #include "src/native/x86_64/src/poly_caddq_avx2_asm.S" #include "src/native/x86_64/src/poly_chknorm_avx2_asm.S" +#include "src/native/x86_64/src/poly_use_hint_32_avx2_asm.S" +#include "src/native/x86_64/src/poly_use_hint_88_avx2_asm.S" #include "src/native/x86_64/src/polyz_unpack_17_avx2_asm.S" #include "src/native/x86_64/src/polyz_unpack_19_avx2_asm.S" #endif /* MLD_SYS_X86_64 */ @@ -800,8 +802,8 @@ #undef mld_poly_chknorm_avx2_asm #undef mld_poly_decompose_32_avx2 #undef mld_poly_decompose_88_avx2 -#undef mld_poly_use_hint_32_avx2 -#undef mld_poly_use_hint_88_avx2 +#undef mld_poly_use_hint_32_avx2_asm +#undef mld_poly_use_hint_88_avx2_asm #undef mld_polyz_unpack_17_avx2_asm #undef mld_polyz_unpack_19_avx2_asm #undef mld_rej_uniform_avx2 diff --git a/mldsa/src/native/x86_64/meta.h b/mldsa/src/native/x86_64/meta.h index 55924ffec..67750c5eb 100644 --- a/mldsa/src/native/x86_64/meta.h +++ b/mldsa/src/native/x86_64/meta.h @@ -195,7 +195,7 @@ static MLD_INLINE int mld_poly_use_hint_32_native(int32_t *a, const int32_t *h) { return MLD_NATIVE_FUNC_FALLBACK; } - mld_poly_use_hint_32_avx2(a, h); + mld_poly_use_hint_32_avx2_asm(a, h); return MLD_NATIVE_FUNC_SUCCESS; } #endif /* MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \ @@ -209,7 +209,7 @@ static MLD_INLINE int mld_poly_use_hint_88_native(int32_t *a, const int32_t *h) { return MLD_NATIVE_FUNC_FALLBACK; } - mld_poly_use_hint_88_avx2(a, h); + mld_poly_use_hint_88_avx2_asm(a, h); return MLD_NATIVE_FUNC_SUCCESS; } #endif /* MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 44 \ diff --git a/mldsa/src/native/x86_64/src/arith_native_x86_64.h b/mldsa/src/native/x86_64/src/arith_native_x86_64.h index 6ec3c1434..b7fcab7aa 100644 --- a/mldsa/src/native/x86_64/src/arith_native_x86_64.h +++ b/mldsa/src/native/x86_64/src/arith_native_x86_64.h @@ -115,11 +115,33 @@ __contract__( ); #if !defined(MLD_CONFIG_NO_VERIFY_API) -#define mld_poly_use_hint_32_avx2 MLD_NAMESPACE(mld_poly_use_hint_32_avx2) -void mld_poly_use_hint_32_avx2(int32_t *a, const int32_t *h); +#define mld_poly_use_hint_32_avx2_asm MLD_NAMESPACE(poly_use_hint_32_avx2_asm) +MLD_SYSV_ABI +void mld_poly_use_hint_32_avx2_asm(int32_t *a, const int32_t *h) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/poly_use_hint_32_avx2_asm.ml */ +__contract__( + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(h, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a, 0, MLDSA_N, 0, MLDSA_Q)) + requires(array_bound(h, 0, MLDSA_N, 0, 2)) + assigns(memory_slice(a, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(a, 0, MLDSA_N, 0, (MLDSA_Q - 1) / (2 * MLDSA_GAMMA2))) +); -#define mld_poly_use_hint_88_avx2 MLD_NAMESPACE(mld_poly_use_hint_88_avx2) -void mld_poly_use_hint_88_avx2(int32_t *a, const int32_t *h); +#define mld_poly_use_hint_88_avx2_asm MLD_NAMESPACE(poly_use_hint_88_avx2_asm) +MLD_SYSV_ABI +void mld_poly_use_hint_88_avx2_asm(int32_t *a, const int32_t *h) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/x86_64/proofs/poly_use_hint_88_avx2_asm.ml */ +__contract__( + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(h, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a, 0, MLDSA_N, 0, MLDSA_Q)) + requires(array_bound(h, 0, MLDSA_N, 0, 2)) + assigns(memory_slice(a, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(a, 0, MLDSA_N, 0, (MLDSA_Q - 1) / (2 * MLDSA_GAMMA2))) +); #endif /* !MLD_CONFIG_NO_VERIFY_API */ #define mld_poly_chknorm_avx2_asm MLD_NAMESPACE(poly_chknorm_avx2_asm) diff --git a/mldsa/src/native/x86_64/src/poly_use_hint_32_avx2.c b/mldsa/src/native/x86_64/src/poly_use_hint_32_avx2.c deleted file mode 100644 index 9a9edfa53..000000000 --- a/mldsa/src/native/x86_64/src/poly_use_hint_32_avx2.c +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright (c) The mldsa-native project authors - * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT - */ - -/* References - * ========== - * - * - [REF_AVX2] - * CRYSTALS-Dilithium optimized AVX2 implementation - * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé - * https://github.com/pq-crystals/dilithium/tree/master/avx2 - */ - -/* - * This file is derived from the public domain - * AVX2 Dilithium implementation @[REF_AVX2]. - */ - -#include "../../../common.h" - -#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ - !defined(MLD_CONFIG_NO_VERIFY_API) && \ - !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ - (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \ - (MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87)) - -#include -#include "arith_native_x86_64.h" -#include "consts.h" - -#define MLD_MM256_BLENDV_EPI32(a, b, mask) \ - _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(a), \ - _mm256_castsi256_ps(b), \ - _mm256_castsi256_ps(mask))) - -void mld_poly_use_hint_32_avx2(int32_t *a, const int32_t *hint) -{ - unsigned int i; - __m256i f, f0, f1, h, t; - const __m256i q_bound = _mm256_set1_epi32(31 * ((MLDSA_Q - 1) / 32)); - /* check-magic: 1025 == floor(2**22 / 4092) */ - const __m256i v = _mm256_set1_epi32(1025); - const __m256i alpha = _mm256_set1_epi32(2 * ((MLDSA_Q - 1) / 32)); - const __m256i off = _mm256_set1_epi32(127); - const __m256i shift = _mm256_set1_epi32(512); - const __m256i mask = _mm256_set1_epi32(15); - const __m256i zero = _mm256_setzero_si256(); - - for (i = 0; i < MLDSA_N / 8; i++) - { - f = _mm256_load_si256((const __m256i *)&a[8 * i]); - h = _mm256_load_si256((const __m256i *)&hint[8 * i]); - - /* Reference: - * - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop. - * - Our implementation of decompose() is slightly different from that in - * @[REF_AVX2]. See poly_decompose_32_avx2.c for more information. - */ - /* f1, f2 = decompose(f) */ - f1 = _mm256_add_epi32(f, off); - f1 = _mm256_srli_epi32(f1, 7); - f1 = _mm256_mulhi_epu16(f1, v); - f1 = _mm256_mulhrs_epi16(f1, shift); - t = _mm256_cmpgt_epi32(f, q_bound); - f0 = _mm256_mullo_epi32(f1, alpha); - f0 = _mm256_sub_epi32(f, f0); - f1 = _mm256_andnot_si256(t, f1); - f0 = _mm256_add_epi32(f0, t); - - /* Reference: The reference avx2 implementation checks a0 >= 0, which is - * different from the specification and the reference C implementation. We - * follow the specification and check a0 > 0. - */ - /* t = (f0 > 0) ? h : -h */ - f0 = _mm256_cmpgt_epi32(f0, zero); - t = MLD_MM256_BLENDV_EPI32(h, zero, f0); - t = _mm256_slli_epi32(t, 1); - h = _mm256_sub_epi32(h, t); - - /* f1 = (f1 + t) % 16 */ - f1 = _mm256_add_epi32(f1, h); - f1 = _mm256_and_si256(f1, mask); - - _mm256_store_si256((__m256i *)&a[8 * i], f1); - } -} - -#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_VERIFY_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \ - || MLD_CONFIG_PARAMETER_SET == 87) */ - -MLD_EMPTY_CU(avx2_poly_use_hint_32) - -#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_VERIFY_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 65 \ - || MLD_CONFIG_PARAMETER_SET == 87)) */ - -/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. - * Don't modify by hand -- this is auto-generated by scripts/autogen. */ -#undef MLD_MM256_BLENDV_EPI32 diff --git a/mldsa/src/native/x86_64/src/poly_use_hint_32_avx2_asm.S b/mldsa/src/native/x86_64/src/poly_use_hint_32_avx2_asm.S new file mode 100644 index 000000000..686fbe863 --- /dev/null +++ b/mldsa/src/native/x86_64/src/poly_use_hint_32_avx2_asm.S @@ -0,0 +1,108 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + + +/************************************************* + * Name: mld_poly_use_hint_32_avx2_asm + * + * Description: Use hint polynomial to correct the high bits of a polynomial. + * Variant for parameter sets ML-DSA-65 and ML-DSA-87 + * (GAMMA2 = (Q-1)/32). Operates in place. + * + * Arguments: - int32_t *a: pointer to input/output polynomial; the + * corrected high bits are written back here + * - const int32_t *h: pointer to input hint polynomial + **************************************************/ + +#include "../../../common.h" + +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ + (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \ + (MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87)) + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/x86_64/src/poly_use_hint_32_avx2_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +.global MLD_ASM_NAMESPACE(poly_use_hint_32_avx2_asm) +MLD_ASM_FN_SYMBOL(poly_use_hint_32_avx2_asm) + + .cfi_startproc + movl $0x7f, %ecx + movl $0x401, %r8d # imm = 0x401 + vmovd %r8d, %xmm8 + vpbroadcastd %xmm8, %ymm8 + xorl %eax, %eax + vpxor %xmm6, %xmm6, %xmm6 + vmovd %ecx, %xmm5 + movl $0x7be100, %ecx # imm = 0x7BE100 + movl $0x200, %r9d # imm = 0x200 + vmovd %r9d, %xmm7 + vpbroadcastd %xmm7, %ymm7 + vmovd %ecx, %xmm4 + movl $0xf, %ecx + vpbroadcastd %xmm5, %ymm5 + vmovd %ecx, %xmm3 + vpbroadcastd %xmm4, %ymm4 + vpbroadcastd %xmm3, %ymm3 + +Lpoly_use_hint_32_avx2_asm_loop: + vmovdqa (%rdi), %ymm0 + vmovdqa (%rsi), %ymm2 + vpaddd %ymm0, %ymm5, %ymm1 + vpsrld $0x7, %ymm1, %ymm1 + vpmulhuw %ymm8, %ymm1, %ymm1 + vpmulhrsw %ymm7, %ymm1, %ymm1 + vpcmpgtd %ymm4, %ymm0, %ymm11 + vpandn %ymm1, %ymm11, %ymm9 + vpslld $0xa, %ymm1, %ymm10 + vpsubd %ymm1, %ymm10, %ymm1 + vpslld $0x9, %ymm1, %ymm1 + vpsubd %ymm1, %ymm0, %ymm0 + vpaddd %ymm11, %ymm0, %ymm0 + vpcmpgtd %ymm6, %ymm0, %ymm0 + vpandn %ymm2, %ymm0, %ymm0 + vpslld $0x1, %ymm0, %ymm0 + vpsubd %ymm0, %ymm2, %ymm2 + vpaddd %ymm9, %ymm2, %ymm2 + vpand %ymm3, %ymm2, %ymm2 + vmovdqa %ymm2, (%rdi) + addq $0x20, %rdi + addq $0x20, %rsi + addq $0x20, %rax + cmpq $0x400, %rax # imm = 0x400 + jne Lpoly_use_hint_32_avx2_asm_loop + retq + .cfi_endproc + +MLD_ASM_FN_SIZE(poly_use_hint_32_avx2_asm) + + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + && (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == \ + 65 || MLD_CONFIG_PARAMETER_SET == 87) */ + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c b/mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c deleted file mode 100644 index 41112a2a3..000000000 --- a/mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) The mldsa-native project authors - * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT - */ - -/* References - * ========== - * - * - [REF_AVX2] - * CRYSTALS-Dilithium optimized AVX2 implementation - * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé - * https://github.com/pq-crystals/dilithium/tree/master/avx2 - */ - -/* - * This file is derived from the public domain - * AVX2 Dilithium implementation @[REF_AVX2]. - */ - -#include "../../../common.h" - -#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ - !defined(MLD_CONFIG_NO_VERIFY_API) && \ - !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ - (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \ - MLD_CONFIG_PARAMETER_SET == 44) - -#include -#include "arith_native_x86_64.h" -#include "consts.h" - -#define MLD_MM256_BLENDV_EPI32(a, b, mask) \ - _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(a), \ - _mm256_castsi256_ps(b), \ - _mm256_castsi256_ps(mask))) - -void mld_poly_use_hint_88_avx2(int32_t *a, const int32_t *hint) -{ - unsigned int i; - __m256i f, f0, f1, h, t; - const __m256i q_bound = _mm256_set1_epi32(87 * ((MLDSA_Q - 1) / 88)); - /* check-magic: 11275 == floor(2**24 / 1488) */ - const __m256i v = _mm256_set1_epi32(11275); - const __m256i alpha = _mm256_set1_epi32(2 * ((MLDSA_Q - 1) / 88)); - const __m256i off = _mm256_set1_epi32(127); - const __m256i shift = _mm256_set1_epi32(128); - const __m256i max = _mm256_set1_epi32(43); - const __m256i zero = _mm256_setzero_si256(); - - for (i = 0; i < MLDSA_N / 8; i++) - { - f = _mm256_load_si256((const __m256i *)&a[8 * i]); - h = _mm256_load_si256((const __m256i *)&hint[8 * i]); - - /* Reference: - * - @[REF_AVX2] calls poly_decompose to compute all a1, a0 before the loop. - * - Our implementation of decompose() is slightly different from that in - * @[REF_AVX2]. See poly_decompose_88_avx2.c for more information. - */ - /* f1, f2 = decompose(f) */ - f1 = _mm256_add_epi32(f, off); - f1 = _mm256_srli_epi32(f1, 7); - f1 = _mm256_mulhi_epu16(f1, v); - f1 = _mm256_mulhrs_epi16(f1, shift); - t = _mm256_cmpgt_epi32(f, q_bound); - f0 = _mm256_mullo_epi32(f1, alpha); - f0 = _mm256_sub_epi32(f, f0); - f1 = _mm256_andnot_si256(t, f1); - f0 = _mm256_add_epi32(f0, t); - - /* Reference: The reference avx2 implementation checks a0 >= 0, which is - * different from the specification and the reference C implementation. We - * follow the specification and check a0 > 0. - */ - /* t = (f0 > 0) ? h : -h */ - f0 = _mm256_cmpgt_epi32(f0, zero); - t = MLD_MM256_BLENDV_EPI32(h, zero, f0); - t = _mm256_slli_epi32(t, 1); - h = _mm256_sub_epi32(h, t); - - /* f1 = (f1 + t) % 44 */ - f1 = _mm256_add_epi32(f1, h); - f1 = MLD_MM256_BLENDV_EPI32(f1, max, f1); - f = _mm256_cmpgt_epi32(f1, max); - f1 = MLD_MM256_BLENDV_EPI32(f1, zero, f); - - _mm256_store_si256((__m256i *)&a[8 * i], f1); - } -} - -#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_VERIFY_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == 44) \ - */ - -MLD_EMPTY_CU(avx2_poly_use_hint_88) - -#endif /* !(MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_NO_VERIFY_API && \ - !MLD_CONFIG_MULTILEVEL_NO_SHARED && \ - (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == \ - 44)) */ - -/* To facilitate single-compilation-unit (SCU) builds, undefine all macros. - * Don't modify by hand -- this is auto-generated by scripts/autogen. */ -#undef MLD_MM256_BLENDV_EPI32 diff --git a/mldsa/src/native/x86_64/src/poly_use_hint_88_avx2_asm.S b/mldsa/src/native/x86_64/src/poly_use_hint_88_avx2_asm.S new file mode 100644 index 000000000..caf9277b1 --- /dev/null +++ b/mldsa/src/native/x86_64/src/poly_use_hint_88_avx2_asm.S @@ -0,0 +1,111 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + + +/************************************************* + * Name: mld_poly_use_hint_88_avx2_asm + * + * Description: Use hint polynomial to correct the high bits of a polynomial. + * Variant for parameter set ML-DSA-44 (GAMMA2 = (Q-1)/88). + * Operates in place. + * + * Arguments: - int32_t *a: pointer to input/output polynomial; the + * corrected high bits are written back here + * - const int32_t *h: pointer to input hint polynomial + **************************************************/ + +#include "../../../common.h" + +#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \ + !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED) && \ + (defined(MLD_CONFIG_MULTILEVEL_WITH_SHARED) || \ + MLD_CONFIG_PARAMETER_SET == 44) + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/x86_64/src/poly_use_hint_88_avx2_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +.global MLD_ASM_NAMESPACE(poly_use_hint_88_avx2_asm) +MLD_ASM_FN_SYMBOL(poly_use_hint_88_avx2_asm) + + .cfi_startproc + movl $0x7f, %ecx + xorl %eax, %eax + vpxor %xmm5, %xmm5, %xmm5 + movl $0x2c0b, %r8d # imm = 0x2C0B + vmovd %r8d, %xmm8 + vpbroadcastd %xmm8, %ymm8 + vmovd %ecx, %xmm4 + movl $0x7e6c00, %ecx # imm = 0x7E6C00 + movl $0x80, %r9d + vmovd %r9d, %xmm7 + vpbroadcastd %xmm7, %ymm7 + movl $0x2b, %r10d + vmovd %r10d, %xmm6 + vpbroadcastd %xmm6, %ymm6 + vmovd %ecx, %xmm3 + vpbroadcastd %xmm4, %ymm4 + vpbroadcastd %xmm3, %ymm3 + +Lpoly_use_hint_88_avx2_asm_loop: + vmovdqa (%rdi), %ymm0 + vmovdqa (%rsi), %ymm1 + vpaddd %ymm0, %ymm4, %ymm10 + vpsrld $0x7, %ymm10, %ymm10 + vpmulhuw %ymm8, %ymm10, %ymm10 + vpmulhrsw %ymm7, %ymm10, %ymm10 + vpcmpgtd %ymm3, %ymm0, %ymm11 + vpandn %ymm10, %ymm11, %ymm9 + vpslld $0x1, %ymm10, %ymm12 + vpaddd %ymm10, %ymm12, %ymm12 + vpslld $0x5, %ymm12, %ymm10 + vpsubd %ymm12, %ymm10, %ymm10 + vpslld $0xb, %ymm10, %ymm10 + vpsubd %ymm10, %ymm0, %ymm0 + vpaddd %ymm11, %ymm0, %ymm0 + vpcmpgtd %ymm5, %ymm0, %ymm0 + vpandn %ymm1, %ymm0, %ymm0 + vpslld $0x1, %ymm0, %ymm0 + vpsubd %ymm0, %ymm1, %ymm0 + vpaddd %ymm9, %ymm0, %ymm0 + vpblendvb %ymm0, %ymm6, %ymm0, %ymm0 + vpcmpgtd %ymm6, %ymm0, %ymm1 + vpandn %ymm0, %ymm1, %ymm0 + vmovdqa %ymm0, (%rdi) + addq $0x20, %rdi + addq $0x20, %rsi + addq $0x20, %rax + cmpq $0x400, %rax # imm = 0x400 + jne Lpoly_use_hint_88_avx2_asm_loop + retq + .cfi_endproc + +MLD_ASM_FN_SIZE(poly_use_hint_88_avx2_asm) + + +#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \ + && (MLD_CONFIG_MULTILEVEL_WITH_SHARED || MLD_CONFIG_PARAMETER_SET == \ + 44) */ + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/nix/s2n_bignum/default.nix b/nix/s2n_bignum/default.nix index 77a58d11d..75bcb6483 100644 --- a/nix/s2n_bignum/default.nix +++ b/nix/s2n_bignum/default.nix @@ -4,12 +4,12 @@ { stdenv, fetchFromGitHub, writeText, ... }: stdenv.mkDerivation rec { pname = "s2n_bignum"; - version = "2f8b8d8562ef001508d497f6b31a4bdd2add0c8e"; + version = "9061e8b76522beafa5ca020f3c8d99b23eba4fbc"; src = fetchFromGitHub { owner = "awslabs"; repo = "s2n-bignum"; rev = "${version}"; - hash = "sha256-rz6qzDMUapxOtu0lsj9uWhPnURMNcCCVC79Zs7SdrZA="; + hash = "sha256-NvtrVfiz5yxfdNvD0P1wSQrn37znuWMbNWxys4jZlU4="; }; setupHook = writeText "setup-hook.sh" '' export S2N_BIGNUM_DIR="$1" diff --git a/proofs/cbmc/poly_use_hint_native_x86_64/Makefile b/proofs/cbmc/poly_use_hint_native_x86_64/Makefile new file mode 100644 index 000000000..755d62f95 --- /dev/null +++ b/proofs/cbmc/poly_use_hint_native_x86_64/Makefile @@ -0,0 +1,49 @@ +# Copyright (c) The mldsa-native project authors +# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +include ../Makefile_params.common + +HARNESS_ENTRY = harness +HARNESS_FILE = poly_use_hint_native_x86_64_harness + +# This should be a unique identifier for this proof, and will appear on the +# Litani dashboard. It can be human-readable and contain spaces if you wish. +PROOF_UID = poly_use_hint_native_x86_64 + +# We need to set MLD_CHECK_APIS as otherwise mldsa/src/native/api.h won't be +# included, which contains the CBMC specifications. +DEFINES += -DMLD_CONFIG_USE_NATIVE_BACKEND_ARITH -DMLD_CONFIG_ARITH_BACKEND_FILE="\"$(SRCDIR)/mldsa/src/native/x86_64/meta.h\"" -DMLD_CHECK_APIS +INCLUDES += + +REMOVE_FUNCTION_BODY += +UNWINDSET += + +PROOF_SOURCES += $(PROOFDIR)/$(HARNESS_FILE).c +PROJECT_SOURCES += $(SRCDIR)/mldsa/src/poly_kl.c + +# poly_use_hint_88 is used with ML-DSA-44 (GAMMA2 = (Q-1)/88); +# poly_use_hint_32 with ML-DSA-65 and ML-DSA-87 (GAMMA2 = (Q-1)/32). +ifeq ($(MLD_CONFIG_PARAMETER_SET),44) + CHECK_FUNCTION_CONTRACTS=mld_poly_use_hint_88_native + USE_FUNCTION_CONTRACTS=mld_poly_use_hint_88_avx2_asm +else ifeq ($(MLD_CONFIG_PARAMETER_SET),65) + CHECK_FUNCTION_CONTRACTS=mld_poly_use_hint_32_native + USE_FUNCTION_CONTRACTS=mld_poly_use_hint_32_avx2_asm +else ifeq ($(MLD_CONFIG_PARAMETER_SET),87) + CHECK_FUNCTION_CONTRACTS=mld_poly_use_hint_32_native + USE_FUNCTION_CONTRACTS=mld_poly_use_hint_32_avx2_asm +endif +USE_FUNCTION_CONTRACTS+=mld_sys_check_capability +APPLY_LOOP_CONTRACTS=on +USE_DYNAMIC_FRAMES=1 + +# Disable any setting of EXTERNAL_SAT_SOLVER, and choose SMT backend instead +EXTERNAL_SAT_SOLVER= +CBMCFLAGS=--smt2 + +FUNCTION_NAME = poly_use_hint_native_x86_64 + +# This function is large enough to need... +CBMC_OBJECT_BITS = 8 + +include ../Makefile.common diff --git a/proofs/cbmc/poly_use_hint_native_x86_64/poly_use_hint_native_x86_64_harness.c b/proofs/cbmc/poly_use_hint_native_x86_64/poly_use_hint_native_x86_64_harness.c new file mode 100644 index 000000000..bae3059b6 --- /dev/null +++ b/proofs/cbmc/poly_use_hint_native_x86_64/poly_use_hint_native_x86_64_harness.c @@ -0,0 +1,24 @@ +// Copyright (c) The mldsa-native project authors +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +#include +#include "cbmc.h" +#include "params.h" + +#if MLDSA_GAMMA2 == ((MLDSA_Q - 1) / 88) +int mld_poly_use_hint_88_native(int32_t *a, const int32_t *h); +#else +int mld_poly_use_hint_32_native(int32_t *a, const int32_t *h); +#endif + +void harness(void) +{ + int32_t *a, *h; + int t; + +#if MLDSA_GAMMA2 == ((MLDSA_Q - 1) / 88) + t = mld_poly_use_hint_88_native(a, h); +#else + t = mld_poly_use_hint_32_native(a, h); +#endif +} diff --git a/proofs/hol_light/README.md b/proofs/hol_light/README.md index 606c4e657..25a789f78 100644 --- a/proofs/hol_light/README.md +++ b/proofs/hol_light/README.md @@ -170,6 +170,8 @@ All routines listed below have been proven correct, memory-safe, and secret-inde * x86_64 poly_chknorm: [poly_chknorm_avx2_asm.S](x86_64/mldsa/poly_chknorm_avx2_asm.S) * x86_64 polyz_unpack (l=4): [polyz_unpack_17_avx2_asm.S](x86_64/mldsa/polyz_unpack_17_avx2_asm.S) * x86_64 polyz_unpack (l=5,7): [polyz_unpack_19_avx2_asm.S](x86_64/mldsa/polyz_unpack_19_avx2_asm.S) + * x86_64 poly_use_hint (l=5,7): [poly_use_hint_32_avx2_asm.S](x86_64/mldsa/poly_use_hint_32_avx2_asm.S) + * x86_64 poly_use_hint (l=4): [poly_use_hint_88_avx2_asm.S](x86_64/mldsa/poly_use_hint_88_avx2_asm.S) - FIPS202: * 4-fold Keccak-F1600 using AVX2: [keccak_f1600_x4_avx2_asm.S](x86_64/mldsa/keccak_f1600_x4_avx2_asm.S) diff --git a/proofs/hol_light/x86_64/Makefile b/proofs/hol_light/x86_64/Makefile index 693078496..238859b56 100644 --- a/proofs/hol_light/x86_64/Makefile +++ b/proofs/hol_light/x86_64/Makefile @@ -57,6 +57,8 @@ OBJ = mldsa/ntt_avx2_asm.o \ mldsa/poly_chknorm_avx2_asm.o \ mldsa/polyz_unpack_17_avx2_asm.o \ mldsa/polyz_unpack_19_avx2_asm.o \ + mldsa/poly_use_hint_32_avx2_asm.o \ + mldsa/poly_use_hint_88_avx2_asm.o \ mldsa/pointwise_avx2_asm.o \ mldsa/pointwise_acc_l4_avx2_asm.o \ mldsa/pointwise_acc_l5_avx2_asm.o \ diff --git a/proofs/hol_light/x86_64/mldsa/poly_use_hint_32_avx2_asm.S b/proofs/hol_light/x86_64/mldsa/poly_use_hint_32_avx2_asm.S new file mode 100644 index 000000000..f06e757b0 --- /dev/null +++ b/proofs/hol_light/x86_64/mldsa/poly_use_hint_32_avx2_asm.S @@ -0,0 +1,102 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + + +/************************************************* + * Name: mld_poly_use_hint_32_avx2_asm + * + * Description: Use hint polynomial to correct the high bits of a polynomial. + * Variant for parameter sets ML-DSA-65 and ML-DSA-87 + * (GAMMA2 = (Q-1)/32). Operates in place. + * + * Arguments: - int32_t *a: pointer to input/output polynomial; the + * corrected high bits are written back here + * - const int32_t *h: pointer to input hint polynomial + **************************************************/ + + + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/x86_64/src/poly_use_hint_32_avx2_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +#ifdef __APPLE__ +.global _mld_poly_use_hint_32_avx2_asm +_mld_poly_use_hint_32_avx2_asm: +#else +.global mld_poly_use_hint_32_avx2_asm +mld_poly_use_hint_32_avx2_asm: +#endif + + .cfi_startproc + endbr64 + movl $0x7f, %ecx + movl $0x401, %r8d # imm = 0x401 + vmovd %r8d, %xmm8 + vpbroadcastd %xmm8, %ymm8 + xorl %eax, %eax + vpxor %xmm6, %xmm6, %xmm6 + vmovd %ecx, %xmm5 + movl $0x7be100, %ecx # imm = 0x7BE100 + movl $0x200, %r9d # imm = 0x200 + vmovd %r9d, %xmm7 + vpbroadcastd %xmm7, %ymm7 + vmovd %ecx, %xmm4 + movl $0xf, %ecx + vpbroadcastd %xmm5, %ymm5 + vmovd %ecx, %xmm3 + vpbroadcastd %xmm4, %ymm4 + vpbroadcastd %xmm3, %ymm3 + +Lpoly_use_hint_32_avx2_asm_loop: + vmovdqa (%rdi), %ymm0 + vmovdqa (%rsi), %ymm2 + vpaddd %ymm0, %ymm5, %ymm1 + vpsrld $0x7, %ymm1, %ymm1 + vpmulhuw %ymm8, %ymm1, %ymm1 + vpmulhrsw %ymm7, %ymm1, %ymm1 + vpcmpgtd %ymm4, %ymm0, %ymm11 + vpandn %ymm1, %ymm11, %ymm9 + vpslld $0xa, %ymm1, %ymm10 + vpsubd %ymm1, %ymm10, %ymm1 + vpslld $0x9, %ymm1, %ymm1 + vpsubd %ymm1, %ymm0, %ymm0 + vpaddd %ymm11, %ymm0, %ymm0 + vpcmpgtd %ymm6, %ymm0, %ymm0 + vpandn %ymm2, %ymm0, %ymm0 + vpslld $0x1, %ymm0, %ymm0 + vpsubd %ymm0, %ymm2, %ymm2 + vpaddd %ymm9, %ymm2, %ymm2 + vpand %ymm3, %ymm2, %ymm2 + vmovdqa %ymm2, (%rdi) + addq $0x20, %rdi + addq $0x20, %rsi + addq $0x20, %rax + cmpq $0x400, %rax # imm = 0x400 + jne Lpoly_use_hint_32_avx2_asm_loop + retq + .cfi_endproc + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/proofs/hol_light/x86_64/mldsa/poly_use_hint_88_avx2_asm.S b/proofs/hol_light/x86_64/mldsa/poly_use_hint_88_avx2_asm.S new file mode 100644 index 000000000..9ca78af82 --- /dev/null +++ b/proofs/hol_light/x86_64/mldsa/poly_use_hint_88_avx2_asm.S @@ -0,0 +1,105 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + +/* References + * ========== + * + * - [REF_AVX2] + * CRYSTALS-Dilithium optimized AVX2 implementation + * Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé + * https://github.com/pq-crystals/dilithium/tree/master/avx2 + */ + +/* + * This file is derived from the public domain + * AVX2 Dilithium implementation @[REF_AVX2]. + */ + + +/************************************************* + * Name: mld_poly_use_hint_88_avx2_asm + * + * Description: Use hint polynomial to correct the high bits of a polynomial. + * Variant for parameter set ML-DSA-44 (GAMMA2 = (Q-1)/88). + * Operates in place. + * + * Arguments: - int32_t *a: pointer to input/output polynomial; the + * corrected high bits are written back here + * - const int32_t *h: pointer to input hint polynomial + **************************************************/ + + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/x86_64/src/poly_use_hint_88_avx2_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +#ifdef __APPLE__ +.global _mld_poly_use_hint_88_avx2_asm +_mld_poly_use_hint_88_avx2_asm: +#else +.global mld_poly_use_hint_88_avx2_asm +mld_poly_use_hint_88_avx2_asm: +#endif + + .cfi_startproc + endbr64 + movl $0x7f, %ecx + xorl %eax, %eax + vpxor %xmm5, %xmm5, %xmm5 + movl $0x2c0b, %r8d # imm = 0x2C0B + vmovd %r8d, %xmm8 + vpbroadcastd %xmm8, %ymm8 + vmovd %ecx, %xmm4 + movl $0x7e6c00, %ecx # imm = 0x7E6C00 + movl $0x80, %r9d + vmovd %r9d, %xmm7 + vpbroadcastd %xmm7, %ymm7 + movl $0x2b, %r10d + vmovd %r10d, %xmm6 + vpbroadcastd %xmm6, %ymm6 + vmovd %ecx, %xmm3 + vpbroadcastd %xmm4, %ymm4 + vpbroadcastd %xmm3, %ymm3 + +Lpoly_use_hint_88_avx2_asm_loop: + vmovdqa (%rdi), %ymm0 + vmovdqa (%rsi), %ymm1 + vpaddd %ymm0, %ymm4, %ymm10 + vpsrld $0x7, %ymm10, %ymm10 + vpmulhuw %ymm8, %ymm10, %ymm10 + vpmulhrsw %ymm7, %ymm10, %ymm10 + vpcmpgtd %ymm3, %ymm0, %ymm11 + vpandn %ymm10, %ymm11, %ymm9 + vpslld $0x1, %ymm10, %ymm12 + vpaddd %ymm10, %ymm12, %ymm12 + vpslld $0x5, %ymm12, %ymm10 + vpsubd %ymm12, %ymm10, %ymm10 + vpslld $0xb, %ymm10, %ymm10 + vpsubd %ymm10, %ymm0, %ymm0 + vpaddd %ymm11, %ymm0, %ymm0 + vpcmpgtd %ymm5, %ymm0, %ymm0 + vpandn %ymm1, %ymm0, %ymm0 + vpslld $0x1, %ymm0, %ymm0 + vpsubd %ymm0, %ymm1, %ymm0 + vpaddd %ymm9, %ymm0, %ymm0 + vpblendvb %ymm0, %ymm6, %ymm0, %ymm0 + vpcmpgtd %ymm6, %ymm0, %ymm1 + vpandn %ymm0, %ymm1, %ymm0 + vmovdqa %ymm0, (%rdi) + addq $0x20, %rdi + addq $0x20, %rsi + addq $0x20, %rax + cmpq $0x400, %rax # imm = 0x400 + jne Lpoly_use_hint_88_avx2_asm_loop + retq + .cfi_endproc + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/proofs/hol_light/x86_64/proofs/dump_bytecode.ml b/proofs/hol_light/x86_64/proofs/dump_bytecode.ml index d5b7ad12f..504e26cc1 100644 --- a/proofs/hol_light/x86_64/proofs/dump_bytecode.ml +++ b/proofs/hol_light/x86_64/proofs/dump_bytecode.ml @@ -52,3 +52,11 @@ print_string "==== bytecode end =====================================\n\n";; print_string "=== bytecode start: x86_64/mldsa/polyz_unpack_19_avx2_asm.o ================\n";; print_literal_from_elf "x86_64/mldsa/polyz_unpack_19_avx2_asm.o";; print_string "==== bytecode end =====================================\n\n";; + +print_string "=== bytecode start: x86_64/mldsa/poly_use_hint_32_avx2_asm.o ================\n";; +print_literal_from_elf "x86_64/mldsa/poly_use_hint_32_avx2_asm.o";; +print_string "==== bytecode end =====================================\n\n";; + +print_string "=== bytecode start: x86_64/mldsa/poly_use_hint_88_avx2_asm.o ================\n";; +print_literal_from_elf "x86_64/mldsa/poly_use_hint_88_avx2_asm.o";; +print_string "==== bytecode end =====================================\n\n";; diff --git a/proofs/hol_light/x86_64/proofs/mldsa_utils.ml b/proofs/hol_light/x86_64/proofs/mldsa_utils.ml index 88b6bf6ac..8b53ef860 100644 --- a/proofs/hol_light/x86_64/proofs/mldsa_utils.ml +++ b/proofs/hol_light/x86_64/proofs/mldsa_utils.ml @@ -163,3 +163,236 @@ let MAP_SUB_LIST = prove MATCH_MP_TAC num_INDUCTION THEN ASM_REWRITE_TAC[SUB_LIST_CLAUSES; MAP] THEN REPEAT STRIP_TAC THEN SPEC_TAC(`q:num`,`q:num`) THEN MATCH_MP_TAC num_INDUCTION THEN ASM_REWRITE_TAC[SUB_LIST_CLAUSES; MAP]);; + +(* ------------------------------------------------------------------------- *) +(* Shared 256-bit-block / 8-lane framework for the in-place poly routines *) +(* (poly_use_hint_32/88 etc.) that loop over 32 blocks of eight int32 lanes. *) +(* These are arch-independent of the per-coefficient model. *) +(* ------------------------------------------------------------------------- *) + +(* Eight consecutive int32 coefficients packed into one 256-bit word. *) +let pack8 = new_definition + `pack8 (f:num->int32) (b:num) : int256 = + word_join + (word_join (word_join (f (8*b+7)) (f (8*b+6)):int64) + (word_join (f (8*b+5)) (f (8*b+4)):int64):int128) + (word_join (word_join (f (8*b+3)) (f (8*b+2)):int64) + (word_join (f (8*b+1)) (f (8*b+0)):int64):int128)`;; + +(* Lane k (k<8) of a packed block is coefficient 8b+k. *) +let PACK8_LANE = prove( + `!f b. !k. k < 8 ==> word_subword (pack8 f b) (32*k,32):int32 = f(8*b+k)`, + GEN_TAC THEN GEN_TAC THEN + CONV_TAC EXPAND_CASES_CONV THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[pack8] THEN REPEAT CONJ_TAC THEN CONV_TAC WORD_BLAST);; + +(* Lane k of a SIMD8 map is the scalar map applied to the corresponding lanes. *) +let SIMD8_LANE = prove( + `!(g:int32->int32->int32) av hv. !k. k < 8 ==> + word_subword (simd8 g av hv) (32*k,32):int32 = + g (word_subword av (32*k,32)) (word_subword hv (32*k,32))`, + GEN_TAC THEN GEN_TAC THEN GEN_TAC THEN + REWRITE_TAC[simd8;simd4;simd2;DIMINDEX_32] THEN + CONV_TAC EXPAND_CASES_CONV THEN CONV_TAC NUM_REDUCE_CONV THEN + REPEAT CONJ_TAC THEN CONV_TAC(DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN REFL_TAC);; + +(* Coefficient address 4*(8b+k) split into block base 32*b plus lane offset. *) +let ADDR_SPLIT = prove( + `!p:int64 b k. word_add p (word(4*(8*b+k))) = + word_add (word_add p (word(32*b))) (word(4*k))`, + REPEAT GEN_TAC THEN REWRITE_TAC[ARITH_RULE `4*(8*b+k) = 32*b+4*k`] THEN + CONV_TAC WORD_RULE);; + +(* A coefficient (bytes32) read is the matching lane of the block (bytes256) read. *) +let BLOCK_SPLIT = prove( + `!p:int64 s:x86state b. !k. k < 8 ==> + read (memory :> bytes32 (word_add p (word(4*(8*b+k))))) s = + word_subword (read (memory :> bytes256 (word_add p (word(32*b)))) s) (32*k,32):int32`, + GEN_TAC THEN GEN_TAC THEN GEN_TAC THEN + CONV_TAC(RAND_CONV(ONCE_DEPTH_CONV(READ_MEMORY_MERGE_CONV 3))) THEN + GEN_REWRITE_TAC (BINDER_CONV o RAND_CONV o LAND_CONV o ONCE_DEPTH_CONV) [ADDR_SPLIT] THEN + CONV_TAC EXPAND_CASES_CONV THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[WORD_RULE `word_add q (word 0) = q`] THEN + REPEAT CONJ_TAC THEN CONV_TAC WORD_BLAST);; + +(* The block (bytes256) read assembles from its eight coefficient reads. *) +let PACK8_MERGE = prove( + `!(x:num->int32) p:int64 s:x86state b. + b < 32 /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add p (word(4*i)))) s = x i) + ==> read (memory :> bytes256 (word_add p (word(32*b)))) s = pack8 x b`, + REPEAT GEN_TAC THEN STRIP_TAC THEN REWRITE_TAC[pack8] THEN + CONV_TAC(LAND_CONV(READ_MEMORY_MERGE_CONV 3)) THEN + SUBGOAL_THEN + `!k. k < 8 ==> read (memory :> bytes32 (word_add (word_add p (word(32*b))) (word(4*k)))) (s:x86state) = x(8*b+k)` + (fun th -> + MP_TAC(SPEC `0` th) THEN MP_TAC(SPEC `1` th) THEN MP_TAC(SPEC `2` th) THEN MP_TAC(SPEC `3` th) THEN + MP_TAC(SPEC `4` th) THEN MP_TAC(SPEC `5` th) THEN MP_TAC(SPEC `6` th) THEN MP_TAC(SPEC `7` th)) THENL + [GEN_TAC THEN DISCH_TAC THEN REWRITE_TAC[GSYM ADDR_SPLIT] THEN + FIRST_X_ASSUM(fun th -> MP_TAC(SPEC `8*b+k` th)) THEN + ANTS_TAC THENL [UNDISCH_TAC `b:num<32` THEN UNDISCH_TAC `k:num<8` THEN ARITH_TAC; SIMP_TAC[]]; + CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[WORD_RULE `word_add q (word 0) = q`] THEN + REPEAT(DISCH_THEN SUBST1_TAC) THEN REFL_TAC]);; + +(* Two int256 words agree if all eight 32-bit lanes agree. *) +let LANES8_EQ = prove + (`!x y:int256. (!k. k < 8 ==> word_subword x (32*k,32):int32 = word_subword y (32*k,32)) ==> x = y`, + REPEAT GEN_TAC THEN + DISCH_THEN(fun th -> MP_TAC(CONV_RULE(EXPAND_CASES_CONV THENC NUM_REDUCE_CONV) th)) THEN + CONV_TAC WORD_BLAST);; + +(* 32-byte blocks preserve 32-byte alignment of the base pointer. *) +let ALIGNED_32I = prove + (`!i. aligned 32 (word(32*i):int64)`, + GEN_TAC THEN REWRITE_TAC[aligned; DIMINDEX_64; VAL_WORD; DIMINDEX_64] THEN + CONJ_TAC THENL + [REWRITE_TAC[DIVIDES_MOD] THEN CONV_TAC NUM_REDUCE_CONV; + MP_TAC(SPECL [`32`; `32 * i`; `2 EXP 64`] DIVIDES_MOD2) THEN + ANTS_TAC THENL + [REWRITE_TAC[DIVIDES_MOD] THEN CONV_TAC NUM_REDUCE_CONV; ALL_TAC] THEN + DISCH_THEN(SUBST1_TAC o SYM) THEN NUMBER_TAC]);; + +let ALIGNED_BLOCK = prove + (`!a:int64 i. aligned 32 a ==> aligned 32 (word_add a (word(32*i)))`, + REPEAT STRIP_TAC THEN MATCH_MP_TAC ALIGNED_WORD_ADD THEN + ASM_REWRITE_TAC[ALIGNED_32I]);; + +(* word_join of a zero high half is just zero-extension of the low half. *) +let JOIN_ZERO_ZX = prove + (`!lo:(16)word. word_join (word 0:(16)word) lo :int32 = word_zx lo`, + GEN_TAC THEN CONV_TAC WORD_BLAST);; + +(* word_ile/word_igt against 0 are complementary on int32. *) +let ILE_IGT = BITBLAST_RULE + `!a0:int32. word_ile a0 (word 0) <=> ~(word_igt a0 (word 0))`;; + +let WORD_NOT_0 = WORD_RULE `!x:N word. word_and x (word_not (word 0)) = x`;; + +(* Per-step state compaction during SIMD body simulation: abbreviate every large + int256 register value to a fresh atom so it propagates compactly (essential + for instructions like VPBLENDVB whose byte-mux otherwise duplicates the value). *) +let ABBREV_BIG_TAC : tactic = fun (asl,w) -> + MAP_EVERY (fun (_,th) -> AUTO_ABBREV_TAC (rand(concl th))) + (filter (fun (_,th) -> let c=concl th in is_eq c && + (try is_comb(lhs c) && fst(dest_const(fst(strip_comb(lhs c))))="read" + && type_of(lhs c)=`:int256` && not(is_var(rand c)) with _->false) + && String.length(string_of_term(rand c)) > 1500) asl) (asl,w);; + +(* ------------------------------------------------------------------------- *) +(* Shared scalar UseHint lemmas (poly_use_hint_32/88). Arch- and *) +(* parameter-independent: the per-coefficient Barrett rounding, lane *) +(* value/sign-extension facts and the +/-1 delta-encoding bridge. *) +(* ------------------------------------------------------------------------- *) + +(* Rounding division: ((q DIV n) + 1) DIV 2 = (q + n) DIV (2 * n). *) +let ROUND_DIV = prove(`!q n. ~(n = 0) ==> (q DIV n + 1) DIV 2 = (q + n) DIV (2 * n)`, + REPEAT STRIP_TAC THEN + SUBGOAL_THEN `(q + n) DIV (2 * n) = (q + n) DIV n DIV 2` SUBST1_TAC THENL + [REWRITE_TAC[DIV_DIV] THEN AP_TERM_TAC THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(q + n) DIV n = q DIV n + 1` (fun th -> REWRITE_TAC[th]) THEN + ASM_SIMP_TAC[DIV_ADD; DIVIDES_REFL] THEN ASM_SIMP_TAC[DIV_REFL]);; + +(* The pre-shift t = (a + 127) >>u 7 has value (val a + 127) DIV 128 (no overflow + since val a < Q < 2^31). This is the f1' input to the Barrett step. *) +let VAL_T = prove(`!x:int32. val x < 8380417 + ==> val(word_ushr (word_add (word 127) x) 7 :int32) = (val x + 127) DIV 128`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[VAL_WORD_USHR] THEN + SUBGOAL_THEN `val(word_add (word 127:int32) x) = val x + 127` SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD_ADD; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `(127 + val(x:int32)) MOD 4294967296 = 127 + val x` + (fun th -> REWRITE_TAC[th] THEN ARITH_TAC) THEN + MATCH_MP_TAC MOD_LT THEN MP_TAC(ISPEC `x:int32` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_32] THEN ASM_ARITH_TAC; + REWRITE_TAC[ARITH_RULE `2 EXP 7 = 128`]]);; + +(* Bounded 16->32 sign-extension equals zero-extension on value: for a 16-bit + lane below 2^15 (top bit clear) word_sx agrees with the numeric value. Used to + evaluate the signed 16x16 multiply in the vpmulhrsw lane. *) +let VAL_WORD_SX_SMALL = prove(`!u:16 word. val u < 32768 + ==> val((word_sx u):int32) = val u`, + GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `(word_sx (u:16 word)):int32 = word_zx u` SUBST1_TAC THENL + [REWRITE_TAC[WORD_SX_ZX_GEN; DIMINDEX_16] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_REDUCE_CONV) THEN + SUBGOAL_THEN `bit 15 (u:16 word) = F` SUBST1_TAC THENL + [REWRITE_TAC[BIT_VAL] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[BITVAL_CLAUSES; + WORD_REDUCE_CONV `word_shl (word_neg (word 0:int32)) 16`; + WORD_OR_0]; ALL_TAC] THEN + REWRITE_TAC[VAL_WORD_ZX_GEN; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(u:16 word) MOD 4294967296 = val u` (fun th->REWRITE_TAC[th]) THEN + MATCH_MP_TAC MOD_LT THEN MP_TAC(ISPEC `u:16 word` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_16] THEN ARITH_TAC);; + +(* delta encoding bridge (needs the hint bound val h <= 1): + the assembly computes delta*h as h - (andnot(dlt,h))<<1 with + dlt = (a0 >s 0); the model uses word_mul of the +1/-1 delta. *) +let DELTA_EQ_BOUNDED = prove + (`!a0:int32 h:int32. val h <= 1 ==> + word_sub h (word_shl (word_and (word_not + (if word_igt a0 (word 0) then word 4294967295 else word 0)) h) 1) = + word_mul (word_or (word_neg (word (bitval (word_ile a0 (word 0))))) (word 1)) h`, + REPEAT GEN_TAC THEN DISCH_TAC THEN REWRITE_TAC[ILE_IGT] THEN + SUBGOAL_THEN `h:int32 = word 0 \/ h = word 1` STRIP_ASSUME_TAC THENL + [POP_ASSUM MP_TAC THEN SPEC_TAC(`h:int32`,`h:int32`) THEN + REWRITE_TAC[GSYM VAL_EQ_0; GSYM VAL_EQ_1] THEN ARITH_TAC; + ASM_REWRITE_TAC[] THEN COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN CONV_TAC WORD_BLAST; + ASM_REWRITE_TAC[] THEN COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN CONV_TAC WORD_BLAST]);; + +(* ------------------------------------------------------------------------- *) +(* Barrett-quotient DIV/MOD tactics over the per-variant divisor 2*GAMMA2 *) +(* (gg below): 523776 for poly_use_hint_32, 190464 for poly_use_hint_88. *) +(* Each proof aliases these at its concrete gg. *) +(* ------------------------------------------------------------------------- *) + +(* Eliminate `r MOD gg` / `r DIV gg` from the assumptions and abstract them, + leaving an arithmetic goal solvable by ASM_ARITH_TAC. *) +let LINEARIZE_DIV_MOD_BY_TAC gg = + let s = subst [mk_small_numeral gg, `gg:num`] in + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in (s `r MOD gg`) (concl th) || free_in (s `r DIV gg`) (concl th)))) THEN + MP_TAC(SPECL [`r:num`; mk_small_numeral gg] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(s `r MOD gg`, `m:num`) THEN + SPEC_TAC(s `r DIV gg`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +(* Replace `(r - r MOD gg) DIV gg` with `r DIV gg`. *) +let DIV_MOD_TO_DIV_BY_TAC gg = + let s = subst [mk_small_numeral gg, `gg:num`] in + SUBGOAL_THEN (s `(r - r MOD gg) DIV gg = r DIV gg`) SUBST1_TAC THENL + [MP_TAC(SPECL [`r:num`; mk_small_numeral gg] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN (s `r - r MOD gg = gg * r DIV gg`) SUBST1_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [mk_small_numeral gg; s `r DIV gg`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV; ALL_TAC];; + +(* Prove `r DIV gg = k` via DIV_SANDWICH + LE_MULT_RCANCEL. *) +let DIV_EQ_K_BY_TAC gg k = + let s = subst [mk_small_numeral gg, `gg:num`] in + let k_num = mk_small_numeral k and km1 = mk_small_numeral (k-1) + and kp1 = mk_small_numeral (k+1) + and q = mk_var("q",`:num`) and le = `(<=):num->num->bool` + and lt = `(<):num->num->bool` and c = mk_small_numeral gg in + let mk_mul a b = mk_binop (rator(rator `0*0`)) a b in + MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in (s `r MOD gg`) (concl th) || free_in (s `r DIV gg`) (concl th)))) THEN + MP_TAC(SPECL [`r:num`; c] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(s `r MOD gg`, `m:num`) THEN + SPEC_TAC(s `r DIV gg`, q) THEN + REPEAT GEN_TAC THEN STRIP_TAC THEN + ASM_CASES_TAC(mk_comb(mk_comb(le, q), km1)) THENL + [SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul q c), mk_mul km1 c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul k_num c), mk_mul q c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC(mk_comb(mk_comb(lt, k_num), q)) THENL + [SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul kp1 c), mk_mul q c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC]];; diff --git a/proofs/hol_light/x86_64/proofs/poly_use_hint_32_avx2_asm.ml b/proofs/hol_light/x86_64/proofs/poly_use_hint_32_avx2_asm.ml new file mode 100644 index 000000000..eba63707c --- /dev/null +++ b/proofs/hol_light/x86_64/proofs/poly_use_hint_32_avx2_asm.ml @@ -0,0 +1,1326 @@ +(* + * Copyright (c) The mldsa-native project authors + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* ========================================================================= *) +(* Use hint to correct high bits of decomposition (ML-DSA, param 65/87). *) +(* x86_64 AVX2 variant (GAMMA2 = (Q-1)/32). *) +(* ========================================================================= *) + +needs "s2n_bignum/x86/proofs/base.ml";; +needs "mldsa_native/common/mldsa_specs.ml";; +needs "mldsa_native/x86_64/proofs/mldsa_utils.ml";; + +(**** print_literal_from_elf "x86_64/mldsa/poly_use_hint_32_avx2_asm.o";; + ****) + +let poly_use_hint_32_avx2_asm_mc = define_assert_from_elf + "poly_use_hint_32_avx2_asm_mc" "x86_64/mldsa/poly_use_hint_32_avx2_asm.o" +(*** BYTECODE START ***) +[ + 0xf3; 0x0f; 0x1e; 0xfa; (* ENDBR64 *) + 0xb9; 0x7f; 0x00; 0x00; 0x00; + (* MOV (% ecx) (Imm32 (word 127)) *) + 0x41; 0xb8; 0x01; 0x04; 0x00; 0x00; + (* MOV (% r8d) (Imm32 (word 1025)) *) + 0xc4; 0x41; 0x79; 0x6e; 0xc0; + (* VMOVD (%_% xmm8) (% r8d) *) + 0xc4; 0x42; 0x7d; 0x58; 0xc0; + (* VPBROADCASTD (%_% ymm8) (%_% xmm8) *) + 0x31; 0xc0; (* XOR (% eax) (% eax) *) + 0xc5; 0xc9; 0xef; 0xf6; (* VPXOR (%_% xmm6) (%_% xmm6) (%_% xmm6) *) + 0xc5; 0xf9; 0x6e; 0xe9; (* VMOVD (%_% xmm5) (% ecx) *) + 0xb9; 0x00; 0xe1; 0x7b; 0x00; + (* MOV (% ecx) (Imm32 (word 8118528)) *) + 0x41; 0xb9; 0x00; 0x02; 0x00; 0x00; + (* MOV (% r9d) (Imm32 (word 512)) *) + 0xc4; 0xc1; 0x79; 0x6e; 0xf9; + (* VMOVD (%_% xmm7) (% r9d) *) + 0xc4; 0xe2; 0x7d; 0x58; 0xff; + (* VPBROADCASTD (%_% ymm7) (%_% xmm7) *) + 0xc5; 0xf9; 0x6e; 0xe1; (* VMOVD (%_% xmm4) (% ecx) *) + 0xb9; 0x0f; 0x00; 0x00; 0x00; + (* MOV (% ecx) (Imm32 (word 15)) *) + 0xc4; 0xe2; 0x7d; 0x58; 0xed; + (* VPBROADCASTD (%_% ymm5) (%_% xmm5) *) + 0xc5; 0xf9; 0x6e; 0xd9; (* VMOVD (%_% xmm3) (% ecx) *) + 0xc4; 0xe2; 0x7d; 0x58; 0xe4; + (* VPBROADCASTD (%_% ymm4) (%_% xmm4) *) + 0xc4; 0xe2; 0x7d; 0x58; 0xdb; + (* VPBROADCASTD (%_% ymm3) (%_% xmm3) *) + 0xc5; 0xfd; 0x6f; 0x07; (* VMOVDQA (%_% ymm0) (Memop Word256 (%% (rdi,0))) *) + 0xc5; 0xfd; 0x6f; 0x16; (* VMOVDQA (%_% ymm2) (Memop Word256 (%% (rsi,0))) *) + 0xc5; 0xd5; 0xfe; 0xc8; (* VPADDD (%_% ymm1) (%_% ymm5) (%_% ymm0) *) + 0xc5; 0xf5; 0x72; 0xd1; 0x07; + (* VPSRLD (%_% ymm1) (%_% ymm1) (Imm8 (word 7)) *) + 0xc4; 0xc1; 0x75; 0xe4; 0xc8; + (* VPMULHUW (%_% ymm1) (%_% ymm1) (%_% ymm8) *) + 0xc4; 0xe2; 0x75; 0x0b; 0xcf; + (* VPMULHRSW (%_% ymm1) (%_% ymm1) (%_% ymm7) *) + 0xc5; 0x7d; 0x66; 0xdc; (* VPCMPGTD (%_% ymm11) (%_% ymm0) (%_% ymm4) *) + 0xc5; 0x25; 0xdf; 0xc9; (* VPANDN (%_% ymm9) (%_% ymm11) (%_% ymm1) *) + 0xc5; 0xad; 0x72; 0xf1; 0x0a; + (* VPSLLD (%_% ymm10) (%_% ymm1) (Imm8 (word 10)) *) + 0xc5; 0xad; 0xfa; 0xc9; (* VPSUBD (%_% ymm1) (%_% ymm10) (%_% ymm1) *) + 0xc5; 0xf5; 0x72; 0xf1; 0x09; + (* VPSLLD (%_% ymm1) (%_% ymm1) (Imm8 (word 9)) *) + 0xc5; 0xfd; 0xfa; 0xc1; (* VPSUBD (%_% ymm0) (%_% ymm0) (%_% ymm1) *) + 0xc4; 0xc1; 0x7d; 0xfe; 0xc3; + (* VPADDD (%_% ymm0) (%_% ymm0) (%_% ymm11) *) + 0xc5; 0xfd; 0x66; 0xc6; (* VPCMPGTD (%_% ymm0) (%_% ymm0) (%_% ymm6) *) + 0xc5; 0xfd; 0xdf; 0xc2; (* VPANDN (%_% ymm0) (%_% ymm0) (%_% ymm2) *) + 0xc5; 0xfd; 0x72; 0xf0; 0x01; + (* VPSLLD (%_% ymm0) (%_% ymm0) (Imm8 (word 1)) *) + 0xc5; 0xed; 0xfa; 0xd0; (* VPSUBD (%_% ymm2) (%_% ymm2) (%_% ymm0) *) + 0xc4; 0xc1; 0x6d; 0xfe; 0xd1; + (* VPADDD (%_% ymm2) (%_% ymm2) (%_% ymm9) *) + 0xc5; 0xed; 0xdb; 0xd3; (* VPAND (%_% ymm2) (%_% ymm2) (%_% ymm3) *) + 0xc5; 0xfd; 0x7f; 0x17; (* VMOVDQA (Memop Word256 (%% (rdi,0))) (%_% ymm2) *) + 0x48; 0x83; 0xc7; 0x20; (* ADD (% rdi) (Imm8 (word 32)) *) + 0x48; 0x83; 0xc6; 0x20; (* ADD (% rsi) (Imm8 (word 32)) *) + 0x48; 0x83; 0xc0; 0x20; (* ADD (% rax) (Imm8 (word 32)) *) + 0x48; 0x3d; 0x00; 0x04; 0x00; 0x00; + (* CMP (% rax) (Imm32 (word 1024)) *) + 0x75; 0x94; (* JNE (Imm8 (word 148)) *) + 0xc3 (* RET *) +];; +(*** BYTECODE END ***) + +let poly_use_hint_32_avx2_asm_tmc = + define_trimmed "poly_use_hint_32_avx2_asm_tmc" poly_use_hint_32_avx2_asm_mc;; + +let POLY_USE_HINT_32_AVX2_ASM_EXEC = + X86_MK_CORE_EXEC_RULE poly_use_hint_32_avx2_asm_tmc;; + +(* ------------------------------------------------------------------------- *) +(* Numeric (code-aligned) form of UseHint, matching the Barrett computation *) +(* performed by the assembly. Bridged to the FIPS 204 spec *) +(* (mldsa_use_hint_32 in mldsa_specs.ml) by MLDSA_USE_HINT_32_EQUIV below. *) +(* The decompose helper lemmas (A1_BOUND, A1_WRAP, BARRETT_INTERVAL_32) and *) +(* the DIV/MOD equivalence tactics are arch-independent and shared with the *) +(* AArch64 proof of the same routine. *) +(* ------------------------------------------------------------------------- *) + +let mldsa_use_hint_32_code = new_definition + `mldsa_use_hint_32_code (a:num) (h:num) = + let a1 = ((((a + 127) DIV 128) * 1025 + 2097152) DIV 4194304) MOD 16 in + let a0:int = &a - &a1 * &523776 in + let a0' = if a0 > &4190208 then a0 - &8380417 else a0 in + if h = 0 then a1 + else if a0' > &0 then (a1 + 1) MOD 16 + else (a1 + 15) MOD 16`;; + +let A1_BOUND = prove( + `!a. a < 8380417 + ==> ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 <= 16`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8380416 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `4194304` (SPEC `69205952` (SPEC `d * 1025 + 2097152` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 1025 <= 65472 * 1025` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +let A1_WRAP = prove( + `!a. 8118528 < a /\ a < 8380417 + ==> ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 = 16`, + GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN `16 <= ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304` + ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `a + 127` (SPEC `8118529 + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `4194304` (SPEC `d * 1025 + 2097152` (SPEC `67108977` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `63427 * 1025 <= d * 1025` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + MP_TAC(SPEC `a:num` A1_BOUND) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_ARITH_TAC);; + +let BARRETT_INTERVAL_32 = prove( + `!a lo hi k. + lo <= a /\ a <= hi /\ + k * 262144 <= (2 * lo * 1074791425) DIV 4294967296 + 131072 /\ + (2 * hi * 1074791425) DIV 4294967296 + 131072 < (k + 1) * 262144 /\ + k * 4194304 <= (lo + 127) DIV 128 * 1025 + 2097152 /\ + (hi + 127) DIV 128 * 1025 + 2097152 < (k + 1) * 4194304 + ==> ((2 * a * 1074791425) DIV 4294967296 + 131072) DIV 262144 = k /\ + ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 = k`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + CONJ_TAC THEN MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THENL + [CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(2 * lo * 1074791425) DIV 4294967296 + 131072` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 131072 <= y + 131072 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(2 * hi * 1074791425) DIV 4294967296 + 131072` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 131072 <= y + 131072 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]; + CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(lo + 127) DIV 128 * 1025 + 2097152` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 2097152 <= y + 2097152 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(hi + 127) DIV 128 * 1025 + 2097152` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 2097152 <= y + 2097152 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]]);; + +(* The per-variant divisor is 2*GAMMA2 = 523776. The generic Barrett DIV/MOD + tactics are shared from mldsa_utils.ml. *) +let LINEARIZE_DIV_MOD_TAC = LINEARIZE_DIV_MOD_BY_TAC 523776;; +let DIV_523776_TAC k = DIV_EQ_K_BY_TAC 523776 k;; +let DIV_MOD_TO_DIV_TAC = DIV_MOD_TO_DIV_BY_TAC 523776;; + +(* Lower half nowrap: dismiss wrap cond, reduce, prove r DIV 523776 = k *) +let DECOMPOSE_R1_LOWER_TAC = + SUBGOAL_THEN `~((&r:int) - &(r MOD 523776) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN LINEARIZE_DIV_MOD_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + DIV_MOD_TO_DIV_TAC THEN + CONV_TAC SYM_CONV THEN + LINEARIZE_DIV_MOD_TAC;; + +(* Upper half nowrap: dismiss wrap cond, reduce, prove r DIV 523776 + 1 = k *) +let DECOMPOSE_R1_UPPER_TAC = + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `~((&r:int) - (&(r MOD 523776) - &523776) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [REWRITE_TAC[INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`] THEN + ASM_SIMP_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + LINEARIZE_DIV_MOD_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 523776) - &523776) = + &(r - r MOD 523776 + 523776)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 523776 + 523776 = 523776 * (r DIV 523776 + 1)` + SUBST1_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`523776`; `r DIV 523776 + 1`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_THEN SUBST1_TAC THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 523776` (concl th) || + free_in `r DIV 523776` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 523776`, `m:num`) THEN + SPEC_TAC(`r DIV 523776`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +let DECOMPOSE_R1_NOWRAP_TAC = + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THEN + TRY DECOMPOSE_R1_LOWER_TAC THEN TRY DECOMPOSE_R1_UPPER_TAC;; + +let DECOMPOSE_32_R1_EQUIV = time prove( + `!r. r < 8380417 ==> + (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 = + decompose_32_r1 r`, + GEN_TAC THEN DISCH_TAC THEN + ASM_CASES_TAC `r <= 8118528` THENL + [ALL_TAC; + (* Wrap zone *) + SUBGOAL_THEN `8118528 < r` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `decompose_32_r1 r = 0` SUBST1_TAC THENL + [REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THENL + [(* Lower wrap: r DIV 523776 = 16 *) + SUBGOAL_THEN `r DIV 523776 = 16` ASSUME_TAC THENL + [DIV_523776_TAC 16; ALL_TAC] THEN + SUBGOAL_THEN `16 * 523776 + r MOD 523776 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + ASM_ARITH_TAC; + (* Upper wrap: r DIV 523776 = 15 *) + SUBGOAL_THEN `r DIV 523776 = 15` ASSUME_TAC THENL + [DIV_523776_TAC 15; ALL_TAC] THEN + SUBGOAL_THEN `15 * 523776 + r MOD 523776 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 523776) - &523776) = + &(r - r MOD 523776 + 523776)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC]; + ALL_TAC] THEN + MP_TAC(SPEC `r:num` A1_WRAP) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONV_TAC NUM_REDUCE_CONV] THEN + (* Nowrap zone: unfold and do interval cascade *) + REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + let intervals = [ + (0, 261888); (261889, 785664); (785665, 1309440); + (1309441, 1833216); (1833217, 2356992); (2356993, 2880768); + (2880769, 3404544); (3404545, 3928320); (3928321, 4452096); + (4452097, 4975872); (4975873, 5499648); (5499649, 6023424); + (6023425, 6547200); (6547201, 7070976); (7070977, 7594752); + (7594753, 8118528)] in + let mk_le hi = + mk_comb(mk_comb(`(<=):num->num->bool`, mk_var("r",`:num`)), + mk_small_numeral hi) in + let apply_interval k (lo, hi) = + let th = SPECL [`r:num`; mk_small_numeral lo; + mk_small_numeral hi; mk_small_numeral k] + BARRETT_INTERVAL_32 in + MP_TAC th THEN CONV_TAC NUM_REDUCE_CONV THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + DECOMPOSE_R1_NOWRAP_TAC in + let rec cascade k = function + | [(lo,hi)] -> apply_interval k (lo,hi) + | (lo,hi)::rest -> + ASM_CASES_TAC (mk_le hi) THENL + [apply_interval k (lo,hi); cascade (k+1) rest] + | [] -> failwith "empty" in + cascade 0 intervals);; + +let R1_IS_DIV_LOWER = prove( + `!r. r < 8380417 /\ r MOD 523776 * 2 <= 523776 /\ + ~((&r:int) - &(r MOD 523776) = &8380416) ==> + (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 = r DIV 523776`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + MP_TAC(SPEC `r:num` LOWER_NONWRAP_R1) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +let R1_IS_DIV_PLUS1_UPPER = prove( + `!r. r < 8380417 /\ ~(r MOD 523776 * 2 <= 523776) /\ + ~((&r:int) - (&(r MOD 523776) - &523776) = &8380416) ==> + (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 = + r DIV 523776 + 1`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + MP_TAC(SPEC `r:num` UPPER_NONWRAP_R1) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +(* Upper nowrap: substitute Barrett = r DIV 523776 + 1, use INT_MOD_RESIDUE *) +let R0_SIGN_UPPER_NOWRAP_TAC = + MP_TAC(SPEC `r:num` R1_IS_DIV_PLUS1_UPPER) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `523776`] INT_MOD_RESIDUE)) THEN + REWRITE_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_MUL] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `(a:int) - (b + &1) * c = a - b * c - c`] THEN + REWRITE_TAC[INT_ARITH `x - &523776 > &0 <=> x > &523776`; + INT_ARITH `x - &523776 - &8380417 > &0 <=> x > &8904193`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +(* Lower nowrap: substitute Barrett = r DIV 523776, use INT_MOD_RESIDUE *) +let R0_SIGN_LOWER_NOWRAP_TAC = + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `523776`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +(* Wrap: derive 8118528 < r, use DECOMPOSE_32_R1_EQUIV to get Barrett = 0 *) +let R0_SIGN_WRAP_TAC = + SUBGOAL_THEN `8118528 < r` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + can (find_term (fun t -> t = `&8380416:int`)) (concl th) && + not(is_neg(concl th)))) THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ; + INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`; + GSYM INT_OF_NUM_ADD] THEN ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + REWRITE_TAC[INT_ARITH `x - &1 > &0 <=> x > &1`; + INT_ARITH `(x - &523776) - &1 > &0 <=> x > &523777`; + INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +let DECOMPOSE_32_R0_SIGN = time prove( + `!r. r < 8380417 ==> + let a1 = (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 in + let a0':int = if (&r:int) - &a1 * &523776 > &4190208 + then &r - &a1 * &523776 - &8380417 + else &r - &a1 * &523776 in + (decompose_32_r0 r > &0 <=> a0' > &0) /\ + (decompose_32_r0 r <= &0 <=> ~(a0' > &0))`, + GEN_TAC THEN DISCH_TAC THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[INT_ARITH `(x:int) <= &0 <=> ~(x > &0)`] THEN + MATCH_MP_TAC(TAUT `(p <=> q) ==> (p <=> q) /\ (~p <=> ~q)`) THEN + REWRITE_TAC[decompose_32_r0; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ONCE_REWRITE_TAC[COND_RAND] THEN REWRITE_TAC[SND; FST] THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + TRY R0_SIGN_LOWER_NOWRAP_TAC THEN + TRY R0_SIGN_UPPER_NOWRAP_TAC THEN + TRY R0_SIGN_WRAP_TAC THEN + TRY( + (* Contradiction: lower nowrap with > 4190208 *) + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV + (SPECL [`r:num`; `523776`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - &((((r + 127) DIV 128 * 1025 + 2097152) DIV + 4194304) MOD 16) * &523776 = &(r MOD 523776)` ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(&(r MOD 523776) > (&4190208:int))` MP_TAC THENL + [REWRITE_TAC[INT_NOT_LT; INT_OF_NUM_LE] THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN REWRITE_TAC[INT_OF_NUM_GT] THEN ASM_ARITH_TAC + ));; + +let MLDSA_USE_HINT_32_EQUIV = prove( + `!r h. r < 8380417 /\ h <= 1 + ==> mldsa_use_hint_32 h r = mldsa_use_hint_32_code r h`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[MLDSA_USE_HINT_32_UNFOLD] THEN + REWRITE_TAC[mldsa_use_hint_32_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R0_SIGN) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN STRIP_TAC THEN + ASM_CASES_TAC `h = 0` THENL + [ASM_REWRITE_TAC[ARITH_RULE `~(0 = 1)`]; ALL_TAC] THEN + SUBGOAL_THEN `h = 1` SUBST_ALL_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `decompose_32_r0 r > &0` THEN ASM_REWRITE_TAC[] THEN + ASM_MESON_TAC[]);; + +(* ------------------------------------------------------------------------- *) +(* Element-level correctness: the per-coefficient word computed by the *) +(* assembly equals the FIPS 204 UseHint of the input. *) +(* The assembly evaluates the code-aligned Barrett form directly *) +(* (vpmulhuw + vpmulhrsw), which agrees with mldsa_use_hint_32_code; the *) +(* code form equals the FIPS spec by MLDSA_USE_HINT_32_EQUIV. *) +(* ------------------------------------------------------------------------- *) + +(* Lane lemma: high 16 bits of the unsigned 16x16->32 multiply by 1025. *) +let LANE_MULHUW = prove(`!u:16 word. val u < 65536 ==> + val(word_subword (word_mul ((word_zx u):int32) ((word_zx (word 1025:16 word)):int32)) (16,16):16 word) + = (val u * 1025) DIV 65536`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[VAL_WORD_SUBWORD; DIMINDEX_16] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(word_mul ((word_zx (u:16 word)):int32) ((word_zx (word 1025:16 word)):int32)) = val u * 1025` SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD_MUL; VAL_WORD_ZX_GEN; DIMINDEX_16; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(u:16 word) MOD 4294967296 = val u` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN MP_TAC(ISPEC `u:16 word` VAL_BOUND) THEN REWRITE_TAC[DIMINDEX_16] THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(word 1025:16 word) = 1025` SUBST1_TAC THENL [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC MOD_LT THEN MP_TAC(ISPEC `u:16 word` VAL_BOUND) THEN REWRITE_TAC[DIMINDEX_16] THEN ASM_ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC MOD_LT THEN + SUBGOAL_THEN `val(u:16 word) * 1025 < 65536 * 65536` MP_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(65536 = 0)`] THEN ARITH_TAC);; + +(* Lane lemma: the vpmulhrsw rounding-multiply lane by 512, for a 16-bit input + below 1024 (the range of the m16 = round(t/B) intermediate). This is the + second Barrett step: a1_lane = ((m16 * 512) >> 14 + 1) >> 1. *) +let LANE_MULHRSW = prove(`!u:16 word. val u < 1024 ==> + val(word_subword (word_add (word_ushr (word_mul ((word_sx u):int32)((word_sx (word 512:16 word)):int32)) 14) (word 1:int32)) (1,16) :16 word) + = ((val u * 512) DIV 16384 + 1) DIV 2`, + GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `val((word_sx (u:16 word)):int32) = val u` ASSUME_TAC THENL + [MATCH_MP_TAC VAL_WORD_SX_SMALL THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val((word_sx (word 512:16 word)):int32) = 512` ASSUME_TAC THENL + [SUBGOAL_THEN `val(word 512:16 word) = 512` ASSUME_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + ASM_SIMP_TAC[VAL_WORD_SX_SMALL; ARITH_RULE `512 < 32768`]; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul ((word_sx (u:16 word)):int32)((word_sx (word 512:16 word)):int32)) = val u * 512` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD_MUL] THEN ASM_REWRITE_TAC[] THEN + MATCH_MP_TAC MOD_LT THEN REWRITE_TAC[DIMINDEX_32] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[VAL_WORD_SUBWORD; DIMINDEX_16] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(word_add (word_ushr (word_mul ((word_sx (u:16 word)):int32)((word_sx (word 512:16 word)):int32)) 14) (word 1:int32)) = (val u * 512) DIV 16384 + 1` SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_USHR; VAL_WORD; DIMINDEX_32] THEN ASM_REWRITE_TAC[] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC MOD_LT THEN + SUBGOAL_THEN `(val(u:16 word) * 512) DIV 16384 <= 31` MP_TAC THENL + [SUBGOAL_THEN `val(u:16 word) * 512 < 1024 * 512` MP_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(16384=0)`] THEN ARITH_TAC; ALL_TAC] THEN ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC MOD_LT THEN + SUBGOAL_THEN `(val(u:16 word) * 512) DIV 16384 + 1 <= 32` MP_TAC THENL + [SUBGOAL_THEN `(val(u:16 word) * 512) DIV 16384 <= 31` MP_TAC THENL + [SUBGOAL_THEN `val(u:16 word) * 512 < 1024 * 512` MP_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(16384=0)`] THEN ARITH_TAC; ALL_TAC] THEN ARITH_TAC; ALL_TAC] THEN + ARITH_TAC);; + +(* Arithmetic bridge: the assembly's two rounding steps (vpmulhuw by 1025 giving + m16 = (t*1025) DIV 2^16, then vpmulhrsw by 512 giving ((m16*512) DIV 2^14 + 1) DIV 2) + equal the single code-form Barrett division (t*1025 + 2^21) DIV 2^22. *) +let MUL_DIV_512 = prove(`!m. (m * 512) DIV 16384 = m DIV 32`, + GEN_TAC THEN REWRITE_TAC[ARITH_RULE `16384 = 512 * 32`] THEN + REWRITE_TAC[GSYM DIV_DIV] THEN AP_THM_TAC THEN AP_TERM_TAC THEN + ONCE_REWRITE_TAC[MULT_SYM] THEN SIMP_TAC[DIV_MULT; ARITH_RULE `~(512=0)`]);; + +let A1_TWOSTEP = prove(`!t. ((((t * 1025) DIV 65536) * 512) DIV 16384 + 1) DIV 2 = + (t * 1025 + 2097152) DIV 4194304`, + GEN_TAC THEN REWRITE_TAC[MUL_DIV_512] THEN + REWRITE_TAC[DIV_DIV] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(SPECL [`t * 1025`; `2097152`] ROUND_DIV) THEN + REWRITE_TAC[ARITH_RULE `~(2097152 = 0)`; ARITH_RULE `2 * 2097152 = 4194304`]);; + +(* Composed per-lane Barrett a1: the full x86 decompose a1 lane (pre-shift t via + VAL_T, vpmulhuw by 1025 via LANE_MULHUW, vpmulhrsw by 512 via LANE_MULHRSW) on + a coefficient a:int32 with val a < Q equals the code-form a1 Barrett value + ((val a + 127) DIV 128 * 1025 + 2097152) DIV 4194304. The (16,16)/(0,16) lane + selections pick the relevant 16-bit halves; the multipliers appear as raw int32 + numerals (word 1025)/(word 512), bridged to the word_zx/word_sx lane forms of the + helper lemmas by WORD_REDUCE_CONV. *) +let A1_LANE = prove(`!a:int32. val a < 8380417 ==> + val(word_subword + (word_add + (word_ushr + (word_mul + ((word_sx + (word_subword + (word_mul + ((word_zx + (word_subword + (word_ushr (word_add (word 127) a) 7) (0,16) + :16 word)):int32) + (word 1025:int32)) + (16,16) :16 word)):int32) + (word 512:int32)) + 14) + (word 1:int32)) + (1,16) :16 word) + = ((val a + 127) DIV 128 * 1025 + 2097152) DIV 4194304`, + GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `(word 1025:int32) = word_zx (word 1025:16 word)` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + SUBGOAL_THEN `(word 512:int32) = word_sx (word 512:16 word)` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + ABBREV_TAC `t = (val(a:int32) + 127) DIV 128` THEN + SUBGOAL_THEN `t < 65473` ASSUME_TAC THENL + [EXPAND_TAC "t" THEN + MATCH_MP_TAC(ARITH_RULE `x <= 8380543 ==> x DIV 128 < 65473`) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + ABBREV_TAC + `u0 = word_subword + (word_ushr (word_add (word 127) (a:int32)) 7) (0,16) :16 word` THEN + SUBGOAL_THEN `val(u0:16 word) = t` ASSUME_TAC THENL + [EXPAND_TAC "u0" THEN + REWRITE_TAC[VAL_WORD_SUBWORD; DIMINDEX_16] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(SPEC `a:int32` VAL_T) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[ARITH_RULE `2 EXP 0 = 1`; DIV_1] THEN + MATCH_MP_TAC MOD_LT THEN EXPAND_TAC "t" THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(u0:16 word) < 65536` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN + `val(word_subword + (word_mul ((word_zx (u0:16 word)):int32) + ((word_zx (word 1025:16 word)):int32)) (16,16) :16 word) + = (t * 1025) DIV 65536` + ASSUME_TAC THENL + [MP_TAC(SPEC `u0:16 word` LANE_MULHUW) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ABBREV_TAC + `m16 = word_subword + (word_mul ((word_zx (u0:16 word)):int32) + ((word_zx (word 1025:16 word)):int32)) (16,16) :16 word` THEN + SUBGOAL_THEN `val(m16:16 word) = (t * 1025) DIV 65536` ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `val(m16:16 word) < 1024` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN + SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(65536 = 0)`] THEN ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `m16:16 word` LANE_MULHRSW) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[] THEN + MATCH_ACCEPT_TAC A1_TWOSTEP);; + +(* ------------------------------------------------------------------------- *) +(* Decompose bound (no-wrap) and a0 upper bound, arch-independent num lemmas *) +(* shared with the AArch64 proof. *) +(* ------------------------------------------------------------------------- *) +let A1_BOUND_NOWRAP = prove( + `!a. a <= 8118528 + ==> ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 <= 15`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8118528 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `4194304` (SPEC `67108802` (SPEC `d * 1025 + 2097152` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 1025 <= 63426 * 1025` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +let A0_UPPER_32 = prove( + `!a. a <= 8118528 + ==> a < (((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 + 1) * 523776`, + GEN_TAC THEN DISCH_TAC THEN + ABBREV_TAC `nv = ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304` THEN + SUBGOAL_THEN `nv * 4194304 <= (a + 127) DIV 128 * 1025 + 2097152` ASSUME_TAC THENL + [EXPAND_TAC "nv" THEN + MP_TAC(SPECL [`(a + 127) DIV 128 * 1025 + 2097152`; `4194304`] (CONJUNCT1 DIVISION_SIMP)) THEN + ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(a + 127) DIV 128 <= 63426` ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `8118528 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv * 4194304 <= 63426 * 1025 + 2097152` ASSUME_TAC THENL + [SUBGOAL_THEN `(a + 127) DIV 128 * 1025 <= 63426 * 1025` MP_TAC THENL + [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; ASM_ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 15` ASSUME_TAC THENL + [CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_ARITH_TAC);; + +(* ------------------------------------------------------------------------- *) +(* The per-element x86 word model of UseHint (param 65/87), matching the *) +(* scalar form SIMD_SIMPLIFY produces for one coefficient lane, plus the *) +(* value/threshold lemmas used by ELEMENT_CORRECT. *) +(* ------------------------------------------------------------------------- *) +let mldsa_use_hint_32_x86_asm = new_definition + `mldsa_use_hint_32_x86_asm (a:int32) (h:int32) : int32 = + let t = word_ushr (word_add (word 127) a) 7 in + let m16 = word_subword + (word_mul (word_zx (word_subword t (0,16) :16 word) :int32) + (word 1025)) (16,16) :16 word in + let a1lane = word_subword + (word_add + (word_ushr (word_mul (word_sx m16 :int32) (word 512)) 14) + (word 1)) (1,16) :16 word in + let a1 = word_zx a1lane :int32 in + let m:int32 = (if word_igt a (word 8118528) then word 4294967295 + else word 0) in + let a1' = word_and a1 (word_not m) in + let a0 = word_add (word_sub a (word_mul a1 (word 523776))) m in + let delta:int32 = word_or (word_neg(word(bitval(word_ile a0 (word 0))))) + (word 1) in + word_and (word_add a1' (word_mul delta h)) (word 15)`;; + +let A1_LANE_VAL = prove( + `!a:int32. val a < 8380417 + ==> val(word_zx (word_subword + (word_add (word_ushr (word_mul (word_sx (word_subword + (word_mul (word_zx (word_subword (word_ushr (word_add (word 127) a) 7) (0,16) :16 word) :int32) + (word 1025:int32)) (16,16) :16 word) :int32) (word 512:int32)) 14) (word 1:int32)) (1,16) :16 word) :int32) + = ((val a + 127) DIV 128 * 1025 + 2097152) DIV 4194304`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[VAL_WORD_ZX_GEN; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN + `val(word_subword + (word_add (word_ushr (word_mul (word_sx (word_subword + (word_mul (word_zx (word_subword (word_ushr (word_add (word 127) (a:int32)) 7) (0,16) :16 word) :int32) + (word 1025:int32)) (16,16) :16 word) :int32) (word 512:int32)) 14) (word 1:int32)) (1,16) :16 word) + = ((val(a:int32) + 127) DIV 128 * 1025 + 2097152) DIV 4194304` + (fun th -> REWRITE_TAC[th]) THENL + [MATCH_MP_TAC A1_LANE THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + MATCH_MP_TAC MOD_LT THEN MP_TAC(SPEC `val(a:int32)` A1_BOUND) THEN ASM_REWRITE_TAC[] THEN ARITH_TAC);; + +let WORD_IGT_THRESHOLD_X86_32 = BITBLAST_RULE + `!a:int32. val a < 8380417 ==> (word_igt a (word 8118528:int32) <=> val a > 8118528)`;; + +let WRAP_A0_NEGATIVE_X86 = BITBLAST_RULE + `!a:int32. val a < 8380417 /\ val a > 8118528 + ==> bit 31 (word_add (word_sub a (word 8380416:int32)) (word 4294967295:int32))`;; + +(* Bare (no +word 0) sign form, used in the no-wrap a0-sign bridge. *) +let WORD_SUB_SIGN_BARE = BITBLAST_RULE + `!a:int32 b:int32. val b <= 7856640 /\ val a <= 8118528 ==> + ((bit 31 (word_sub a b) \/ word_sub a b = word 0) <=> val a <= val b)`;; + +(* Small word identities used to simplify the lane after the wrap/no-wrap mask + is resolved (these specific shapes are not in the standard word library). *) +let WORD_ADD_LID = WORD_RULE `!x:N word. word_add (word 0) x = x`;; +let WORD_SUB_RZERO = WORD_RULE `!x:N word. word_sub x (word 0) = x`;; + +(* word_ile against 0 reduces to "sign bit set or value zero". *) +let WORD_ILE_ZERO_X86_32 = WORD_ILE_ZERO_32;; + +(* ------------------------------------------------------------------------- *) +(* Element correctness (value form): the per-lane x86 word model equals the *) +(* scalar FIPS UseHint code for one coefficient, for val a < Q, val h <= 1. *) +(* Case split is wrap (a > 87*GAMMA2 => a1 lane = 16, masked 0) vs no-wrap. *) +(* ------------------------------------------------------------------------- *) +let ELEMENT_CORRECT = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> val(mldsa_use_hint_32_x86_asm a h) = mldsa_use_hint_32_code (val a) (val h)`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[mldsa_use_hint_32_x86_asm; mldsa_use_hint_32_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ABBREV_TAC `nv = ((val(a:int32) + 127) DIV 128 * 1025 + 2097152) DIV 4194304` THEN + ABBREV_TAC + `a1w:int32 = + word_zx (word_subword + (word_add (word_ushr (word_mul (word_sx (word_subword + (word_mul (word_zx (word_subword + (word_ushr (word_add (word 127) (a:int32)) 7) (0,16) :16 word) :int32) + (word 1025)) (16,16) :16 word) :int32) + (word 512)) 14) (word 1)) (1,16) :16 word) :int32` THEN + SUBGOAL_THEN `val(a1w:int32) = nv` ASSUME_TAC THENL + [EXPAND_TAC "a1w" THEN EXPAND_TAC "nv" THEN + MATCH_MP_TAC A1_LANE_VAL THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 16` ASSUME_TAC THENL + [MP_TAC(SPEC `val(a:int32)` A1_BOUND) THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `word_igt (a:int32) (word 8118528:int32) <=> val a > 8118528` + SUBST1_TAC THENL + [MP_TAC(SPEC `a:int32` WORD_IGT_THRESHOLD_X86_32) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_CASES_TAC `val(a:int32) > 8118528` THEN ASM_REWRITE_TAC[] THENL + [ + (* WRAP ZONE: nv = 16, a1 lane masked to 0. *) + SUBGOAL_THEN `nv = 16` SUBST_ALL_TAC THENL + [MP_TAC(SPEC `val(a:int32)` A1_WRAP) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `a1w:int32 = word 16` ASSUME_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ] THEN ASM_REWRITE_TAC[] THEN CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `word_and (word 16:int32) (word_not (word 4294967295)) = word 0` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + SUBGOAL_THEN `word_mul (word 16:int32) (word 523776) = word 8380416` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + REWRITE_TAC[WORD_ADD_LID] THEN + SUBGOAL_THEN `word_ile (word_add (word_sub (a:int32) (word 8380416)) (word 4294967295)) (word 0)` ASSUME_TAC THENL + [REWRITE_TAC[WORD_ILE_ZERO_X86_32] THEN DISJ1_TAC THEN + MP_TAC(SPEC `a:int32` WRAP_A0_NEGATIVE_X86) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN + ASM_CASES_TAC `val(h:int32) = 0` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV; + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + SUBGOAL_THEN `(&(val(a:int32)):int) > &4190208` ASSUME_TAC THENL + [MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_GT; INT_GT](ASSUME `val(a:int32) > 8118528`)) THEN INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `~((&(val(a:int32)) - &8380417:int) > &0)` ASSUME_TAC THENL + [MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT](ASSUME `val(a:int32) < 8380417`)) THEN INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[]] + ; + (* NO-WRAP ZONE: mask = 0, a1' = a1 lane = nv <= 15. *) + SUBGOAL_THEN `nv <= 15` ASSUME_TAC THENL + [MP_TAC(SPEC `val(a:int32)` A1_BOUND_NOWRAP) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `nv MOD 16 = nv` SUBST_ALL_TAC THENL + [MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[WORD_NOT_0; WORD_AND_REFL] THEN + REWRITE_TAC[WORD_MUL_0; WORD_SUB_RZERO; WORD_ADD_0; WORD_ADD_LID] THEN + SUBGOAL_THEN `val(word_mul (a1w:int32) (word 523776:int32)) = nv * 523776` + ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD_MUL; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `val(word 523776:int32) = 523776` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + MATCH_MP_TAC MOD_LT THEN + SUBGOAL_THEN `nv * 523776 <= 15 * 523776` MP_TAC THENL + [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv * 523776 <= 7856640` ASSUME_TAC THENL + [SUBGOAL_THEN `nv * 523776 <= 15 * 523776` MP_TAC THENL + [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul (a1w:int32) (word 523776:int32)) <= 7856640` + ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN + `word_ile (word_sub (a:int32) (word_mul a1w (word 523776:int32))) (word 0) + <=> ~(&(val a) - &nv * &523776 > &0)` SUBST1_TAC THENL + [REWRITE_TAC[WORD_ILE_ZERO_X86_32] THEN + MP_TAC(ISPECL [`a:int32`; `word_mul (a1w:int32) (word 523776:int32)`] + WORD_SUB_SIGN_BARE) THEN + ASM_REWRITE_TAC[] THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(a:int32) <= nv * 523776` THENL + [ASM_REWRITE_TAC[] THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_MUL] + (REWRITE_RULE[GSYM INT_OF_NUM_LE] + (ASSUME `val(a:int32) <= nv * 523776`))) THEN INT_ARITH_TAC; + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `nv * 523776 < val(a:int32)` ASSUME_TAC THENL + [UNDISCH_TAC `~(val(a:int32) <= nv * 523776)` THEN ARITH_TAC; ALL_TAC] THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_MUL] + (REWRITE_RULE[GSYM INT_OF_NUM_LT] + (ASSUME `nv * 523776 < val(a:int32)`))) THEN INT_ARITH_TAC]; + ALL_TAC] THEN + SUBGOAL_THEN `~(int_gt (&(val(a:int32)) - &nv * &523776) (&4190208))` + ASSUME_TAC THENL + [REWRITE_TAC[INT_GT; INT_NOT_LT] THEN + MP_TAC(SPEC `val(a:int32)` A0_UPPER_32) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_MUL; + GSYM INT_OF_NUM_ADD] (ASSUME `val(a:int32) < (nv + 1) * 523776`)) THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_GT] THEN + RULE_ASSUM_TAC(REWRITE_RULE[INT_GT]) THEN ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(h:int32) = 0` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[WORD_MUL_0; WORD_ADD_0; WORD_AND_ONES_32] THEN + REWRITE_TAC[VAL_WORD_AND_15_32] THEN ASM_REWRITE_TAC[] THEN + MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[WORD_MUL_1_32] THEN + ASM_CASES_TAC `&0 < &(val(a:int32)) - &nv * &523776` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `~(&(val(a:int32)) - &nv * &523776 > &0) <=> F` SUBST1_TAC THENL + [REWRITE_TAC[INT_GT] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[VAL_WORD_AND_15_32; VAL_WORD_ADD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `16 = 2 EXP 4`; ARITH_RULE `4294967296 = 2 EXP 32`; + MOD_MOD_EXP_MIN] THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[VAL_WORD_1]; + SUBGOAL_THEN `~(&(val(a:int32)) - &nv * &523776 > &0) <=> T` SUBST1_TAC THENL + [REWRITE_TAC[INT_GT] THEN POP_ASSUM MP_TAC THEN INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[VAL_WORD_AND_15_32; VAL_WORD_ADD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `16 = 2 EXP 4`; ARITH_RULE `4294967296 = 2 EXP 32`; + MOD_MOD_EXP_MIN] THEN CONV_TAC NUM_REDUCE_CONV THEN + CONV_TAC(ONCE_DEPTH_CONV WORD_RED_CONV) THEN + REWRITE_TAC[ARITH_RULE `4294967295 = 15 + 268435455 * 16`; + ARITH_RULE `n + (15 + 268435455 * 16) = (n + 15) + 268435455 * 16`; + MOD_MULT_ADD]]] THEN + ASM_REWRITE_TAC[] THEN ASM_INT_ARITH_TAC);; + +(* Word form, directly usable in the loop body discharge. *) +let ELEMENT_CORRECT_WORD = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> mldsa_use_hint_32_x86_asm a h = word(mldsa_use_hint_32_code (val a) (val h))`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + GEN_REWRITE_TAC LAND_CONV [GSYM WORD_VAL] THEN + AP_TERM_TAC THEN + MP_TAC(SPECL [`a:int32`; `h:int32`] ELEMENT_CORRECT) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN(fun th -> REWRITE_TAC[th]));; + +(* ------------------------------------------------------------------------- *) +(* Per-lane bridge from the raw SIMD-simplified body output to the lane model. *) +(* These convert the assembly's surface encodings (shift-based multiply, the *) +(* zero high half of the 16x16->32 multiply, the h-(andnot dlt h)<<1 delta) *) +(* into the mldsa_use_hint_32_x86_asm model form. *) +(* ------------------------------------------------------------------------- *) + +(* a1 * 2*GAMMA2 = a1*523776 computed as (a1<<10 - a1)<<9. *) +let SHL_523776 = BITBLAST_RULE + `!a:int32. word_shl (word_sub (word_shl a 10) a) 9 = word_mul a (word 523776)`;; + +(* The high 16-bit half of the VPMULHRSW lane uses multiplier (word 0), so it + contributes nothing: the high a1 lane is zero. *) +let A1HI_ZERO = BITBLAST_RULE + `word_subword + (word_add + (word_ushr + (word_mul + (word_sx (word_subword + (word_mul (word_zx (word_subword + (word_ushr (word_add (word 127) (x:int32)) 7) (16,16) :16 word) :int32) + (word 0)) (16,16) :16 word) :int32) + (word 0)) 14) + (word 1)) (1,16) :16 word = word 0`;; + +(* Final commutativity closer used after the surface rewrites line up the lane. *) +let LANE_AC_CLOSE = prove + (`!d m a:int32. word_and (word_add d (word_and m a)) (word 15) = + word_and (word_add (word_and a m) d) (word 15)`, + REPEAT GEN_TAC THEN CONV_TAC WORD_BLAST);; + +(* Distribution of the final vpand-by-15 (a 256-bit broadcast) over the SIMD + word_join tree, pushing the &15 mask down to each 32-bit lane. Used to line + up the assembly's whole-vector mask with the per-lane &15 of the lane model. *) +let ANDDUP_256 = prove + (`!a b:int128. word_and (word_join a b:int256) (word_duplicate (word 15:int32)) = + word_join (word_and a (word_duplicate (word 15:int32):int128)) + (word_and b (word_duplicate (word 15:int32):int128))`, + REPEAT GEN_TAC THEN CONV_TAC WORD_BLAST);; + +let ANDDUP_128 = prove + (`!a b:int64. word_and (word_join a b:int128) (word_duplicate (word 15:int32)) = + word_join (word_and a (word_duplicate (word 15:int32):int64)) + (word_and b (word_duplicate (word 15:int32):int64))`, + REPEAT GEN_TAC THEN CONV_TAC WORD_BLAST);; + +let ANDDUP_64 = prove + (`!a b:int32. word_and (word_join a b:int64) (word_duplicate (word 15:int32)) = + word_join (word_and a (word_duplicate (word 15:int32):int32)) + (word_and b (word_duplicate (word 15:int32):int32))`, + REPEAT GEN_TAC THEN CONV_TAC WORD_BLAST);; + +let DUP15_32 = prove + (`word_duplicate (word 15:int32):int32 = word 15`, CONV_TAC WORD_BLAST);; + +(* ------------------------------------------------------------------------- *) +(* Discharge of the 256-bit block store: rewrite the raw SIMD result back to *) +(* simd8 mldsa_use_hint_32_x86_asm on the input lanes, using the per-lane *) +(* hint bounds. *) +(* ------------------------------------------------------------------------- *) + +(* word_subword distributes through word_not on each 32-bit lane. *) +let UH32_WSN = map (fun n -> prove( + subst [mk_small_numeral n, `n:num`] + `!z:int256. word_subword(word_not z) (n,32):int32 = word_not(word_subword z (n,32))`, + GEN_TAC THEN MATCH_MP_TAC WORD_SUBWORD_NOT THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_256] THEN ARITH_TAC)) [0;32;64;96;128;160;192;224];; + +(* Discharge of the 256-bit store equality: distribute the final &15 mask to + each 32-bit lane, isolate the eight lanes, and rewrite each raw lane to the + lane model via the LANE_USE_HINT bridge (built on the fly from the goal's own + lane 0, instantiated at each of the eight 32-bit offsets and discharged from + the per-lane hint bound). *) +let UH32_STORE_DISCHARGE_TAC : tactic = + fun (asl,w) -> + if not(is_eq w) || + not(can (find_term (fun t -> try fst(dest_const(fst(strip_comb t)))="simd8" with _->false)) w) + then failwith "not the store goal" else + let bignum = rand (lhs w) in + let dup15_lit = prove(mk_eq(bignum,`word_duplicate (word 15:int32):int256`), CONV_TAC WORD_BLAST) in + let phase1 = + REWRITE_TAC[dup15_lit] THEN REWRITE_TAC[ANDDUP_256;ANDDUP_128;ANDDUP_64;DUP15_32] THEN + REWRITE_TAC([WORD_SUBWORD_AND;WORD_SUBWORD_OR] @ UH32_WSN) THEN + CONV_TAC(DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) in + let phase2 (asl2,w2) = + let lhsI = lhs w2 in + let rec deepest t = if (try fst(dest_const(fst(strip_comb t)))="word_join" with _->false) then deepest(rand t) else t in + let lane0_raw = deepest lhsI in + let a1lane_x = `word_subword (word_add (word_ushr (word_mul (word_sx (word_subword (word_mul (word_zx (word_subword (word_ushr (word_add (word 127) (x:int32)) 7) (0,16) :16 word) :int32) (word 1025)) (16,16) :16 word) :int32) (word 512)) 14) (word 1)) (1,16) :16 word` in + let a0term = `word_add (word_sub x (word_mul (A:int32) (word 523776))) (if word_igt x (word 8118528) then word 4294967295 else word 0)` in + let lane0_fn = subst [`x:int32`,`word_subword (av:int256) (0,32):int32`; `y:int32`,`word_subword (hv:int256) (0,32):int32`] lane0_raw in + let luh = prove( + mk_imp(`val(y:int32) <= 1`, mk_eq(lane0_fn, list_mk_comb(`mldsa_use_hint_32_x86_asm`,[`x:int32`;`y:int32`]))), + DISCH_TAC THEN REWRITE_TAC[mldsa_use_hint_32_x86_asm] THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[SHL_523776] THEN REWRITE_TAC[A1HI_ZERO; JOIN_ZERO_ZX] THEN + ABBREV_TAC (mk_eq(`A:int32`, mk_comb(`word_zx:(16)word->int32`, a1lane_x))) THEN + MP_TAC(SPECL [a0term; `y:int32`] DELTA_EQ_BOUNDED) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN(fun th -> REWRITE_TAC[th]) THEN MATCH_ACCEPT_TAC LANE_AC_CLOSE) in + let lane_gen = GENL [`x:int32`;`y:int32`] luh in + let bound_hyp = snd(find (fun (_,th) -> try concl th = `!k. k < 8 ==> val(word_subword (hv:int256) (32*k,32):int32) <= 1` with _->false) asl2) in + let bridges = map (fun k -> + let off = mk_small_numeral (32*k) in + let xk = mk_comb(mk_comb(`word_subword:int256->num#num->int32`,`av:int256`),mk_pair(off,`32`)) in + let yk = mk_comb(mk_comb(`word_subword:int256->num#num->int32`,`hv:int256`),mk_pair(off,`32`)) in + let raw_bnd = MP (SPEC (mk_small_numeral k) bound_hyp) + (EQT_ELIM(NUM_REDUCE_CONV (mk_binop `(<):num->num->bool` (mk_small_numeral k) `8`))) in + let raw_bnd = CONV_RULE(ONCE_DEPTH_CONV NUM_REDUCE_CONV) raw_bnd in + MP (SPECL [xk;yk] lane_gen) raw_bnd) (0--7) in + (REWRITE_TAC[simd8;simd4;simd2;DIMINDEX_32] THEN CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + REWRITE_TAC bridges THEN + CONV_TAC(DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV)) (asl2,w2) in + (phase1 THEN phase2) (asl,w);; + +(* Broadcast constants as 256-bit duplicates of their 32-bit lane value. *) +let DUPLITS = map (fun (n,c) -> prove(mk_eq(mk_comb(`word:num->int256`, mk_numeral(Num.num_of_string n)), + mk_comb(`word_duplicate:int32->int256`, c)), CONV_TAC WORD_BLAST)) + ["3423913227525323174502430081042878883520180111764122672559515536195711", `word 127:int32`; + "27633945340263435069803077425739770516599878854789179050185066335437825", `word 1025:int32`; + "13803492696795003664135781114125621955608915096245911876775369720726016", `word 512:int32`; + "218875081946729975600369013236132924539112762223623301674088649976692072704", `word 8118528:int32`; + "404399200101416122972727962327899080730729934460329449514903409786895", `word 15:int32`];; + +(* ------------------------------------------------------------------------- *) +(* Loop body (one iteration): block i is read, eight UseHints are computed *) +(* in the YMM lanes, and the corrected block is stored back; pointers and the *) +(* loop counter advance by one block. Blocks below i (already done) and at *) +(* or above i+1 (untouched) are preserved by the single 256-bit store. *) +(* ------------------------------------------------------------------------- *) + +let POLY_USE_HINT_32_AVX2_ASM_BODY_BLOCK_TAC : tactic = + REPEAT STRIP_TAC THEN + ENSURES_INIT_TAC "s0" THEN + MP_TAC(SPECL [`a:int64`;`i:num`] ALIGNED_BLOCK) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + MP_TAC(SPECL [`h:int64`;`i:num`] ALIGNED_BLOCK) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + SUBGOAL_THEN `read (memory :> bytes256 (word_add a (word(32*i)))) s0 = xb i` ASSUME_TAC THENL + [FIRST_ASSUM(fun th -> + if can (term_match [] + `!b. i <= b /\ b < 32 ==> read(memory :> bytes256(word_add a (word(32*b)))) s0 = xb b`) (concl th) + then MP_TAC(SPEC `i:num` th) else failwith "no") THEN + ANTS_TAC THENL [ASM_ARITH_TAC; DISCH_THEN ACCEPT_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `read (memory :> bytes256 (word_add h (word(32*i)))) s0 = yb i` ASSUME_TAC THENL + [FIRST_ASSUM(fun th -> + if can (term_match [] + `!b. b < 32 ==> read(memory :> bytes256(word_add h (word(32*b)))) s0 = yb b`) (concl th) + then MP_TAC(SPEC `i:num` th) else failwith "no") THEN + ANTS_TAC THENL [ASM_ARITH_TAC; DISCH_THEN ACCEPT_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `!k. k < 8 ==> val(word_subword ((yb:num->int256) i) (32*k,32):int32) <= 1` ASSUME_TAC THENL + [GEN_TAC THEN DISCH_TAC THEN + FIRST_ASSUM(fun th -> if can (term_match [] + `!b k. b < 32 /\ k < 8 ==> val(word_subword ((yb:num->int256) b) (32*k,32):int32) <= 1`) (concl th) + then MP_TAC(SPECL [`i:num`;`k:num`] th) else failwith "no") THEN + ANTS_TAC THENL [ASM_ARITH_TAC; DISCH_THEN ACCEPT_TAC]; ALL_TAC] THEN + SUBGOAL_THEN + `!b. i+1 <= b /\ b < 32 ==> read(memory :> bytes256(word_add a (word(32*b)))) s0 = xb b` + ASSUME_TAC THENL + [REPEAT STRIP_TAC THEN + FIRST_ASSUM(fun th -> if can (term_match [] + `!b. i <= b /\ b < 32 ==> read(memory :> bytes256(word_add a (word(32*b)))) s0 = xb b`) (concl th) + then MP_TAC(SPEC `b:num` th) else failwith "no") THEN + ANTS_TAC THENL [ASM_ARITH_TAC; DISCH_THEN ACCEPT_TAC]; ALL_TAC] THEN + EVERY (map (fun n -> X86_STEPS_TAC POLY_USE_HINT_32_AVX2_ASM_EXEC [n] THEN SIMD_SIMPLIFY_TAC[]) (1--24)) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + REPEAT CONJ_TAC THEN + TRY(REWRITE_TAC[ARITH_RULE `32 * (i + 1) = 32 * i + 32`] THEN CONV_TAC WORD_RULE) THEN + TRY(X_GEN_TAC `b:num` THEN DISCH_TAC THEN ASM_CASES_TAC `b < i` THENL + [UNDISCH_TAC `b:num < i` THEN + UNDISCH_TAC `!b. b < i ==> read (memory :> bytes256 (word_add a (word (32 * b)))) s24 = simd8 mldsa_use_hint_32_x86_asm (xb b) (yb b)` THEN + MESON_TAC[]; + SUBGOAL_THEN `b:num = i` SUBST_ALL_TAC THENL + [UNDISCH_TAC `b:num < i + 1` THEN UNDISCH_TAC `~(b:num < i)` THEN ARITH_TAC; ALL_TAC] THEN + FIRST_X_ASSUM(fun th -> + try let l,_ = dest_eq (concl th) in + if l = `read (memory :> bytes256 (word_add a (word (32 * i)))) s24` + then SUBST1_TAC th else failwith "no" + with _ -> failwith "no") THEN + ABBREV_TAC `av:int256 = (xb:num->int256) i` THEN + ABBREV_TAC `hv:int256 = (yb:num->int256) i` THEN + UH32_STORE_DISCHARGE_TAC]) THEN + TRY(REWRITE_TAC[VAL_WORD_ADD; VAL_WORD; DIMINDEX_64] THEN + FIRST_ASSUM(fun th -> if concl th = `i < 32` then MP_TAC th else failwith "no") THEN + SPEC_TAC(`i:num`,`i:num`) THEN + CONV_TAC EXPAND_CASES_CONV THEN CONV_TAC NUM_REDUCE_CONV);; + +(* ------------------------------------------------------------------------- *) +(* Block-function correctness: the routine maps each 256-bit input block *) +(* through the SIMD UseHint, over 32 loop iterations. *) +(* ------------------------------------------------------------------------- *) + +let POLY_USE_HINT_32_AVX2_ASM_BLOCK_CORRECT = prove + (`!a h xb yb pc. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, 0xbc) (a, 1024) /\ nonoverlapping (a, 1024) (h, 1024) /\ + (!b k. b < 32 /\ k < 8 ==> val(word_subword (yb b:int256) (32*k,32):int32) <= 1) + ==> ensures x86 + (\s. bytes_loaded s (word pc) (BUTLAST poly_use_hint_32_avx2_asm_tmc) /\ + read RIP s = word pc /\ + C_ARGUMENTS [a; h] s /\ + (!b. b < 32 ==> read(memory :> bytes256(word_add a (word(32 * b)))) s = xb b) /\ + (!b. b < 32 ==> read(memory :> bytes256(word_add h (word(32 * b)))) s = yb b)) + (\s. read RIP s = word(pc + 0xbc) /\ + (!b. b < 32 ==> + read(memory :> bytes256(word_add a (word(32 * b)))) s = + simd8 mldsa_use_hint_32_x86_asm (xb b) (yb b))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, MAYCHANGE [memory :> bytes(a, 1024)])`, + MAP_EVERY X_GEN_TAC [`a:int64`;`h:int64`;`xb:num->int256`;`yb:num->int256`;`pc:num`] THEN + REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; C_ARGUMENTS; NONOVERLAPPING_CLAUSES; ALL; + fst POLY_USE_HINT_32_AVX2_ASM_EXEC] THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN REWRITE_TAC[SOME_FLAGS] THEN + ENSURES_WHILE_PUP_TAC `32` `pc + 0x50` `pc + 0xba` + `\i s. + (read RDI s = word_add a (word(32 * i)) /\ + read RSI s = word_add h (word(32 * i)) /\ + read RAX s = word(32 * i) /\ + read YMM5 s = (word_duplicate (word 127:int32):int256) /\ + read YMM8 s = (word_duplicate (word 1025:int32):int256) /\ + read YMM7 s = (word_duplicate (word 512:int32):int256) /\ + read YMM4 s = (word_duplicate (word 8118528:int32):int256) /\ + read YMM6 s = (word 0:int256) /\ + read YMM3 s = (word_duplicate (word 15:int32):int256) /\ + (!b. b < 32 ==> read(memory :> bytes256(word_add h (word(32 * b)))) s = yb b) /\ + (!b. i <= b /\ b < 32 + ==> read(memory :> bytes256(word_add a (word(32 * b)))) s = xb b) /\ + (!b. b < i + ==> read(memory :> bytes256(word_add a (word(32 * b)))) s = + simd8 mldsa_use_hint_32_x86_asm (xb b) (yb b))) + /\ + (read ZF s <=> i = 32)` THEN + REWRITE_TAC[ARITH_RULE `~(32 = 0)`] THEN CONV_TAC(ONCE_DEPTH_CONV NUM_REDUCE_CONV) THEN + REPEAT CONJ_TAC THENL + [ + (* INIT: run the constant-setup block to the loop top. *) + REWRITE_TAC[MULT_CLAUSES; WORD_ADD_0] THEN + ENSURES_INIT_TAC "s0" THEN + X86_STEPS_TAC POLY_USE_HINT_32_AVX2_ASM_EXEC (1--17) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC DUPLITS THEN + REWRITE_TAC[ARITH_RULE `b < 0 <=> F`; LE_0] THEN ASM_REWRITE_TAC[] + ; + (* BODY *) + POLY_USE_HINT_32_AVX2_ASM_BODY_BLOCK_TAC + ; + (* BACKEDGE *) + REPEAT STRIP_TAC THEN X86_SIM_TAC POLY_USE_HINT_32_AVX2_ASM_EXEC (1--1) + ; + (* EXIT: the invariant at i = 32 is the postcondition. *) + REWRITE_TAC[ARITH_RULE `32 <= b /\ b < 32 <=> F`] THEN + ENSURES_INIT_TAC "s0" THEN + X86_STEPS_TAC POLY_USE_HINT_32_AVX2_ASM_EXEC (1--1) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] + ]);; + +(* ------------------------------------------------------------------------- *) +(* Core correctness theorem (coefficient form, FIPS 204 UseHint). *) +(* Input bounds appear as preconditions; the result is stated directly. *) +(* This must be kept in sync with the CBMC specification in *) +(* mldsa/src/native/x86_64/src/arith_native_x86_64.h *) +(* ------------------------------------------------------------------------- *) + +let POLY_USE_HINT_32_AVX2_ASM_CORRECT = prove + (`!a h x y pc. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, 0xbc) (a, 1024) /\ nonoverlapping (a, 1024) (h, 1024) + ==> ensures x86 + (\s. bytes_loaded s (word pc) (BUTLAST poly_use_hint_32_avx2_asm_tmc) /\ + read RIP s = word pc /\ + C_ARGUMENTS [a; h] s /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add h (word(4 * i)))) s = y i) /\ + (!i. i < 256 ==> val(x i:int32) < 8380417) /\ + (!i. i < 256 ==> val(y i:int32) <= 1)) + (\s. read RIP s = word(pc + 0xbc) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = + word(mldsa_use_hint_32 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add a (word(4 * i)))) s) < 16)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, MAYCHANGE [memory :> bytes(a, 1024)])`, + MAP_EVERY X_GEN_TAC [`a:int64`;`h:int64`;`x:num->int32`;`y:num->int32`;`pc:num`] THEN + STRIP_TAC THEN + ASM_CASES_TAC `(!i. i < 256 ==> val(x i:int32) < 8380417) /\ + (!i. i < 256 ==> val(y i:int32) <= 1)` + THENL + [FIRST_X_ASSUM(CONJUNCTS_THEN ASSUME_TAC) THEN + MATCH_MP_TAC ENSURES_PREPOSTCONDITION_THM THEN + MAP_EVERY EXISTS_TAC + [`\s. bytes_loaded s (word pc) (BUTLAST poly_use_hint_32_avx2_asm_tmc) /\ + read RIP s = word pc /\ + C_ARGUMENTS [a; h] s /\ + (!b. b < 32 ==> read(memory :> bytes256(word_add a (word(32 * b)))) s = pack8 x b) /\ + (!b. b < 32 ==> read(memory :> bytes256(word_add h (word(32 * b)))) s = pack8 y b)`; + `\s. read RIP s = word(pc + 0xbc) /\ + (!b. b < 32 ==> + read(memory :> bytes256(word_add a (word(32 * b)))) s = + simd8 mldsa_use_hint_32_x86_asm (pack8 x b) (pack8 y b))`] THEN + CONJ_TAC THENL + [ + (* precondition strengthening: coefficient reads ==> block reads *) + X_GEN_TAC `s:x86state` THEN REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[] THEN + CONJ_TAC THEN X_GEN_TAC `b:num` THEN DISCH_TAC THEN + MATCH_MP_TAC PACK8_MERGE THEN ASM_REWRITE_TAC[] + ; + CONJ_TAC THENL + [ + (* postcondition weakening: block result ==> coefficient FIPS result + bound *) + X_GEN_TAC `s:x86state` THEN REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN + `!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = + mldsa_use_hint_32_x86_asm (x i) (y i)` + ASSUME_TAC THENL + [X_GEN_TAC `i:num` THEN DISCH_TAC THEN + SUBGOAL_THEN `4 * i = 4 * (8 * (i DIV 8) + i MOD 8)` SUBST1_TAC THENL + [AP_TERM_TAC THEN ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`a:int64`;`s:x86state`;`i DIV 8`;`i MOD 8`] BLOCK_SPLIT) THEN + ANTS_TAC THENL [SIMP_TAC[DIVISION; ARITH_EQ]; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN + FIRST_X_ASSUM(fun th -> MP_TAC(SPEC `i DIV 8` th)) THEN + ANTS_TAC THENL [UNDISCH_TAC `i:num < 256` THEN ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN + MP_TAC(SPECL [`mldsa_use_hint_32_x86_asm`;`pack8 x (i DIV 8)`;`pack8 y (i DIV 8)`;`i MOD 8`] SIMD8_LANE) THEN + ANTS_TAC THENL [SIMP_TAC[DIVISION; ARITH_EQ]; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN + MP_TAC(SPECL [`x:num->int32`;`i DIV 8`;`i MOD 8`] PACK8_LANE) THEN + ANTS_TAC THENL [SIMP_TAC[DIVISION; ARITH_EQ]; ALL_TAC] THEN + MP_TAC(SPECL [`y:num->int32`;`i DIV 8`;`i MOD 8`] PACK8_LANE) THEN + ANTS_TAC THENL [SIMP_TAC[DIVISION; ARITH_EQ]; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN DISCH_THEN SUBST1_TAC THEN + SUBGOAL_THEN `8 * (i DIV 8) + i MOD 8 = i` SUBST1_TAC THENL + [ARITH_TAC; REFL_TAC]; ALL_TAC] THEN + CONJ_TAC THENL + [ + X_GEN_TAC `i:num` THEN DISCH_TAC THEN + FIRST_X_ASSUM(fun th -> MP_TAC(SPEC `i:num` th)) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN + SUBGOAL_THEN `val(x (i:num):int32) < 8380417 /\ val(y (i:num):int32) <= 1` STRIP_ASSUME_TAC THENL + [ASM_SIMP_TAC[]; ALL_TAC] THEN + MP_TAC(SPECL [`x (i:num):int32`;`y (i:num):int32`] ELEMENT_CORRECT_WORD) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + AP_TERM_TAC THEN + MP_TAC(SPECL [`val(x (i:num):int32)`;`val(y (i:num):int32)`] MLDSA_USE_HINT_32_EQUIV) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN(fun th -> REWRITE_TAC[th]) + ; + X_GEN_TAC `i:num` THEN DISCH_TAC THEN + FIRST_X_ASSUM(fun th -> MP_TAC(SPEC `i:num` th)) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN + SUBGOAL_THEN `val(x (i:num):int32) < 8380417 /\ val(y (i:num):int32) <= 1` STRIP_ASSUME_TAC THENL + [ASM_SIMP_TAC[]; ALL_TAC] THEN + MP_TAC(SPECL [`x (i:num):int32`;`y (i:num):int32`] ELEMENT_CORRECT_WORD) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_REDUCE_CONV) THEN + MATCH_MP_TAC(ARITH_RULE `n < 16 ==> n MOD 4294967296 < 16`) THEN + REWRITE_TAC[mldsa_use_hint_32_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REPEAT(COND_CASES_TAC THEN ASM_REWRITE_TAC[]) THEN + REWRITE_TAC[MOD_LT_EQ; ARITH_EQ] + ] + ; + (* the block-function correctness specialised at xb = pack8 x, yb = pack8 y *) + MATCH_MP_TAC POLY_USE_HINT_32_AVX2_ASM_BLOCK_CORRECT THEN + ASM_REWRITE_TAC[] THEN REPEAT STRIP_TAC THEN + MP_TAC(SPECL [`y:num->int32`;`b:num`;`k:num`] PACK8_LANE) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + FIRST_X_ASSUM(fun th -> MP_TAC(SPEC `8 * b + k` th)) THEN + ANTS_TAC THENL [UNDISCH_TAC `b:num < 32` THEN UNDISCH_TAC `k:num < 8` THEN ARITH_TAC; + REWRITE_TAC[]] + ]] + ; + MATCH_MP_TAC ENSURES_PRECONDITION_THM THEN + EXISTS_TAC `\s:x86state. F` THEN + REWRITE_TAC[ENSURES_TRIVIAL] THEN + GEN_TAC THEN POP_ASSUM MP_TAC THEN MESON_TAC[]]);; + +(* ========================================================================= *) +(* Public subroutine correctness (with return). *) +(* ========================================================================= *) + +let POLY_USE_HINT_32_AVX2_ASM_NOIBT_SUBROUTINE_CORRECT = prove + (`!a h x y pc stackpointer returnaddress. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, LENGTH poly_use_hint_32_avx2_asm_tmc) (a, 1024) /\ + nonoverlapping (a, 1024) (h, 1024) /\ + nonoverlapping (stackpointer, 8) (a, 1024) + ==> ensures x86 + (\s. bytes_loaded s (word pc) poly_use_hint_32_avx2_asm_tmc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [a; h] s /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add h (word(4 * i)))) s = y i) /\ + (!i. i < 256 ==> val(x i:int32) < 8380417) /\ + (!i. i < 256 ==> val(y i:int32) <= 1)) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = + word(mldsa_use_hint_32 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add a (word(4 * i)))) s) < 16)) + (MAYCHANGE [RSP] ,, MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(a, 1024)])`, + X86_PROMOTE_RETURN_NOSTACK_TAC poly_use_hint_32_avx2_asm_tmc POLY_USE_HINT_32_AVX2_ASM_CORRECT);; + +let POLY_USE_HINT_32_AVX2_ASM_SUBROUTINE_CORRECT = prove + (`!a h x y pc stackpointer returnaddress. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, LENGTH poly_use_hint_32_avx2_asm_mc) (a, 1024) /\ + nonoverlapping (a, 1024) (h, 1024) /\ + nonoverlapping (stackpointer, 8) (a, 1024) + ==> ensures x86 + (\s. bytes_loaded s (word pc) poly_use_hint_32_avx2_asm_mc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [a; h] s /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add h (word(4 * i)))) s = y i) /\ + (!i. i < 256 ==> val(x i:int32) < 8380417) /\ + (!i. i < 256 ==> val(y i:int32) <= 1)) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = + word(mldsa_use_hint_32 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add a (word(4 * i)))) s) < 16)) + (MAYCHANGE [RSP] ,, MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(a, 1024)])`, + MATCH_ACCEPT_TAC(ADD_IBT_RULE POLY_USE_HINT_32_AVX2_ASM_NOIBT_SUBROUTINE_CORRECT));; + +(* ========================================================================= *) +(* Constant-time and memory safety proof. *) +(* ========================================================================= *) + +needs "s2n_bignum/x86/proofs/consttime.ml";; +needs "mldsa_native/x86_64/proofs/subroutine_signatures.ml";; + +let NORMALIZE_AND_EXPAND_YMM_TAC : tactic = + RULE_ASSUM_TAC(REWRITE_RULE[WORD_ADD_0]) THEN EXPAND_MAYCHANGE_YMM_REGS_TAC;; + +let full_spec,public_vars = mk_safety_spec + ~keep_maychanges:true + (assoc "mldsa_poly_use_hint_32_x86" subroutine_signatures) + (REWRITE_RULE[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; SOME_FLAGS] + POLY_USE_HINT_32_AVX2_ASM_CORRECT) + POLY_USE_HINT_32_AVX2_ASM_EXEC;; + +let POLY_USE_HINT_32_AVX2_ASM_SAFE = time prove + (full_spec, + REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; SOME_FLAGS] THEN + GEN_PROVE_SAFETY_SPEC_TAC ~public_vars:public_vars + ~tac_before_maychange_simp:NORMALIZE_AND_EXPAND_YMM_TAC + POLY_USE_HINT_32_AVX2_ASM_EXEC + [BYTES_LOADED_APPEND_CLAUSE] X86_SINGLE_STEP_TAC);; + +let POLY_USE_HINT_32_AVX2_ASM_NOIBT_SUBROUTINE_SAFE = time prove + (`exists f_events. + forall e a h pc stackpointer returnaddress. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, LENGTH poly_use_hint_32_avx2_asm_tmc) (a, 1024) /\ + nonoverlapping (a, 1024) (h, 1024) /\ + nonoverlapping (stackpointer, 8) (a, 1024) + ==> ensures x86 + (\s. + bytes_loaded s (word pc) poly_use_hint_32_avx2_asm_tmc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [a; h] s /\ + read events s = e) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (exists e2. + read events s = APPEND e2 e /\ + e2 = f_events a h pc stackpointer returnaddress /\ + memaccess_inbounds e2 [a,1024; h,1024; stackpointer,8] + [a,1024; stackpointer,8])) + (\s s'. true)`, + X86_PROMOTE_RETURN_NOSTACK_TAC poly_use_hint_32_avx2_asm_tmc POLY_USE_HINT_32_AVX2_ASM_SAFE THEN + DISCHARGE_SAFETY_PROPERTY_TAC);; + +let POLY_USE_HINT_32_AVX2_ASM_SUBROUTINE_SAFE = time prove + (`exists f_events. + forall e a h pc stackpointer returnaddress. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, LENGTH poly_use_hint_32_avx2_asm_mc) (a, 1024) /\ + nonoverlapping (a, 1024) (h, 1024) /\ + nonoverlapping (stackpointer, 8) (a, 1024) + ==> ensures x86 + (\s. + bytes_loaded s (word pc) poly_use_hint_32_avx2_asm_mc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [a; h] s /\ + read events s = e) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (exists e2. + read events s = APPEND e2 e /\ + e2 = f_events a h pc stackpointer returnaddress /\ + memaccess_inbounds e2 [a,1024; h,1024; stackpointer,8] + [a,1024; stackpointer,8])) + (\s s'. true)`, + MATCH_ACCEPT_TAC(ADD_IBT_RULE POLY_USE_HINT_32_AVX2_ASM_NOIBT_SUBROUTINE_SAFE));; + diff --git a/proofs/hol_light/x86_64/proofs/poly_use_hint_88_avx2_asm.ml b/proofs/hol_light/x86_64/proofs/poly_use_hint_88_avx2_asm.ml new file mode 100644 index 000000000..d6a9e98fe --- /dev/null +++ b/proofs/hol_light/x86_64/proofs/poly_use_hint_88_avx2_asm.ml @@ -0,0 +1,1607 @@ +(* + * Copyright (c) The mldsa-native project authors + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* ========================================================================= *) +(* Use hint to correct high bits of decomposition (ML-DSA, param 44). *) +(* x86_64 AVX2 variant (GAMMA2 = (Q-1)/88). *) +(* ========================================================================= *) + +needs "s2n_bignum/x86/proofs/base.ml";; +needs "mldsa_native/common/mldsa_specs.ml";; +needs "mldsa_native/x86_64/proofs/mldsa_utils.ml";; + +(**** print_literal_from_elf "x86_64/mldsa/poly_use_hint_88_avx2_asm.o";; + ****) + +let poly_use_hint_88_avx2_asm_mc = define_assert_from_elf + "poly_use_hint_88_avx2_asm_mc" "x86_64/mldsa/poly_use_hint_88_avx2_asm.o" +(*** BYTECODE START ***) +[ + 0xf3; 0x0f; 0x1e; 0xfa; (* ENDBR64 *) + 0xb9; 0x7f; 0x00; 0x00; 0x00; + (* MOV (% ecx) (Imm32 (word 127)) *) + 0x31; 0xc0; (* XOR (% eax) (% eax) *) + 0xc5; 0xd1; 0xef; 0xed; (* VPXOR (%_% xmm5) (%_% xmm5) (%_% xmm5) *) + 0x41; 0xb8; 0x0b; 0x2c; 0x00; 0x00; + (* MOV (% r8d) (Imm32 (word 11275)) *) + 0xc4; 0x41; 0x79; 0x6e; 0xc0; + (* VMOVD (%_% xmm8) (% r8d) *) + 0xc4; 0x42; 0x7d; 0x58; 0xc0; + (* VPBROADCASTD (%_% ymm8) (%_% xmm8) *) + 0xc5; 0xf9; 0x6e; 0xe1; (* VMOVD (%_% xmm4) (% ecx) *) + 0xb9; 0x00; 0x6c; 0x7e; 0x00; + (* MOV (% ecx) (Imm32 (word 8285184)) *) + 0x41; 0xb9; 0x80; 0x00; 0x00; 0x00; + (* MOV (% r9d) (Imm32 (word 128)) *) + 0xc4; 0xc1; 0x79; 0x6e; 0xf9; + (* VMOVD (%_% xmm7) (% r9d) *) + 0xc4; 0xe2; 0x7d; 0x58; 0xff; + (* VPBROADCASTD (%_% ymm7) (%_% xmm7) *) + 0x41; 0xba; 0x2b; 0x00; 0x00; 0x00; + (* MOV (% r10d) (Imm32 (word 43)) *) + 0xc4; 0xc1; 0x79; 0x6e; 0xf2; + (* VMOVD (%_% xmm6) (% r10d) *) + 0xc4; 0xe2; 0x7d; 0x58; 0xf6; + (* VPBROADCASTD (%_% ymm6) (%_% xmm6) *) + 0xc5; 0xf9; 0x6e; 0xd9; (* VMOVD (%_% xmm3) (% ecx) *) + 0xc4; 0xe2; 0x7d; 0x58; 0xe4; + (* VPBROADCASTD (%_% ymm4) (%_% xmm4) *) + 0xc4; 0xe2; 0x7d; 0x58; 0xdb; + (* VPBROADCASTD (%_% ymm3) (%_% xmm3) *) + 0xc5; 0xfd; 0x6f; 0x07; (* VMOVDQA (%_% ymm0) (Memop Word256 (%% (rdi,0))) *) + 0xc5; 0xfd; 0x6f; 0x0e; (* VMOVDQA (%_% ymm1) (Memop Word256 (%% (rsi,0))) *) + 0xc5; 0x5d; 0xfe; 0xd0; (* VPADDD (%_% ymm10) (%_% ymm4) (%_% ymm0) *) + 0xc4; 0xc1; 0x2d; 0x72; 0xd2; 0x07; + (* VPSRLD (%_% ymm10) (%_% ymm10) (Imm8 (word 7)) *) + 0xc4; 0x41; 0x2d; 0xe4; 0xd0; + (* VPMULHUW (%_% ymm10) (%_% ymm10) (%_% ymm8) *) + 0xc4; 0x62; 0x2d; 0x0b; 0xd7; + (* VPMULHRSW (%_% ymm10) (%_% ymm10) (%_% ymm7) *) + 0xc5; 0x7d; 0x66; 0xdb; (* VPCMPGTD (%_% ymm11) (%_% ymm0) (%_% ymm3) *) + 0xc4; 0x41; 0x25; 0xdf; 0xca; + (* VPANDN (%_% ymm9) (%_% ymm11) (%_% ymm10) *) + 0xc4; 0xc1; 0x1d; 0x72; 0xf2; 0x01; + (* VPSLLD (%_% ymm12) (%_% ymm10) (Imm8 (word 1)) *) + 0xc4; 0x41; 0x1d; 0xfe; 0xe2; + (* VPADDD (%_% ymm12) (%_% ymm12) (%_% ymm10) *) + 0xc4; 0xc1; 0x2d; 0x72; 0xf4; 0x05; + (* VPSLLD (%_% ymm10) (%_% ymm12) (Imm8 (word 5)) *) + 0xc4; 0x41; 0x2d; 0xfa; 0xd4; + (* VPSUBD (%_% ymm10) (%_% ymm10) (%_% ymm12) *) + 0xc4; 0xc1; 0x2d; 0x72; 0xf2; 0x0b; + (* VPSLLD (%_% ymm10) (%_% ymm10) (Imm8 (word 11)) *) + 0xc4; 0xc1; 0x7d; 0xfa; 0xc2; + (* VPSUBD (%_% ymm0) (%_% ymm0) (%_% ymm10) *) + 0xc4; 0xc1; 0x7d; 0xfe; 0xc3; + (* VPADDD (%_% ymm0) (%_% ymm0) (%_% ymm11) *) + 0xc5; 0xfd; 0x66; 0xc5; (* VPCMPGTD (%_% ymm0) (%_% ymm0) (%_% ymm5) *) + 0xc5; 0xfd; 0xdf; 0xc1; (* VPANDN (%_% ymm0) (%_% ymm0) (%_% ymm1) *) + 0xc5; 0xfd; 0x72; 0xf0; 0x01; + (* VPSLLD (%_% ymm0) (%_% ymm0) (Imm8 (word 1)) *) + 0xc5; 0xf5; 0xfa; 0xc0; (* VPSUBD (%_% ymm0) (%_% ymm1) (%_% ymm0) *) + 0xc4; 0xc1; 0x7d; 0xfe; 0xc1; + (* VPADDD (%_% ymm0) (%_% ymm0) (%_% ymm9) *) + 0xc4; 0xe3; 0x7d; 0x4c; 0xc6; 0x00; + (* VPBLENDVB (%_% ymm0) (%_% ymm0) (%_% ymm6) (%_% ymm0) *) + 0xc5; 0xfd; 0x66; 0xce; (* VPCMPGTD (%_% ymm1) (%_% ymm0) (%_% ymm6) *) + 0xc5; 0xf5; 0xdf; 0xc0; (* VPANDN (%_% ymm0) (%_% ymm1) (%_% ymm0) *) + 0xc5; 0xfd; 0x7f; 0x07; (* VMOVDQA (Memop Word256 (%% (rdi,0))) (%_% ymm0) *) + 0x48; 0x83; 0xc7; 0x20; (* ADD (% rdi) (Imm8 (word 32)) *) + 0x48; 0x83; 0xc6; 0x20; (* ADD (% rsi) (Imm8 (word 32)) *) + 0x48; 0x83; 0xc0; 0x20; (* ADD (% rax) (Imm8 (word 32)) *) + 0x48; 0x3d; 0x00; 0x04; 0x00; 0x00; + (* CMP (% rax) (Imm32 (word 1024)) *) + 0x0f; 0x85; 0x75; 0xff; 0xff; 0xff; + (* JNE (Imm32 (word 4294967157)) *) + 0xc3 (* RET *) +];; +(*** BYTECODE END ***) + +let poly_use_hint_88_avx2_asm_tmc = + define_trimmed "poly_use_hint_88_avx2_asm_tmc" poly_use_hint_88_avx2_asm_mc;; + +let POLY_USE_HINT_88_AVX2_ASM_EXEC = + X86_MK_CORE_EXEC_RULE poly_use_hint_88_avx2_asm_tmc;; + +(* ========================================================================= *) +(* Per-element scalar correctness chain for the x86_64 AVX2 poly_use_hint_88 *) +(* routine (ML-DSA-44, GAMMA2 = (Q-1)/88, 2*GAMMA2 = 190464). The numeric *) +(* Barrett model and the pure num/int decompose lemmas are arch-independent *) +(* (shared with the AArch64-88 proof); the x86 lane decomposition mirrors the *) +(* poly_use_hint_32 proof with the 88 constants. *) +(* ========================================================================= *) + +(* ------------------------------------------------------------------------- *) +(* Numeric (code-aligned) form of UseHint, matching the Barrett computation *) +(* performed by the assembly. Bridged to the FIPS 204 spec mldsa_use_hint_88 *) +(* via MLDSA_USE_HINT_88_EQUIV below. *) +(* ------------------------------------------------------------------------- *) +let mldsa_use_hint_88_code = new_definition + `mldsa_use_hint_88_code (a:num) (h:num) = + let a1_raw = ((((a + 127) DIV 128) * 11275 + 8388608) DIV 16777216) in + let a1 = if a1_raw > 43 then 0 else a1_raw in + let a0:int = &a - &a1 * &190464 in + let a0' = if a0 > &4190208 then a0 - &8380417 else a0 in + if h = 0 then a1 + else if a0' > &0 then if a1 = 43 then 0 else a1 + 1 + else if a1 = 0 then 43 else a1 - 1`;; + +(* ========================================================================= *) +(* Arithmetic bridge: the assembly's two rounding steps (vpmulhuw by 11275 *) +(* giving m16 = (t*11275) DIV 2^16, then vpmulhrsw by 128 giving *) +(* ((m16*128) DIV 2^14 + 1) DIV 2) equal the single code-form Barrett *) +(* division (t*11275 + 2^23) DIV 2^24. *) +(* ========================================================================= *) + +let MUL_DIV_128 = prove(`!m. (m * 128) DIV 16384 = m DIV 128`, + GEN_TAC THEN REWRITE_TAC[ARITH_RULE `16384 = 128 * 128`] THEN + REWRITE_TAC[GSYM DIV_DIV] THEN AP_THM_TAC THEN AP_TERM_TAC THEN + ONCE_REWRITE_TAC[MULT_SYM] THEN SIMP_TAC[DIV_MULT; ARITH_RULE `~(128=0)`]);; + +let A1_TWOSTEP_88 = prove(`!t. ((((t * 11275) DIV 65536) * 128) DIV 16384 + 1) DIV 2 = + (t * 11275 + 8388608) DIV 16777216`, + GEN_TAC THEN REWRITE_TAC[MUL_DIV_128] THEN + REWRITE_TAC[DIV_DIV] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(SPECL [`t * 11275`; `8388608`] ROUND_DIV) THEN + REWRITE_TAC[ARITH_RULE `~(8388608 = 0)`; ARITH_RULE `2 * 8388608 = 16777216`]);; + +(* ------------------------------------------------------------------------- *) +(* x86 lane lemmas (vpsrld $7 / vpmulhuw 11275 / vpmulhrsw 128). *) +(* ------------------------------------------------------------------------- *) + +(* Lane lemma: high 16 bits of the unsigned 16x16->32 multiply by 11275. *) +let LANE_MULHUW_88 = prove(`!u:16 word. val u < 65536 ==> + val(word_subword (word_mul ((word_zx u):int32) ((word_zx (word 11275:16 word)):int32)) (16,16):16 word) + = (val u * 11275) DIV 65536`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[VAL_WORD_SUBWORD; DIMINDEX_16] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(word_mul ((word_zx (u:16 word)):int32) ((word_zx (word 11275:16 word)):int32)) = val u * 11275` SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD_MUL; VAL_WORD_ZX_GEN; DIMINDEX_16; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(u:16 word) MOD 4294967296 = val u` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN MP_TAC(ISPEC `u:16 word` VAL_BOUND) THEN REWRITE_TAC[DIMINDEX_16] THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(word 11275:16 word) = 11275` SUBST1_TAC THENL [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC MOD_LT THEN MP_TAC(ISPEC `u:16 word` VAL_BOUND) THEN REWRITE_TAC[DIMINDEX_16] THEN ASM_ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC MOD_LT THEN + SUBGOAL_THEN `val(u:16 word) * 11275 < 65536 * 65536` MP_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(65536 = 0)`] THEN ARITH_TAC);; + +(* Lane lemma: the vpmulhrsw rounding-multiply lane by 128, for a 16-bit input + below 11265 (the range of the m16 = high16(t*11275) intermediate). *) +let LANE_MULHRSW_88 = prove(`!u:16 word. val u < 11265 ==> + val(word_subword (word_add (word_ushr (word_mul ((word_sx u):int32)((word_sx (word 128:16 word)):int32)) 14) (word 1:int32)) (1,16) :16 word) + = ((val u * 128) DIV 16384 + 1) DIV 2`, + GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `val((word_sx (u:16 word)):int32) = val u` ASSUME_TAC THENL + [MATCH_MP_TAC VAL_WORD_SX_SMALL THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val((word_sx (word 128:16 word)):int32) = 128` ASSUME_TAC THENL + [SUBGOAL_THEN `val(word 128:16 word) = 128` ASSUME_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + ASM_SIMP_TAC[VAL_WORD_SX_SMALL; ARITH_RULE `128 < 32768`]; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul ((word_sx (u:16 word)):int32)((word_sx (word 128:16 word)):int32)) = val u * 128` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD_MUL] THEN ASM_REWRITE_TAC[] THEN + MATCH_MP_TAC MOD_LT THEN REWRITE_TAC[DIMINDEX_32] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[VAL_WORD_SUBWORD; DIMINDEX_16] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `val(word_add (word_ushr (word_mul ((word_sx (u:16 word)):int32)((word_sx (word 128:16 word)):int32)) 14) (word 1:int32)) = (val u * 128) DIV 16384 + 1` SUBST1_TAC THENL + [REWRITE_TAC[VAL_WORD_ADD; VAL_WORD_USHR; VAL_WORD; DIMINDEX_32] THEN ASM_REWRITE_TAC[] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC MOD_LT THEN + SUBGOAL_THEN `(val(u:16 word) * 128) DIV 16384 <= 88` MP_TAC THENL + [SUBGOAL_THEN `val(u:16 word) * 128 < 11265 * 128` MP_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(16384=0)`] THEN ARITH_TAC; ALL_TAC] THEN ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC MOD_LT THEN + SUBGOAL_THEN `(val(u:16 word) * 128) DIV 16384 + 1 <= 89` MP_TAC THENL + [SUBGOAL_THEN `(val(u:16 word) * 128) DIV 16384 <= 88` MP_TAC THENL + [SUBGOAL_THEN `val(u:16 word) * 128 < 11265 * 128` MP_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(16384=0)`] THEN ARITH_TAC; ALL_TAC] THEN ARITH_TAC; ALL_TAC] THEN + ARITH_TAC);; + +(* Composed per-lane Barrett a1: the full x86 decompose a1 lane (pre-shift t via + VAL_T, vpmulhuw by 11275 via LANE_MULHUW_88, vpmulhrsw by 128 via + LANE_MULHRSW_88) on a coefficient a:int32 with val a < Q equals the code-form + a1 Barrett value ((val a + 127) DIV 128 * 11275 + 8388608) DIV 16777216. *) +let A1_LANE_88 = prove(`!a:int32. val a < 8380417 ==> + val(word_subword + (word_add + (word_ushr + (word_mul + ((word_sx + (word_subword + (word_mul + ((word_zx + (word_subword + (word_ushr (word_add (word 127) a) 7) (0,16) + :16 word)):int32) + (word 11275:int32)) + (16,16) :16 word)):int32) + (word 128:int32)) + 14) + (word 1:int32)) + (1,16) :16 word) + = ((val a + 127) DIV 128 * 11275 + 8388608) DIV 16777216`, + GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `(word 11275:int32) = word_zx (word 11275:16 word)` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + SUBGOAL_THEN `(word 128:int32) = word_sx (word 128:16 word)` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + ABBREV_TAC `t = (val(a:int32) + 127) DIV 128` THEN + SUBGOAL_THEN `t < 65473` ASSUME_TAC THENL + [EXPAND_TAC "t" THEN + MATCH_MP_TAC(ARITH_RULE `x <= 8380543 ==> x DIV 128 < 65473`) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + ABBREV_TAC + `u0 = word_subword + (word_ushr (word_add (word 127) (a:int32)) 7) (0,16) :16 word` THEN + SUBGOAL_THEN `val(u0:16 word) = t` ASSUME_TAC THENL + [EXPAND_TAC "u0" THEN + REWRITE_TAC[VAL_WORD_SUBWORD; DIMINDEX_16] THEN CONV_TAC NUM_REDUCE_CONV THEN + MP_TAC(SPEC `a:int32` VAL_T) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[ARITH_RULE `2 EXP 0 = 1`; DIV_1] THEN + MATCH_MP_TAC MOD_LT THEN EXPAND_TAC "t" THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(u0:16 word) < 65536` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN + `val(word_subword + (word_mul ((word_zx (u0:16 word)):int32) + ((word_zx (word 11275:16 word)):int32)) (16,16) :16 word) + = (t * 11275) DIV 65536` + ASSUME_TAC THENL + [MP_TAC(SPEC `u0:16 word` LANE_MULHUW_88) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ABBREV_TAC + `m16 = word_subword + (word_mul ((word_zx (u0:16 word)):int32) + ((word_zx (word 11275:16 word)):int32)) (16,16) :16 word` THEN + SUBGOAL_THEN `val(m16:16 word) = (t * 11275) DIV 65536` ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `val(m16:16 word) < 11265` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN + SIMP_TAC[RDIV_LT_EQ; ARITH_RULE `~(65536 = 0)`] THEN ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `m16:16 word` LANE_MULHRSW_88) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[] THEN + MATCH_ACCEPT_TAC A1_TWOSTEP_88);; + +(* ------------------------------------------------------------------------- *) +(* Decompose num lemmas (arch-independent). *) +(* ------------------------------------------------------------------------- *) + +let A1_BOUND_88 = prove( + `!a. a < 8380417 + ==> ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 <= 44`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8380416 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `16777216` (SPEC `751819508` (SPEC `d * 11275 + 8388608` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 11275 <= 65472 * 11275` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +let A1_BOUND_NOWRAP_88 = prove( + `!a. a <= 8285184 + ==> ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 <= 43`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8285184 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `16777216` (SPEC `738196808` (SPEC `d * 11275 + 8388608` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 11275 <= 64728 * 11275` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +(* Barrett equivalence for _88: 45-interval case analysis *) +let BARRETT_INTERVAL_88 = prove( + `!a lo hi k. + lo <= a /\ a <= hi /\ + k * 131072 <= (2 * lo * 1477838209) DIV 4294967296 + 65536 /\ + (2 * hi * 1477838209) DIV 4294967296 + 65536 < (k + 1) * 131072 /\ + k * 16777216 <= (lo + 127) DIV 128 * 11275 + 8388608 /\ + (hi + 127) DIV 128 * 11275 + 8388608 < (k + 1) * 16777216 + ==> ((2 * a * 1477838209) DIV 4294967296 + 65536) DIV 131072 = k /\ + ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 = k`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + CONJ_TAC THEN MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THENL + [CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(2 * lo * 1477838209) DIV 4294967296 + 65536` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 65536 <= y + 65536 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(2 * hi * 1477838209) DIV 4294967296 + 65536` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 65536 <= y + 65536 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]; + CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(lo + 127) DIV 128 * 11275 + 8388608` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 8388608 <= y + 8388608 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(hi + 127) DIV 128 * 11275 + 8388608` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 8388608 <= y + 8388608 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]]);; + +let A1_WRAP_88 = prove( + `!a. 8285184 < a /\ a < 8380417 + ==> ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 = 44`, + GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN `44 <= ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216` + ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `a + 127` (SPEC `8285185 + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `16777216` (SPEC `d * 11275 + 8388608` (SPEC `738208083` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `64729 * 11275 <= d * 11275` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + MP_TAC(SPEC `a:num` A1_BOUND_88) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_ARITH_TAC);; + +let A0_UPPER_88 = prove( + `!a. a <= 8285184 + ==> a < (((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 + 1) * 190464`, + GEN_TAC THEN DISCH_TAC THEN + ABBREV_TAC `nv = ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216` THEN + SUBGOAL_THEN `nv * 16777216 <= (a + 127) DIV 128 * 11275 + 8388608` ASSUME_TAC THENL + [EXPAND_TAC "nv" THEN MP_TAC(SPECL [`(a + 127) DIV 128 * 11275 + 8388608`; `16777216`] (CONJUNCT1 DIVISION_SIMP)) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(a + 127) DIV 128 <= 64728` ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `8285184 + 127` (SPEC `a + 127` DIV_MONO))) THEN ANTS_TAC THENL [ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv * 16777216 <= 64728 * 11275 + 8388608` ASSUME_TAC THENL + [SUBGOAL_THEN `(a + 127) DIV 128 * 11275 <= 64728 * 11275` MP_TAC THENL [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; ASM_ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 43` ASSUME_TAC THENL [CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_ARITH_TAC);; + +(* ------------------------------------------------------------------------- *) +(* x86 helper lemmas (BITBLAST), value form of the a1 lane, threshold/sign. *) +(* ------------------------------------------------------------------------- *) + +let A1_LANE_VAL_88 = prove( + `!a:int32. val a < 8380417 + ==> val(word_zx (word_subword + (word_add (word_ushr (word_mul (word_sx (word_subword + (word_mul (word_zx (word_subword (word_ushr (word_add (word 127) a) 7) (0,16) :16 word) :int32) + (word 11275:int32)) (16,16) :16 word) :int32) (word 128:int32)) 14) (word 1:int32)) (1,16) :16 word) :int32) + = ((val a + 127) DIV 128 * 11275 + 8388608) DIV 16777216`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[VAL_WORD_ZX_GEN; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN + `val(word_subword + (word_add (word_ushr (word_mul (word_sx (word_subword + (word_mul (word_zx (word_subword (word_ushr (word_add (word 127) (a:int32)) 7) (0,16) :16 word) :int32) + (word 11275:int32)) (16,16) :16 word) :int32) (word 128:int32)) 14) (word 1:int32)) (1,16) :16 word) + = ((val(a:int32) + 127) DIV 128 * 11275 + 8388608) DIV 16777216` + (fun th -> REWRITE_TAC[th]) THENL + [MATCH_MP_TAC A1_LANE_88 THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + MATCH_MP_TAC MOD_LT THEN MP_TAC(SPEC `val(a:int32)` A1_BOUND_88) THEN ASM_REWRITE_TAC[] THEN ARITH_TAC);; + +let WORD_IGT_THRESHOLD_88 = BITBLAST_RULE + `!a:int32. val a < 8380417 ==> (word_igt a (word 8285184:int32) <=> val a > 8285184)`;; + +let WRAP_A0_NEGATIVE_88 = BITBLAST_RULE + `!a:int32. val a < 8380417 /\ val a > 8285184 + ==> bit 31 (word_add (word_sub a (word 8380416:int32)) (word 4294967295:int32))`;; + +let WORD_SUB_SIGN_BARE_88 = BITBLAST_RULE + `!a:int32 b:int32. val b <= 8189952 /\ val a <= 8285184 ==> + ((bit 31 (word_sub a b) \/ word_sub a b = word 0) <=> val a <= val b)`;; + +(* Sign / threshold helpers for tmp = a1' + delta*h (delta in {+1,-1}). *) +let WORD_ILT_0_SMALL = BITBLAST_RULE + `!x:int32. val x <= 43 ==> ~(word_ilt x (word 0))`;; +let WORD_IGT_43_BOUND = BITBLAST_RULE + `!x:int32. val x <= 43 ==> ~(word_igt x (word 43:int32))`;; +let WORD_IGT_43_ADD1 = BITBLAST_RULE + `!x:int32. val x <= 42 ==> ~(word_igt (word_add x (word 1:int32)) (word 43:int32))`;; +let WORD_ILT_0_ADD1 = BITBLAST_RULE + `!x:int32. val x <= 42 ==> ~(word_ilt (word_add x (word 1:int32)) (word 0:int32))`;; +let WORD_IGT_43_SUB1 = BITBLAST_RULE + `!x:int32. val x <= 43 /\ ~(val x = 0) ==> + ~(word_igt (word_add x (word 4294967295:int32)) (word 43:int32))`;; +let WORD_ILT_0_SUB1 = BITBLAST_RULE + `!x:int32. val x <= 43 /\ ~(val x = 0) ==> + ~(word_ilt (word_add x (word 4294967295:int32)) (word 0:int32))`;; + +let WORD_AND_FULL_32 = BITBLAST_RULE `!x:int32. word_and x (word 4294967295) = x`;; + +(* Final-clamp leaf results, evaluated at a1 = word nv (nv <= 43). These are the + value of the vpblendvb/vpcmpgtd/vpandn clamp for the three reachable deltas: + h=0 (tmp = nv), h=1 with a0>0 (tmp = nv+1), h=1 with a0<=0 (tmp = nv-1, or -1 + wrapping to 43 at nv=0). Types are pinned to int32 throughout to avoid type + invention when these are MATCH_MP_TAC'd into the element proof. *) +let TMP_RESULT_H0 = prove( + `!nv:num. nv <= 43 ==> + val(word_and + (if word_ilt (word nv:int32) (word 0:int32) then (word 43:int32) else (word nv:int32)) + (word_not (word_neg (word (bitval (word_igt + (if word_ilt (word nv:int32) (word 0:int32) then (word 43:int32) else (word nv:int32)) + (word 43:int32))):int32)))) + = nv`, + GEN_TAC THEN DISCH_TAC THEN + SUBGOAL_THEN `~(word_ilt (word nv:int32) (word 0:int32))` ASSUME_TAC THENL + [MATCH_MP_TAC WORD_ILT_0_SMALL THEN REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ASM_REWRITE_TAC[]]; ALL_TAC] THEN + SUBGOAL_THEN `~(word_igt (word nv:int32) (word 43:int32))` ASSUME_TAC THENL + [MATCH_MP_TAC WORD_IGT_43_BOUND THEN REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ASM_REWRITE_TAC[]]; ALL_TAC] THEN + ASM_REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[WORD_AND_FULL_32] THEN + REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC);; + +let TMP_RESULT_POS = prove( + `!nv:num. nv <= 43 ==> + val(word_and + (if word_ilt (word_add (word nv:int32) (word 1:int32)) (word 0:int32) + then (word 43:int32) else word_add (word nv:int32) (word 1:int32)) + (word_not (word_neg (word (bitval (word_igt + (if word_ilt (word_add (word nv:int32) (word 1:int32)) (word 0:int32) + then (word 43:int32) else word_add (word nv:int32) (word 1:int32)) + (word 43:int32))):int32)))) + = (if nv = 43 then 0 else nv + 1)`, + GEN_TAC THEN DISCH_TAC THEN + ASM_CASES_TAC `nv = 43` THEN ASM_REWRITE_TAC[] THENL + [CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[BITVAL_CLAUSES] THEN + CONV_TAC WORD_REDUCE_CONV THEN CONV_TAC NUM_REDUCE_CONV + ; + SUBGOAL_THEN `val(word nv:int32) <= 42` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + UNDISCH_TAC `nv <= 43` THEN UNDISCH_TAC `~(nv = 43)` THEN ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `word nv:int32` WORD_ILT_0_ADD1) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + MP_TAC(SPEC `word nv:int32` WORD_IGT_43_ADD1) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + ASM_REWRITE_TAC[] THEN REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[WORD_AND_FULL_32] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC]);; + +let TMP_RESULT_NEG = prove( + `!nv:num. nv <= 43 ==> + val(word_and + (if word_ilt (word_add (word nv:int32) (word 4294967295:int32)) (word 0:int32) + then (word 43:int32) else word_add (word nv:int32) (word 4294967295:int32)) + (word_not (word_neg (word (bitval (word_igt + (if word_ilt (word_add (word nv:int32) (word 4294967295:int32)) (word 0:int32) + then (word 43:int32) else word_add (word nv:int32) (word 4294967295:int32)) + (word 43:int32))):int32)))) + = (if nv = 0 then 43 else nv - 1)`, + GEN_TAC THEN DISCH_TAC THEN + ASM_CASES_TAC `nv = 0` THEN ASM_REWRITE_TAC[] THENL + [CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[BITVAL_CLAUSES] THEN + CONV_TAC WORD_REDUCE_CONV THEN CONV_TAC NUM_REDUCE_CONV + ; + SUBGOAL_THEN `val(word nv:int32) <= 43 /\ ~(val(word nv:int32) = 0)` STRIP_ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `word nv:int32` WORD_ILT_0_SUB1) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + MP_TAC(SPEC `word nv:int32` WORD_IGT_43_SUB1) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + ASM_REWRITE_TAC[] THEN REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[WORD_AND_FULL_32] THEN + REWRITE_TAC[VAL_WORD_ADD; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL + [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `nv + 4294967295 = (nv - 1) + 1 * 4294967296` SUBST1_TAC THENL + [UNDISCH_TAC `~(nv = 0)` THEN ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[MOD_MULT_ADD] THEN MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC]);; + +(* h = 1 leaf for general delta sign b = (a0 <= 0): the vpblendvb/clamp value at + a1 = word nv with delta = word_or(word_neg(word(bitval b)))(word 1). *) +let H1_LEAF = prove( + `!nv:num b:bool. nv <= 43 ==> + val(word_and + (if word_ilt (word_add (word nv:int32) (word_or (word_neg (word (bitval b):int32)) (word 1:int32))) (word 0:int32) + then (word 43:int32) + else word_add (word nv:int32) (word_or (word_neg (word (bitval b):int32)) (word 1:int32))) + (word_not (word_neg (word (bitval (word_igt + (if word_ilt (word_add (word nv:int32) (word_or (word_neg (word (bitval b):int32)) (word 1:int32))) (word 0:int32) + then (word 43:int32) + else word_add (word nv:int32) (word_or (word_neg (word (bitval b):int32)) (word 1:int32))) + (word 43:int32))):int32)))) + = (if b then (if nv = 0 then 43 else nv - 1) else (if nv = 43 then 0 else nv + 1))`, + REPEAT GEN_TAC THEN DISCH_TAC THEN + BOOL_CASES_TAC `b:bool` THEN REWRITE_TAC[BITVAL_CLAUSES] THEN + CONV_TAC WORD_REDUCE_CONV THENL + [MATCH_MP_TAC TMP_RESULT_NEG THEN ASM_REWRITE_TAC[]; + MATCH_MP_TAC TMP_RESULT_POS THEN ASM_REWRITE_TAC[]]);; + +(* Reconcile the H1_LEAF result (keyed on b = a0 <= 0) with the code RHS + (keyed on a0 > 0): the two if-forms are equal since the conditions are + complementary and the branches are swapped. *) +let H1_RHS = prove( + `!x:int (A:num) B. (if ~(&0 < x) then A else B) = (if &0 < x then B else A)`, + REPEAT GEN_TAC THEN ASM_CASES_TAC `&0 < x:int` THEN ASM_REWRITE_TAC[]);; + +(* ------------------------------------------------------------------------- *) +(* The per-element x86 word model of UseHint (param 44), matching the scalar *) +(* form one coefficient lane computes. Structure mirrors *) +(* mldsa_use_hint_32_x86_asm but with the 88 constants: *) +(* t = (a + 127) >>u 7 (vpaddd 127, vpsrld 7) *) +(* m16 = high16(t * 11275) (vpmulhuw) *) +(* a1lane = ((m16 * 128) >> 14 + 1) >> 1 [bits (1,16)] (vpmulhrsw) *) +(* a1 = zx a1lane *) +(* m = (a > 8285184 ? -1 : 0) (vpcmpgtd) *) +(* a1' = a1 & ~m (vpandn) *) +(* a0 = a - a1*190464 + m (190464 via the 3*/31*/2048 shift chain) *) +(* delta = (a0 <= 0 ? -1 : +1) *) +(* tmp = a1' + delta*h *) +(* tmp2 = (tmp < 0 ? 43 : tmp) (vpblendvb against 43; only tmp=-1 *) +(* reaches the negative selection) *) +(* result = (tmp2 > 43 ? 0 : tmp2) (vpcmpgtd 43, vpandn) *) +(* ------------------------------------------------------------------------- *) +let mldsa_use_hint_88_x86_asm = new_definition + `mldsa_use_hint_88_x86_asm (a:int32) (h:int32) : int32 = + let t = word_ushr (word_add (word 127) a) 7 in + let m16 = word_subword + (word_mul (word_zx (word_subword t (0,16) :16 word) :int32) + (word 11275)) (16,16) :16 word in + let a1lane = word_subword + (word_add + (word_ushr (word_mul (word_sx m16 :int32) (word 128)) 14) + (word 1)) (1,16) :16 word in + let a1 = word_zx a1lane :int32 in + let m:int32 = (if word_igt a (word 8285184) then word 4294967295 + else word 0) in + let a1' = word_and a1 (word_not m) in + let a0 = word_add (word_sub a (word_mul a1 (word 190464))) m in + let delta:int32 = word_or (word_neg(word(bitval(word_ile a0 (word 0))))) + (word 1) in + let tmp = word_add a1' (word_mul delta h) in + let tmp2:int32 = if word_ilt tmp (word 0) then word 43 else tmp in + let neg_mask:int32 = word_neg(word(bitval(word_igt tmp2 (word 43)))) in + word_and tmp2 (word_not neg_mask)`;; + +(* ------------------------------------------------------------------------- *) +(* Element correctness (value form): the per-lane x86 word model equals the *) +(* scalar code form for one coefficient, for val a < Q and val h <= 1. *) +(* Case split is wrap (a > 8285184 => a1 lane = 44, masked to 0) vs no-wrap; *) +(* the no-wrap case further splits on h and on the a0-sign delta. *) +(* ------------------------------------------------------------------------- *) +let ELEMENT_CORRECT_88 = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> val(mldsa_use_hint_88_x86_asm a h) = mldsa_use_hint_88_code (val a) (val h)`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[mldsa_use_hint_88_x86_asm; mldsa_use_hint_88_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ABBREV_TAC `nv = ((val(a:int32) + 127) DIV 128 * 11275 + 8388608) DIV 16777216` THEN + ABBREV_TAC + `a1w:int32 = + word_zx (word_subword + (word_add (word_ushr (word_mul (word_sx (word_subword + (word_mul (word_zx (word_subword + (word_ushr (word_add (word 127) (a:int32)) 7) (0,16) :16 word) :int32) + (word 11275)) (16,16) :16 word) :int32) + (word 128)) 14) (word 1)) (1,16) :16 word) :int32` THEN + SUBGOAL_THEN `val(a1w:int32) = nv` ASSUME_TAC THENL + [EXPAND_TAC "a1w" THEN EXPAND_TAC "nv" THEN + MATCH_MP_TAC A1_LANE_VAL_88 THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 44` ASSUME_TAC THENL + [MP_TAC(SPEC `val(a:int32)` A1_BOUND_88) THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `word_igt (a:int32) (word 8285184:int32) <=> val a > 8285184` + SUBST1_TAC THENL + [MP_TAC(SPEC `a:int32` WORD_IGT_THRESHOLD_88) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_CASES_TAC `val(a:int32) > 8285184` THEN ASM_REWRITE_TAC[] THENL + [ + (* WRAP ZONE: nv = 44, a1 lane masked to 0. *) + SUBGOAL_THEN `nv = 44` SUBST_ALL_TAC THENL + [MP_TAC(SPEC `val(a:int32)` A1_WRAP_88) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `a1w:int32 = word 44` ASSUME_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ] THEN ASM_REWRITE_TAC[] THEN CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + FIRST_X_ASSUM(fun th -> if concl th = `a1w:int32 = word 44` then SUBST_ALL_TAC th else failwith "no") THEN + CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `word_and (word 44:int32) (word_not (word 4294967295)) = word 0` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + SUBGOAL_THEN `word_mul (word 44:int32) (word 190464) = word 8380416` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + SUBGOAL_THEN `word_ile (word_add (word_sub (a:int32) (word 8380416)) (word 4294967295)) (word 0)` ASSUME_TAC THENL + [REWRITE_TAC[WORD_ILE_ZERO_32] THEN DISJ1_TAC THEN + MP_TAC(SPEC `a:int32` WRAP_A0_NEGATIVE_88) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN + ASM_CASES_TAC `val(h:int32) = 0` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV; + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + SUBGOAL_THEN `(&(val(a:int32)):int) > &4190208` ASSUME_TAC THENL + [MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_GT; INT_GT](ASSUME `val(a:int32) > 8285184`)) THEN INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `~((&(val(a:int32)) - &8380417:int) > &0)` ASSUME_TAC THENL + [MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT](ASSUME `val(a:int32) < 8380417`)) THEN INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[]] + ; + (* NO-WRAP ZONE: mask = 0, a1' = a1 lane = nv <= 43. *) + SUBGOAL_THEN `nv <= 43` ASSUME_TAC THENL + [MP_TAC(SPEC `val(a:int32)` A1_BOUND_NOWRAP_88) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `(if nv > 43 then 0 else nv) = nv` SUBST1_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[WORD_NOT_0; WORD_AND_REFL; WORD_ADD_0] THEN + SUBGOAL_THEN `val(word_mul (a1w:int32) (word 190464:int32)) = nv * 190464` ASSUME_TAC THENL + [REWRITE_TAC[VAL_WORD_MUL; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `val(word 190464:int32) = 190464` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; ALL_TAC] THEN + MATCH_MP_TAC MOD_LT THEN + SUBGOAL_THEN `nv * 190464 <= 43 * 190464` MP_TAC THENL + [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv * 190464 <= 8189952` ASSUME_TAC THENL + [SUBGOAL_THEN `nv * 190464 <= 43 * 190464` MP_TAC THENL + [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul (a1w:int32) (word 190464:int32)) <= 8189952` ASSUME_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN + `word_ile (word_sub (a:int32) (word_mul a1w (word 190464:int32))) (word 0) + <=> ~(&(val a) - &nv * &190464 > (&0:int))` SUBST1_TAC THENL + [REWRITE_TAC[WORD_ILE_ZERO_32] THEN + SUBGOAL_THEN + `(bit 31 (word_sub (a:int32) (word_mul a1w (word 190464:int32))) \/ + word_sub a (word_mul a1w (word 190464:int32)) = word 0) + <=> val a <= val(word_mul a1w (word 190464:int32))` + SUBST1_TAC THENL + [MATCH_MP_TAC WORD_SUB_SIGN_BARE_88 THEN ASM_REWRITE_TAC[] THEN ASM_ARITH_TAC; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(a:int32) <= nv * 190464` THENL + [ASM_REWRITE_TAC[] THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_MUL] + (REWRITE_RULE[GSYM INT_OF_NUM_LE] + (ASSUME `val(a:int32) <= nv * 190464`))) THEN INT_ARITH_TAC; + ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `nv * 190464 < val(a:int32)` ASSUME_TAC THENL + [UNDISCH_TAC `~(val(a:int32) <= nv * 190464)` THEN ARITH_TAC; ALL_TAC] THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_MUL] + (REWRITE_RULE[GSYM INT_OF_NUM_LT] + (ASSUME `nv * 190464 < val(a:int32)`))) THEN INT_ARITH_TAC]; + ALL_TAC] THEN + SUBGOAL_THEN `~(int_gt (&(val(a:int32)) - &nv * &190464) (&4190208))` ASSUME_TAC THENL + [REWRITE_TAC[INT_GT; INT_NOT_LT] THEN + MP_TAC(SPEC `val(a:int32)` A0_UPPER_88) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_MUL; + GSYM INT_OF_NUM_ADD] (ASSUME `val(a:int32) < (nv + 1) * 190464`)) THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_GT] THEN + RULE_ASSUM_TAC(REWRITE_RULE[INT_GT]) THEN ASM_REWRITE_TAC[] THEN + (* Replace a1w by word nv everywhere (val a1w = nv, nv <= 43). *) + SUBGOAL_THEN `a1w:int32 = word nv` SUBST_ALL_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ] THEN ASM_REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC SYM_CONV THEN MATCH_MP_TAC MOD_LT THEN + ASM_ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `val(h:int32) = 0` THEN ASM_REWRITE_TAC[] THENL + [ + (* h = 0: tmp = a1' = nv, no correction; leaf via TMP_RESULT_H0 *) + SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[WORD_MUL_0; WORD_ADD_0] THEN + MATCH_MP_TAC TMP_RESULT_H0 THEN ASM_REWRITE_TAC[] + ; + (* h = 1: tmp = nv + delta, delta = word_or(word_neg(word(bitval(a0<=0))))(word 1). + The whole vpblendvb/clamp value is given by H1_LEAF with + b = ~(&(val a) - &nv * &190464 > &0) (which equals a0 <= 0); the code RHS + then matches by case analysis on b. *) + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL + [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[WORD_MUL_1_32] THEN + MP_TAC(SPECL [`nv:num`; `~((&0:int) < &(val(a:int32)) - &nv * &190464)`] H1_LEAF) THEN + ANTS_TAC THENL + [UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; + DISCH_THEN(fun th -> REWRITE_TAC[th]) THEN MATCH_ACCEPT_TAC H1_RHS]]]);; + +(* Word form, directly usable in the loop body discharge. *) +let ELEMENT_CORRECT_WORD_88 = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> mldsa_use_hint_88_x86_asm a h = + word(mldsa_use_hint_88_code (val a) (val h))`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + GEN_REWRITE_TAC LAND_CONV [GSYM WORD_VAL] THEN + AP_TERM_TAC THEN MP_TAC(SPECL [`a:int32`; `h:int32`] ELEMENT_CORRECT_88) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN(fun th -> REWRITE_TAC[th]));; + +(* ========================================================================= *) +(* FIPS 204 = code-aligned equivalence (arch-independent). *) +(* The decompose r1/r0 lemmas and tactics below are pure num/int. *) +(* ========================================================================= *) + +(* The per-variant divisor is 2*GAMMA2 = 190464. The generic Barrett DIV/MOD + tactics are shared from mldsa_utils.ml. *) +let LINEARIZE_DIV_MOD_TAC_88 = LINEARIZE_DIV_MOD_BY_TAC 190464;; +let DIV_190464_TAC k = DIV_EQ_K_BY_TAC 190464 k;; +let DIV_MOD_TO_DIV_TAC_88 = DIV_MOD_TO_DIV_BY_TAC 190464;; + +let DECOMPOSE_R1_LOWER_TAC_88 = + SUBGOAL_THEN `~((&r:int) - &(r MOD 190464) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN LINEARIZE_DIV_MOD_TAC_88; + ALL_TAC] THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + DIV_MOD_TO_DIV_TAC_88 THEN + CONV_TAC SYM_CONV THEN + LINEARIZE_DIV_MOD_TAC_88;; + +let DECOMPOSE_R1_UPPER_TAC_88 = + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `~((&r:int) - (&(r MOD 190464) - &190464) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [REWRITE_TAC[INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`] THEN + ASM_SIMP_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + LINEARIZE_DIV_MOD_TAC_88; ALL_TAC] THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 190464) - &190464) = + &(r - r MOD 190464 + 190464)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 190464 + 190464 = 190464 * (r DIV 190464 + 1)` + SUBST1_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`190464`; `r DIV 190464 + 1`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_THEN SUBST1_TAC THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 190464` (concl th) || + free_in `r DIV 190464` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 190464`, `m:num`) THEN + SPEC_TAC(`r DIV 190464`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +let DECOMPOSE_R1_NOWRAP_TAC_88 = + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THEN + TRY DECOMPOSE_R1_LOWER_TAC_88 THEN TRY DECOMPOSE_R1_UPPER_TAC_88;; + +let DECOMPOSE_88_R1_EQUIV = time prove( + `!r. r < 8380417 ==> + (let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw) = + decompose_88_r1 r`, + GEN_TAC THEN DISCH_TAC THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ASM_CASES_TAC `r <= 8285184` THENL + [ALL_TAC; + (* Wrap zone *) + SUBGOAL_THEN `8285184 < r` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `decompose_88_r1 r = 0` SUBST1_TAC THENL + [REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THENL + [(* Lower wrap: r DIV 190464 = 44 *) + SUBGOAL_THEN `r DIV 190464 = 44` ASSUME_TAC THENL + [DIV_190464_TAC 44; ALL_TAC] THEN + SUBGOAL_THEN `44 * 190464 + r MOD 190464 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + ASM_ARITH_TAC; + (* Upper wrap: r DIV 190464 = 43 *) + SUBGOAL_THEN `r DIV 190464 = 43` ASSUME_TAC THENL + [DIV_190464_TAC 43; ALL_TAC] THEN + SUBGOAL_THEN `43 * 190464 + r MOD 190464 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 190464) - &190464) = + &(r - r MOD 190464 + 190464)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC]; + ALL_TAC] THEN + MP_TAC(SPEC `r:num` A1_WRAP_88) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONV_TAC NUM_REDUCE_CONV] THEN + (* Nowrap zone: Barrett <= 43, so if > 43 then 0 else Barrett = Barrett *) + SUBGOAL_THEN `((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 <= 43` + ASSUME_TAC THENL + [MATCH_MP_TAC A1_BOUND_NOWRAP_88 THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 > 43)` + (fun th -> REWRITE_TAC[th]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + (* Nowrap zone: unfold and do interval cascade *) + REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + let intervals = [ + (0, 95232); (95233, 285696); (285697, 476160); (476161, 666624); + (666625, 857088); (857089, 1047552); (1047553, 1238016); + (1238017, 1428480); (1428481, 1618944); (1618945, 1809408); + (1809409, 1999872); (1999873, 2190336); (2190337, 2380800); + (2380801, 2571264); (2571265, 2761728); (2761729, 2952192); + (2952193, 3142656); (3142657, 3333120); (3333121, 3523584); + (3523585, 3714048); (3714049, 3904512); (3904513, 4094976); + (4094977, 4285440); (4285441, 4475904); (4475905, 4666368); + (4666369, 4856832); (4856833, 5047296); (5047297, 5237760); + (5237761, 5428224); (5428225, 5618688); (5618689, 5809152); + (5809153, 5999616); (5999617, 6190080); (6190081, 6380544); + (6380545, 6571008); (6571009, 6761472); (6761473, 6951936); + (6951937, 7142400); (7142401, 7332864); (7332865, 7523328); + (7523329, 7713792); (7713793, 7904256); (7904257, 8094720); + (8094721, 8285184)] in + let mk_le hi = + mk_comb(mk_comb(`(<=):num->num->bool`, mk_var("r",`:num`)), + mk_small_numeral hi) in + let apply_interval k (lo, hi) = + let th = SPECL [`r:num`; mk_small_numeral lo; + mk_small_numeral hi; mk_small_numeral k] + BARRETT_INTERVAL_88 in + MP_TAC th THEN CONV_TAC NUM_REDUCE_CONV THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + DECOMPOSE_R1_NOWRAP_TAC_88 in + let rec cascade k = function + | [(lo,hi)] -> apply_interval k (lo,hi) + | (lo,hi)::rest -> + ASM_CASES_TAC (mk_le hi) THENL + [apply_interval k (lo,hi); cascade (k+1) rest] + | [] -> failwith "empty" in + cascade 0 intervals);; + +let R1_IS_DIV_LOWER_88 = prove( + `!r. r < 8380417 /\ r MOD 190464 * 2 <= 190464 /\ + ~((&r:int) - &(r MOD 190464) = &8380416) ==> + (let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw) = r DIV 190464`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` LOWER_NONWRAP_R1_88) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +let R1_IS_DIV_PLUS1_UPPER_88 = prove( + `!r. r < 8380417 /\ ~(r MOD 190464 * 2 <= 190464) /\ + ~((&r:int) - (&(r MOD 190464) - &190464) = &8380416) ==> + (let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw) = r DIV 190464 + 1`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` UPPER_NONWRAP_R1_88) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +let R0_SIGN_UPPER_NOWRAP_TAC_88 = + MP_TAC(SPEC `r:num` R1_IS_DIV_PLUS1_UPPER_88) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `190464`] INT_MOD_RESIDUE)) THEN + REWRITE_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_MUL] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `(a:int) - (b + &1) * c = a - b * c - c`] THEN + REWRITE_TAC[INT_ARITH `x - &190464 > &0 <=> x > &190464`; + INT_ARITH `x - &190464 - &8380417 > &0 <=> x > &8570881`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +let R0_SIGN_LOWER_NOWRAP_TAC_88 = + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER_88) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `190464`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +let R0_SIGN_WRAP_TAC_88 = + SUBGOAL_THEN `8285184 < r` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + can (find_term (fun t -> t = `&8380416:int`)) (concl th) && + not(is_neg(concl th)))) THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ; + INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`; + GSYM INT_OF_NUM_ADD] THEN ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + REWRITE_TAC[INT_ARITH `x - &1 > &0 <=> x > &1`; + INT_ARITH `(x - &190464) - &1 > &0 <=> x > &190465`; + INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +let DECOMPOSE_88_R0_SIGN = time prove( + `!r. r < 8380417 ==> + let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + let a1 = if a1_raw > 43 then 0 else a1_raw in + let a0':int = if (&r:int) - &a1 * &190464 > &4190208 + then &r - &a1 * &190464 - &8380417 + else &r - &a1 * &190464 in + (decompose_88_r0 r > &0 <=> a0' > &0) /\ + (decompose_88_r0 r <= &0 <=> ~(a0' > &0))`, + GEN_TAC THEN DISCH_TAC THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[INT_ARITH `(x:int) <= &0 <=> ~(x > &0)`] THEN + MATCH_MP_TAC(TAUT `(p <=> q) ==> (p <=> q) /\ (~p <=> ~q)`) THEN + REWRITE_TAC[decompose_88_r0; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ONCE_REWRITE_TAC[COND_RAND] THEN REWRITE_TAC[SND; FST] THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + TRY R0_SIGN_LOWER_NOWRAP_TAC_88 THEN + TRY R0_SIGN_UPPER_NOWRAP_TAC_88 THEN + TRY R0_SIGN_WRAP_TAC_88 THEN + TRY( + (* Contradiction: lower nowrap with > 4190208 *) + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER_88) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN DISCH_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV + (SPECL [`r:num`; `190464`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - &((let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw)) * &190464 = &(r MOD 190464)` ASSUME_TAC THENL + [CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(&(r MOD 190464) > (&4190208:int))` MP_TAC THENL + [REWRITE_TAC[INT_NOT_LT; INT_OF_NUM_LE] THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN REWRITE_TAC[INT_OF_NUM_GT] THEN ASM_ARITH_TAC + ));; + +let MLDSA_USE_HINT_88_EQUIV = prove( + `!r h. r < 8380417 /\ h <= 1 + ==> mldsa_use_hint_88 h r = mldsa_use_hint_88_code r h`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[MLDSA_USE_HINT_88_UNFOLD] THEN + REWRITE_TAC[mldsa_use_hint_88_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + DISCH_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R0_SIGN) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN STRIP_TAC THEN + ASM_CASES_TAC `h = 0` THENL + [ASM_REWRITE_TAC[ARITH_RULE `~(0 = 1)`]; ALL_TAC] THEN + SUBGOAL_THEN `h = 1` SUBST_ALL_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `decompose_88_r0 r > &0` THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[]);; +(* ========================================================================= *) +(* 88 framework + store-discharge + correctness chain. *) +(* Loaded on top of the 88 program defs and the per-element scalar chain. *) +(* The YMM constant map is {ymm8=11275, ymm4=127, ymm7=128, ymm6=43, *) +(* ymm3=8285184, ymm5=0}; the a1<=43 reduction and the VPBLENDVB clamp are *) +(* handled by UH88_STORE_DISCHARGE_TAC. *) +(* ========================================================================= *) + +(* a1 * 2*GAMMA2 = a1*190464 built by the (3a<<5 - 3a)<<11 shift chain. *) +let SHL_190464 = BITBLAST_RULE + `!a:int32. word_shl (word_sub (word_shl (word_add (word_shl a 1) a) 5) (word_add (word_shl a 1) a)) 11 = + word_mul a (word 190464)`;; + +(* The (16,16) high half of the VPMULHRSW lane uses multiplier (word 0), so it + contributes nothing: the high a1 lane is zero. *) +let A1HI_ZERO_88 = BITBLAST_RULE + `word_subword + (word_add + (word_ushr + (word_mul + (word_sx (word_subword + (word_mul (word_zx (word_subword + (word_ushr (word_add (word 127) (x:int32)) 7) (16,16) :16 word) :int32) + (word 0)) (16,16) :16 word) :int32) + (word 0)) 14) + (word 1)) (1,16) :16 word = word 0`;; + +(* The final >43 clamp in the andn form equals the model's neg-mask form. *) +let TAIL_CLOSE_88 = prove + (`!z:int32. word_and (word_not (if word_igt z (word 43) then word 4294967295 else word 0)) z = + word_and z (word_not (word_neg (word (bitval (word_igt z (word 43))))))`, + GEN_TAC THEN COND_CASES_TAC THEN ASM_REWRITE_TAC[BITVAL_CLAUSES] THEN CONV_TAC WORD_BLAST);; + +(* Commutativity closer used after the surface rewrites line up the lane. *) +let LANE_AC_CLOSE_88 = prove + (`!d y m a:int32. word_add (word_mul d y) (word_and (word_not m) a) = + word_add (word_and a (word_not m)) (word_mul d y)`, + REPEAT GEN_TAC THEN CONV_TAC WORD_BLAST);; + +let WORD_ILT_45 = BITBLAST_RULE `!V:int32. val V <= 45 ==> ~(word_ilt V (word 0))`;; + +(* word_subword distributes through word_not on each 32-bit lane. *) +let UH88_WSN = map (fun n -> prove( + subst [mk_small_numeral n, `n:num`] + `!z:int256. word_subword(word_not z) (n,32):int32 = word_not(word_subword z (n,32))`, + GEN_TAC THEN MATCH_MP_TAC WORD_SUBWORD_NOT THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_256] THEN ARITH_TAC)) [0;32;64;96;128;160;192;224];; + +(* VPBLENDVB lane collapse: the post-blend byte-mux of a 32-bit lane (43 if the + lane is negative, else the lane itself) collapses to the clamp-against-43, + for a lane value below 46 or equal to -1. The four bytes of the blend select + 43 / 0 by the lane's per-byte sign bits. *) +let BLEND_COLLAPSE_LE45 = BITBLAST_RULE + `val(V:int32) <= 45 + ==> word (val (if bit 7 V then word 43 else word_subword V (0,8):byte) * 1 + + val (if bit 15 V then word 0 else word_subword V (8,8):byte) * 256 + + val (if bit 23 V then word 0 else word_subword V (16,8):byte) * 65536 + + val (if bit 31 V then word 0 else word_subword V (24,8):byte) * 16777216) = V`;; + +let BLEND_COLLAPSE = prove + (`val(V:int32) <= 45 \/ V = word 4294967295 + ==> word (val (if bit 7 V then word 43 else word_subword V (0,8):byte) * 1 + + val (if bit 15 V then word 0 else word_subword V (8,8):byte) * 256 + + val (if bit 23 V then word 0 else word_subword V (16,8):byte) * 65536 + + val (if bit 31 V then word 0 else word_subword V (24,8):byte) * 16777216) = + (if word_ilt V (word 0) then word 43 else V)`, + STRIP_TAC THENL + [MP_TAC(SPEC `V:int32` (GEN `V:int32` BLEND_COLLAPSE_LE45)) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN ASM_SIMP_TAC[WORD_ILT_45]; + ASM_REWRITE_TAC[] THEN + CONV_TAC(BINOP_CONV(TRY_CONV(LAND_CONV NUM_REDUCE_CONV))) THEN + CONV_TAC WORD_REDUCE_CONV THEN CONV_TAC NUM_REDUCE_CONV]);; +let BLEND_COLLAPSE_GEN = GEN `V:int32` BLEND_COLLAPSE;; + +(* ------------------------------------------------------------------------- *) +(* The post-VPBLENDVB store discharge. The store goal has the form *) +(* word_and (word_not (word_join_of_8 (if word_igt (word_subword B (32k,32)) *) +(* (word 43) then -1 else 0))) B = simd8 mldsa_use_hint_88_x86_asm av hv *) +(* where B (the post-blend abbreviation, word(SUM val(byte_i)*256^i)) is *) +(* defined by a chain of abbreviation assumptions feeding from PRE (the *) +(* pre-blend value). The tactic distributes the equality to the 32-bit *) +(* lanes, bridges each word_subword B (32k,32) to the clamp form via byte *) +(* extraction + BLEND_COLLAPSE, then collapses the per-lane chain to lane *) +(* model and discharges via the LUH/RANGE lane lemmas built from the goal. *) +(* ------------------------------------------------------------------------- *) + +(* Per-lane byte-MOD extraction over 32 abstract byte values, lane k. *) +let UH88_BYTEMOD_LANE = + let bvars = map (fun i -> mk_var("b"^string_of_int i, `:num`)) (0--31) in + let pow256 i = mk_numeral(Num.power_num (Num.num_of_int 256) (Num.num_of_int i)) in + let mkadd l = end_itlist (fun a b -> mk_binop `(+):num->num->num` a b) l in + let mkmul a b = mk_binop `( * ):num->num->num` a b in + let bigsum_v = mkadd (map (fun i -> mkmul (el i bvars) (pow256 i)) (0--31)) in + let lane0_v = mkadd [mkmul (el 0 bvars) `1`; mkmul (el 1 bvars) `256`; + mkmul (el 2 bvars) `65536`; mkmul (el 3 bvars) `16777216`] in + let bounds = end_itlist (fun a b -> mk_conj(a,b)) (map (fun b -> mk_binop `(<):num->num->bool` b `256`) bvars) in + let regroup_th = ARITH_RULE (mk_eq(bigsum_v, mk_binop `(+):num->num->num` lane0_v + (mkmul `4294967296` (mkadd (map (fun i -> mkmul (el i bvars) (pow256 (i-4))) (4--31)))))) in + fun k -> + if k=0 then + prove(mk_imp(bounds, mk_eq(mk_binop `MOD` (mk_binop `DIV` bigsum_v `2 EXP 0`) `2 EXP 32`, lane0_v)), + STRIP_TAC THEN CONV_TAC(ONCE_DEPTH_CONV NUM_REDUCE_CONV) THEN + REWRITE_TAC[DIV_1] THEN ONCE_REWRITE_TAC[regroup_th] THEN + REWRITE_TAC[MOD_MULT_ADD] THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC) + else + let off = 4*k in + let lower = mkadd (map (fun i -> mkmul (el i bvars) (pow256 i)) (0--(off-1))) in + let upper = mkadd (map (fun i -> mkmul (el i bvars) (pow256 (i-off))) (off--31)) in + let lane = mkadd [mkmul (el off bvars) `1`; mkmul (el (off+1) bvars) `256`; + mkmul (el (off+2) bvars) `65536`; mkmul (el (off+3) bvars) `16777216`] in + let p256_4k = pow256 off in + let split_big = mk_eq(bigsum_v, mk_binop `(+):num->num->num` lower (mkmul p256_4k upper)) in + let conc = mk_eq(mk_binop `MOD` (mk_binop `DIV` bigsum_v (mk_binop `EXP` `2` (mk_small_numeral(32*k)))) `2 EXP 32`, lane) in + let div_prefix = + STRIP_TAC THEN CONV_TAC(ONCE_DEPTH_CONV NUM_REDUCE_CONV) THEN + GEN_REWRITE_TAC (LAND_CONV o LAND_CONV o LAND_CONV) [ARITH_RULE split_big] THEN + SUBGOAL_THEN (mk_binop `(<):num->num->bool` lower p256_4k) ASSUME_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_SIMP_TAC[DIV_MULT_ADD; ARITH_RULE (mk_neg(mk_eq(p256_4k,`0`)))] THEN + ASM_SIMP_TAC[DIV_LT] THEN REWRITE_TAC[ADD_CLAUSES] in + if off+4 > 31 then + prove(mk_imp(bounds,conc), div_prefix THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC) + else + let uppertail = mkadd (map (fun i -> mkmul (el i bvars) (pow256 (i-off-4))) ((off+4)--31)) in + let split_up = mk_eq(upper, mk_binop `(+):num->num->num` lane (mkmul `4294967296` uppertail)) in + prove(mk_imp(bounds,conc), + div_prefix THEN GEN_REWRITE_TAC (LAND_CONV o LAND_CONV) [ARITH_RULE split_up] THEN + REWRITE_TAC[MOD_MULT_ADD] THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC);; + +let UH88_BYTEMOD_LANES = map UH88_BYTEMOD_LANE (0--7);; + +(* Relabel a single lane byte of PRE to the per-lane subword V form. *) +let uh88_byte_relabel pre k m c = + let tmpl = `(if bit pbit (pp:int256) then (cc:byte) else word_subword pp (ppos,8):byte) = + (if bit vbit (word_subword pp (lane,32):int32) then (cc:byte) + else word_subword (word_subword pp (lane,32):int32) (vpos,8):byte)` in + let inst = subst [ mk_small_numeral(32*k+8*m+7), `pbit:num`; + mk_small_numeral(8*m+7), `vbit:num`; + mk_small_numeral(32*k+8*m), `ppos:num`; + mk_small_numeral(8*m), `vpos:num`; + mk_small_numeral(32*k), `lane:num`; + c, `cc:byte`; pre, `pp:int256` ] tmpl in + BITBLAST_RULE inst;; + +let UH88_STORE_DISCHARGE_TAC : tactic = + fun (asl,w) -> + if not(is_eq w) || not(can(find_term(fun t->try fst(dest_const(fst(strip_comb t)))="simd8" with _->false)) w) + then failwith "not store goal" else + let isu t = is_var t && (let v=fst(dest_var t) in String.length v>0 && v.[0]='_') in + let chaindefs = List.filter_map (fun (_,th)-> let c=concl th in if is_eq c && isu(rand c) then Some th else None) asl in + let chain_syms = map SYM chaindefs in + let collapse = TOP_DEPTH_CONV (WORD_SIMPLE_SUBWORD_CONV ORELSEC + GEN_REWRITE_CONV I [WORD_SUBWORD_AND] ORELSEC + FIRST_CONV (map (fun th -> GEN_REWRITE_CONV I [th]) UH88_WSN)) in + let onelevel = GEN_REWRITE_CONV ONCE_DEPTH_CONV chain_syms THENC collapse in + let nchain = List.length chaindefs in + let lane_collapse t = let c0 = collapse t in + let rec iter th n = if n=0 then th else iter (TRANS th (onelevel(rand(concl th)))) (n-1) in + iter c0 (nchain+2) in + let bdef = List.find (fun th -> let l=lhs(concl th) in (try fst(dest_const(fst(strip_comb l)))="word" with _->false)) chaindefs in + let bvar = rand(concl bdef) in + let bytesum = rand(lhs(concl bdef)) in + let prevar = find_term (fun t -> isu t && type_of t=`:int256` && not(t=bvar)) bytesum in + let bound_hyp = snd(List.find (fun (_,th)-> try concl th = `!k. k < 8 ==> val(word_subword (hv:int256) (32*k,32):int32) <= 1` with _->false) asl) in + let bvals = map lhand (striplist (dest_binop `(+):num->num->num`) bytesum) in + let xk k = mk_comb(mk_comb(`word_subword:int256->num#num->int32`,`av:int256`),mk_pair(mk_small_numeral(32*k),`32`)) in + let yk k = mk_comb(mk_comb(`word_subword:int256->num#num->int32`,`hv:int256`),mk_pair(mk_small_numeral(32*k),`32`)) in + let bndk k = CONV_RULE(ONCE_DEPTH_CONV NUM_REDUCE_CONV) + (MP (SPEC (mk_small_numeral k) bound_hyp) (EQT_ELIM(NUM_REDUCE_CONV(mk_binop `(<):num->num->bool` (mk_small_numeral k) `8`)))) in + let vk k = mk_comb(mk_comb(`word_subword:int256->num#num->int32`,prevar),mk_pair(mk_small_numeral(32*k),`32`)) in + let vcollk k = lane_collapse (vk k) in + let surfaceT0 = rand(concl(vcollk 0)) in + let surfaceTxy = subst [`x:int32`, xk 0; `y:int32`, yk 0] surfaceT0 in + let luh_lhs = subst [surfaceTxy,`tt:int32`] + `word_and (word_not (if word_igt (if word_ilt (tt:int32) (word 0) then word 43 else tt) (word 43) + then word 4294967295 else word 0)) (if word_ilt (tt:int32) (word 0) then word 43 else tt)` in + let LUH0 = prove(mk_imp(`val(y:int32) <= 1`, mk_eq(luh_lhs, `mldsa_use_hint_88_x86_asm x y`)), + DISCH_TAC THEN REWRITE_TAC[mldsa_use_hint_88_x86_asm] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN REWRITE_TAC[SHL_190464] THEN + REWRITE_TAC[A1HI_ZERO_88; JOIN_ZERO_ZX] THEN ASM_SIMP_TAC[DELTA_EQ_BOUNDED] THEN + REWRITE_TAC[LANE_AC_CLOSE_88] THEN REWRITE_TAC[TAIL_CLOSE_88]) in + let LUH_GEN = GENL [`x:int32`;`y:int32`] LUH0 in + let a1p_term = (match snd(strip_comb surfaceTxy) with [_;a]->a|_->failwith "T") in + let dsh_term = (match snd(strip_comb surfaceTxy) with [s0;_]->rand s0|_->failwith "T") in + let A1P_BOUND = BITBLAST_RULE (mk_binop `(<=):num->num->bool` (mk_comb(`val:int32->num`,a1p_term)) `44`) in + let RANGE_FINAL_BB = BITBLAST_RULE + `!A1P:int32 DSH:int32. val A1P <= 44 /\ (DSH = word 0 \/ DSH = word 2) + ==> (val(word_add (word_sub (word 1) DSH) A1P) <= 45 \/ + word_add (word_sub (word 1) DSH) A1P = word 4294967295)` in + let Y0_BB = BITBLAST_RULE `!A1P:int32. val A1P <= 44 ==> val(word_add (word_sub (word 0) (word 0)) A1P) <= 45` in + let DSH_VALS = BITBLAST_RULE (mk_imp(`val(y:int32) <= 1`, + mk_disj(mk_eq(dsh_term,`word 0:int32`), mk_eq(dsh_term,`word 2:int32`)))) in + let range_goal = mk_imp(`val(y:int32) <= 1`, + mk_disj(mk_binop `(<=):num->num->bool` (mk_comb(`val:int32->num`,surfaceTxy)) `45`, + mk_eq(surfaceTxy,`word 4294967295:int32`))) in + let RANGE0 = prove(range_goal, + DISCH_TAC THEN ABBREV_TAC (mk_eq(`A1P:int32`, a1p_term)) THEN + SUBGOAL_THEN `val(A1P:int32) <= 44` ASSUME_TAC THENL + [EXPAND_TAC "A1P" THEN ACCEPT_TAC A1P_BOUND; ALL_TAC] THEN + SUBGOAL_THEN `(y:int32) = word 0 \/ y = word 1` STRIP_ASSUME_TAC THENL + [UNDISCH_TAC `val(y:int32) <= 1` THEN SPEC_TAC(`y:int32`,`y:int32`) THEN + REWRITE_TAC[GSYM VAL_EQ_0; GSYM VAL_EQ_1] THEN ARITH_TAC; ALL_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[WORD_AND_0; WORD_SHL_0] THENL + [DISJ1_TAC THEN MATCH_MP_TAC Y0_BB THEN ASM_REWRITE_TAC[]; + MP_TAC DSH_VALS THEN ASM_REWRITE_TAC[] THEN + ABBREV_TAC (mk_eq(`DSH:int32`, dsh_term)) THEN DISCH_TAC THEN + MATCH_MP_TAC RANGE_FINAL_BB THEN ASM_REWRITE_TAC[]]) in + let RANGE_GEN = GENL [`x:int32`;`y:int32`] RANGE0 in + let mk_blane k = + let subB = mk_comb(mk_comb(`word_subword:int256->num#num->int32`,bvar),mk_pair(mk_small_numeral(32*k),`32`)) in + let thA1 = GEN_REWRITE_CONV ONCE_DEPTH_CONV [GSYM bdef] subB in + let inst_wsw = PART_MATCH (lhs o snd o dest_imp) WORD_SUBWORD_WORD (rand(concl thA1)) in + let side_th = EQT_ELIM((REWRITE_CONV[DIMINDEX_256] THENC NUM_REDUCE_CONV)(fst(dest_imp(concl inst_wsw)))) in + let stepA = TRANS thA1 (MP inst_wsw side_th) in + let bml_i = INST (map2 (fun i bv -> (bv, mk_var("b"^string_of_int i,`:num`))) (0--31) bvals) (List.nth UH88_BYTEMOD_LANES k) in + let bounds_th = end_itlist CONJ (map (fun bv -> CONV_RULE(RAND_CONV NUM_REDUCE_CONV)(REWRITE_RULE[DIMINDEX_8](ISPEC (rand bv) VAL_BOUND))) bvals) in + let stepB = TRANS stepA (AP_TERM `word:num->int32` (MP bml_i bounds_th)) in + let relabels = map (fun m -> uh88_byte_relabel prevar k m (if m=0 then `word 43:byte` else `word 0:byte`)) (0--3) in + let stepC = CONV_RULE (RAND_CONV (RAND_CONV (REWRITE_CONV relabels))) stepB in + let v = vk k in + let rng = ONCE_REWRITE_RULE[SYM(vcollk k)] (MP (SPECL [xk k;yk k] RANGE_GEN) (bndk k)) in + let bc = MP (SPEC v BLEND_COLLAPSE_GEN) rng in + TRANS stepC bc in + let blanes = map mk_blane (0--7) in + let vcolls = map vcollk (0--7) in + let luhs = map (fun k -> MP (SPECL [xk k;yk k] LUH_GEN) (bndk k)) (0--7) in + (MATCH_MP_TAC LANES8_EQ THEN CONV_TAC EXPAND_CASES_CONV THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[WORD_SUBWORD_AND] THEN REWRITE_TAC UH88_WSN THEN + CONV_TAC(DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN + REWRITE_TAC[simd8;simd4;simd2;DIMINDEX_32] THEN + CONV_TAC(DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + REWRITE_TAC blanes THEN REWRITE_TAC vcolls THEN REWRITE_TAC luhs) (asl,w);; + +(* ------------------------------------------------------------------------- *) +(* Broadcast constants as 256-bit duplicates of their 32-bit lane value. *) +(* ------------------------------------------------------------------------- *) +let DUPLITS_88 = map (fun (n,c) -> prove(mk_eq(mk_comb(`word:num->int256`, mk_numeral(Num.num_of_string n)), + mk_comb(`word_duplicate:int32->int256`, c)), CONV_TAC WORD_BLAST)) + ["3423913227525323174502430081042878883520180111764122672559515536195711", `word 127:int32`; + "303973398742897785767833851683137475682598667402680969552035729689816075", `word 11275:int32`; + "3450873174198750916033945278531405488902228774061477969193842430181504", `word 128:int32`; + "1159277706957392885855153492006644031428092478786277755276056441389099", `word 43:int32`; + "223368118819536749293045209988780814485663464087451345989979032820788390912", `word 8285184:int32`];; + +(* ------------------------------------------------------------------------- *) +(* Loop body (one iteration). *) +(* ------------------------------------------------------------------------- *) + +let POLY_USE_HINT_88_AVX2_ASM_BODY_BLOCK_TAC : tactic = + REPEAT STRIP_TAC THEN + ENSURES_INIT_TAC "s0" THEN + MP_TAC(SPECL [`a:int64`;`i:num`] ALIGNED_BLOCK) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + MP_TAC(SPECL [`h:int64`;`i:num`] ALIGNED_BLOCK) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + SUBGOAL_THEN `read (memory :> bytes256 (word_add a (word(32*i)))) s0 = xb i` ASSUME_TAC THENL + [FIRST_ASSUM(fun th -> + if can (term_match [] + `!b. i <= b /\ b < 32 ==> read(memory :> bytes256(word_add a (word(32*b)))) s0 = xb b`) (concl th) + then MP_TAC(SPEC `i:num` th) else failwith "no") THEN + ANTS_TAC THENL [ASM_ARITH_TAC; DISCH_THEN ACCEPT_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `read (memory :> bytes256 (word_add h (word(32*i)))) s0 = yb i` ASSUME_TAC THENL + [FIRST_ASSUM(fun th -> + if can (term_match [] + `!b. b < 32 ==> read(memory :> bytes256(word_add h (word(32*b)))) s0 = yb b`) (concl th) + then MP_TAC(SPEC `i:num` th) else failwith "no") THEN + ANTS_TAC THENL [ASM_ARITH_TAC; DISCH_THEN ACCEPT_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `!k. k < 8 ==> val(word_subword ((yb:num->int256) i) (32*k,32):int32) <= 1` ASSUME_TAC THENL + [GEN_TAC THEN DISCH_TAC THEN + FIRST_ASSUM(fun th -> if can (term_match [] + `!b k. b < 32 /\ k < 8 ==> val(word_subword ((yb:num->int256) b) (32*k,32):int32) <= 1`) (concl th) + then MP_TAC(SPECL [`i:num`;`k:num`] th) else failwith "no") THEN + ANTS_TAC THENL [ASM_ARITH_TAC; DISCH_THEN ACCEPT_TAC]; ALL_TAC] THEN + SUBGOAL_THEN + `!b. i+1 <= b /\ b < 32 ==> read(memory :> bytes256(word_add a (word(32*b)))) s0 = xb b` + ASSUME_TAC THENL + [REPEAT STRIP_TAC THEN + FIRST_ASSUM(fun th -> if can (term_match [] + `!b. i <= b /\ b < 32 ==> read(memory :> bytes256(word_add a (word(32*b)))) s0 = xb b`) (concl th) + then MP_TAC(SPEC `b:num` th) else failwith "no") THEN + ANTS_TAC THENL [ASM_ARITH_TAC; DISCH_THEN ACCEPT_TAC]; ALL_TAC] THEN + EVERY (map (fun n -> X86_STEPS_TAC POLY_USE_HINT_88_AVX2_ASM_EXEC [n] THEN SIMD_SIMPLIFY_TAC[] THEN ABBREV_BIG_TAC) (1--28)) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + REPEAT CONJ_TAC THEN + TRY(REWRITE_TAC[ARITH_RULE `32 * (i + 1) = 32 * i + 32`] THEN CONV_TAC WORD_RULE) THEN + TRY(X_GEN_TAC `b:num` THEN DISCH_TAC THEN ASM_CASES_TAC `b < i` THENL + [UNDISCH_TAC `b:num < i` THEN + UNDISCH_TAC `!b. b < i ==> read (memory :> bytes256 (word_add a (word (32 * b)))) s28 = simd8 mldsa_use_hint_88_x86_asm (xb b) (yb b)` THEN + MESON_TAC[]; + SUBGOAL_THEN `b:num = i` SUBST_ALL_TAC THENL + [UNDISCH_TAC `b:num < i + 1` THEN UNDISCH_TAC `~(b:num < i)` THEN ARITH_TAC; ALL_TAC] THEN + FIRST_X_ASSUM(fun th -> + try let l,_ = dest_eq (concl th) in + if l = `read (memory :> bytes256 (word_add a (word (32 * i)))) s28` + then SUBST1_TAC th else failwith "no" + with _ -> failwith "no") THEN + ABBREV_TAC `av:int256 = (xb:num->int256) i` THEN + ABBREV_TAC `hv:int256 = (yb:num->int256) i` THEN + UH88_STORE_DISCHARGE_TAC]) THEN + TRY(REWRITE_TAC[VAL_WORD_ADD; VAL_WORD; DIMINDEX_64] THEN + FIRST_ASSUM(fun th -> if concl th = `i < 32` then MP_TAC th else failwith "no") THEN + SPEC_TAC(`i:num`,`i:num`) THEN + CONV_TAC EXPAND_CASES_CONV THEN CONV_TAC NUM_REDUCE_CONV);; + +(* ------------------------------------------------------------------------- *) +(* Block-function correctness. *) +(* ------------------------------------------------------------------------- *) + +let POLY_USE_HINT_88_AVX2_ASM_BLOCK_CORRECT = prove + (`!a h xb yb pc. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, 0xdd) (a, 1024) /\ nonoverlapping (a, 1024) (h, 1024) /\ + (!b k. b < 32 /\ k < 8 ==> val(word_subword (yb b:int256) (32*k,32):int32) <= 1) + ==> ensures x86 + (\s. bytes_loaded s (word pc) (BUTLAST poly_use_hint_88_avx2_asm_tmc) /\ + read RIP s = word pc /\ + C_ARGUMENTS [a; h] s /\ + (!b. b < 32 ==> read(memory :> bytes256(word_add a (word(32 * b)))) s = xb b) /\ + (!b. b < 32 ==> read(memory :> bytes256(word_add h (word(32 * b)))) s = yb b)) + (\s. read RIP s = word(pc + 0xdd) /\ + (!b. b < 32 ==> + read(memory :> bytes256(word_add a (word(32 * b)))) s = + simd8 mldsa_use_hint_88_x86_asm (xb b) (yb b))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, MAYCHANGE [memory :> bytes(a, 1024)])`, + MAP_EVERY X_GEN_TAC [`a:int64`;`h:int64`;`xb:num->int256`;`yb:num->int256`;`pc:num`] THEN + REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; C_ARGUMENTS; NONOVERLAPPING_CLAUSES; ALL; + fst POLY_USE_HINT_88_AVX2_ASM_EXEC] THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN REWRITE_TAC[SOME_FLAGS] THEN + ENSURES_WHILE_PUP_TAC `32` `pc + 0x52` `pc + 0xd7` + `\i s. + (read RDI s = word_add a (word(32 * i)) /\ + read RSI s = word_add h (word(32 * i)) /\ + read RAX s = word(32 * i) /\ + read YMM8 s = (word_duplicate (word 11275:int32):int256) /\ + read YMM4 s = (word_duplicate (word 127:int32):int256) /\ + read YMM7 s = (word_duplicate (word 128:int32):int256) /\ + read YMM6 s = (word_duplicate (word 43:int32):int256) /\ + read YMM3 s = (word_duplicate (word 8285184:int32):int256) /\ + read YMM5 s = (word 0:int256) /\ + (!b. b < 32 ==> read(memory :> bytes256(word_add h (word(32 * b)))) s = yb b) /\ + (!b. i <= b /\ b < 32 + ==> read(memory :> bytes256(word_add a (word(32 * b)))) s = xb b) /\ + (!b. b < i + ==> read(memory :> bytes256(word_add a (word(32 * b)))) s = + simd8 mldsa_use_hint_88_x86_asm (xb b) (yb b))) + /\ + (read ZF s <=> i = 32)` THEN + REWRITE_TAC[ARITH_RULE `~(32 = 0)`] THEN CONV_TAC(ONCE_DEPTH_CONV NUM_REDUCE_CONV) THEN + REPEAT CONJ_TAC THENL + [ + REWRITE_TAC[MULT_CLAUSES; WORD_ADD_0] THEN + ENSURES_INIT_TAC "s0" THEN + X86_STEPS_TAC POLY_USE_HINT_88_AVX2_ASM_EXEC (1--17) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC DUPLITS_88 THEN + REWRITE_TAC[ARITH_RULE `b < 0 <=> F`; LE_0] THEN ASM_REWRITE_TAC[] + ; + POLY_USE_HINT_88_AVX2_ASM_BODY_BLOCK_TAC + ; + REPEAT STRIP_TAC THEN X86_SIM_TAC POLY_USE_HINT_88_AVX2_ASM_EXEC (1--1) + ; + REWRITE_TAC[ARITH_RULE `32 <= b /\ b < 32 <=> F`] THEN + ENSURES_INIT_TAC "s0" THEN + X86_STEPS_TAC POLY_USE_HINT_88_AVX2_ASM_EXEC (1--1) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] + ]);; + +(* ------------------------------------------------------------------------- *) +(* Core correctness theorem (coefficient form, FIPS 204 UseHint). *) +(* This must be kept in sync with the CBMC specification in *) +(* mldsa/src/native/x86_64/src/arith_native_x86_64.h *) +(* ------------------------------------------------------------------------- *) + +let POLY_USE_HINT_88_AVX2_ASM_CORRECT = prove + (`!a h x y pc. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, 0xdd) (a, 1024) /\ nonoverlapping (a, 1024) (h, 1024) + ==> ensures x86 + (\s. bytes_loaded s (word pc) (BUTLAST poly_use_hint_88_avx2_asm_tmc) /\ + read RIP s = word pc /\ + C_ARGUMENTS [a; h] s /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add h (word(4 * i)))) s = y i) /\ + (!i. i < 256 ==> val(x i:int32) < 8380417) /\ + (!i. i < 256 ==> val(y i:int32) <= 1)) + (\s. read RIP s = word(pc + 0xdd) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = + word(mldsa_use_hint_88 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add a (word(4 * i)))) s) < 44)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, MAYCHANGE [memory :> bytes(a, 1024)])`, + MAP_EVERY X_GEN_TAC [`a:int64`;`h:int64`;`x:num->int32`;`y:num->int32`;`pc:num`] THEN + STRIP_TAC THEN + ASM_CASES_TAC `(!i. i < 256 ==> val(x i:int32) < 8380417) /\ + (!i. i < 256 ==> val(y i:int32) <= 1)` + THENL + [FIRST_X_ASSUM(CONJUNCTS_THEN ASSUME_TAC) THEN + MATCH_MP_TAC ENSURES_PREPOSTCONDITION_THM THEN + MAP_EVERY EXISTS_TAC + [`\s. bytes_loaded s (word pc) (BUTLAST poly_use_hint_88_avx2_asm_tmc) /\ + read RIP s = word pc /\ + C_ARGUMENTS [a; h] s /\ + (!b. b < 32 ==> read(memory :> bytes256(word_add a (word(32 * b)))) s = pack8 x b) /\ + (!b. b < 32 ==> read(memory :> bytes256(word_add h (word(32 * b)))) s = pack8 y b)`; + `\s. read RIP s = word(pc + 0xdd) /\ + (!b. b < 32 ==> + read(memory :> bytes256(word_add a (word(32 * b)))) s = + simd8 mldsa_use_hint_88_x86_asm (pack8 x b) (pack8 y b))`] THEN + CONJ_TAC THENL + [ + X_GEN_TAC `s:x86state` THEN REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[] THEN + CONJ_TAC THEN X_GEN_TAC `b:num` THEN DISCH_TAC THEN + MATCH_MP_TAC PACK8_MERGE THEN ASM_REWRITE_TAC[] + ; + CONJ_TAC THENL + [ + X_GEN_TAC `s:x86state` THEN REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN + `!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = + mldsa_use_hint_88_x86_asm (x i) (y i)` + ASSUME_TAC THENL + [X_GEN_TAC `i:num` THEN DISCH_TAC THEN + SUBGOAL_THEN `4 * i = 4 * (8 * (i DIV 8) + i MOD 8)` SUBST1_TAC THENL + [AP_TERM_TAC THEN ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`a:int64`;`s:x86state`;`i DIV 8`;`i MOD 8`] BLOCK_SPLIT) THEN + ANTS_TAC THENL [SIMP_TAC[DIVISION; ARITH_EQ]; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN + FIRST_X_ASSUM(fun th -> MP_TAC(SPEC `i DIV 8` th)) THEN + ANTS_TAC THENL [UNDISCH_TAC `i:num < 256` THEN ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN + MP_TAC(SPECL [`mldsa_use_hint_88_x86_asm`;`pack8 x (i DIV 8)`;`pack8 y (i DIV 8)`;`i MOD 8`] SIMD8_LANE) THEN + ANTS_TAC THENL [SIMP_TAC[DIVISION; ARITH_EQ]; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN + MP_TAC(SPECL [`x:num->int32`;`i DIV 8`;`i MOD 8`] PACK8_LANE) THEN + ANTS_TAC THENL [SIMP_TAC[DIVISION; ARITH_EQ]; ALL_TAC] THEN + MP_TAC(SPECL [`y:num->int32`;`i DIV 8`;`i MOD 8`] PACK8_LANE) THEN + ANTS_TAC THENL [SIMP_TAC[DIVISION; ARITH_EQ]; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN DISCH_THEN SUBST1_TAC THEN + SUBGOAL_THEN `8 * (i DIV 8) + i MOD 8 = i` SUBST1_TAC THENL + [ARITH_TAC; REFL_TAC]; ALL_TAC] THEN + CONJ_TAC THENL + [ + X_GEN_TAC `i:num` THEN DISCH_TAC THEN + FIRST_X_ASSUM(fun th -> MP_TAC(SPEC `i:num` th)) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN + SUBGOAL_THEN `val(x (i:num):int32) < 8380417 /\ val(y (i:num):int32) <= 1` STRIP_ASSUME_TAC THENL + [ASM_SIMP_TAC[]; ALL_TAC] THEN + MP_TAC(SPECL [`x (i:num):int32`;`y (i:num):int32`] ELEMENT_CORRECT_WORD_88) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + AP_TERM_TAC THEN + MP_TAC(SPECL [`val(x (i:num):int32)`;`val(y (i:num):int32)`] MLDSA_USE_HINT_88_EQUIV) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN(fun th -> REWRITE_TAC[th]) + ; + X_GEN_TAC `i:num` THEN DISCH_TAC THEN + FIRST_X_ASSUM(fun th -> MP_TAC(SPEC `i:num` th)) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN + SUBGOAL_THEN `val(x (i:num):int32) < 8380417 /\ val(y (i:num):int32) <= 1` STRIP_ASSUME_TAC THENL + [ASM_SIMP_TAC[]; ALL_TAC] THEN + MP_TAC(SPECL [`x (i:num):int32`;`y (i:num):int32`] ELEMENT_CORRECT_WORD_88) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC(ONCE_DEPTH_CONV NUM_REDUCE_CONV) THEN + MATCH_MP_TAC(ARITH_RULE `n < 44 ==> n MOD 4294967296 < 44`) THEN + REWRITE_TAC[mldsa_use_hint_88_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REPEAT(COND_CASES_TAC THEN ASM_REWRITE_TAC[]) THEN + ASM_ARITH_TAC + ] + ; + MATCH_MP_TAC POLY_USE_HINT_88_AVX2_ASM_BLOCK_CORRECT THEN + ASM_REWRITE_TAC[] THEN REPEAT STRIP_TAC THEN + MP_TAC(SPECL [`y:num->int32`;`b:num`;`k:num`] PACK8_LANE) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + FIRST_X_ASSUM(fun th -> MP_TAC(SPEC `8 * b + k` th)) THEN + ANTS_TAC THENL [UNDISCH_TAC `b:num < 32` THEN UNDISCH_TAC `k:num < 8` THEN ARITH_TAC; + REWRITE_TAC[]] + ]] + ; + MATCH_MP_TAC ENSURES_PRECONDITION_THM THEN + EXISTS_TAC `\s:x86state. F` THEN + REWRITE_TAC[ENSURES_TRIVIAL] THEN + GEN_TAC THEN POP_ASSUM MP_TAC THEN MESON_TAC[]]);; + +(* ========================================================================= *) +(* Public subroutine correctness (with return). *) +(* ========================================================================= *) + +let POLY_USE_HINT_88_AVX2_ASM_NOIBT_SUBROUTINE_CORRECT = prove + (`!a h x y pc stackpointer returnaddress. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, LENGTH poly_use_hint_88_avx2_asm_tmc) (a, 1024) /\ + nonoverlapping (a, 1024) (h, 1024) /\ + nonoverlapping (stackpointer, 8) (a, 1024) + ==> ensures x86 + (\s. bytes_loaded s (word pc) poly_use_hint_88_avx2_asm_tmc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [a; h] s /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add h (word(4 * i)))) s = y i) /\ + (!i. i < 256 ==> val(x i:int32) < 8380417) /\ + (!i. i < 256 ==> val(y i:int32) <= 1)) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = + word(mldsa_use_hint_88 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add a (word(4 * i)))) s) < 44)) + (MAYCHANGE [RSP] ,, MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(a, 1024)])`, + X86_PROMOTE_RETURN_NOSTACK_TAC poly_use_hint_88_avx2_asm_tmc POLY_USE_HINT_88_AVX2_ASM_CORRECT);; + +let POLY_USE_HINT_88_AVX2_ASM_SUBROUTINE_CORRECT = prove + (`!a h x y pc stackpointer returnaddress. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, LENGTH poly_use_hint_88_avx2_asm_mc) (a, 1024) /\ + nonoverlapping (a, 1024) (h, 1024) /\ + nonoverlapping (stackpointer, 8) (a, 1024) + ==> ensures x86 + (\s. bytes_loaded s (word pc) poly_use_hint_88_avx2_asm_mc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [a; h] s /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> read(memory :> bytes32(word_add h (word(4 * i)))) s = y i) /\ + (!i. i < 256 ==> val(x i:int32) < 8380417) /\ + (!i. i < 256 ==> val(y i:int32) <= 1)) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = + word(mldsa_use_hint_88 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add a (word(4 * i)))) s) < 44)) + (MAYCHANGE [RSP] ,, MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(a, 1024)])`, + MATCH_ACCEPT_TAC(ADD_IBT_RULE POLY_USE_HINT_88_AVX2_ASM_NOIBT_SUBROUTINE_CORRECT));; + +(* ========================================================================= *) +(* Constant-time and memory safety proof. *) +(* ========================================================================= *) + +needs "s2n_bignum/x86/proofs/consttime.ml";; +needs "mldsa_native/x86_64/proofs/subroutine_signatures.ml";; + +let NORMALIZE_AND_EXPAND_YMM_TAC : tactic = + RULE_ASSUM_TAC(REWRITE_RULE[WORD_ADD_0]) THEN EXPAND_MAYCHANGE_YMM_REGS_TAC;; + +let full_spec,public_vars = mk_safety_spec + ~keep_maychanges:true + (assoc "mldsa_poly_use_hint_88_x86" subroutine_signatures) + (REWRITE_RULE[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; SOME_FLAGS] + POLY_USE_HINT_88_AVX2_ASM_CORRECT) + POLY_USE_HINT_88_AVX2_ASM_EXEC;; + +let POLY_USE_HINT_88_AVX2_ASM_SAFE = time prove + (full_spec, + REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; SOME_FLAGS] THEN + GEN_PROVE_SAFETY_SPEC_TAC ~public_vars:public_vars + ~tac_before_maychange_simp:NORMALIZE_AND_EXPAND_YMM_TAC + POLY_USE_HINT_88_AVX2_ASM_EXEC + [BYTES_LOADED_APPEND_CLAUSE] X86_SINGLE_STEP_TAC);; + +let POLY_USE_HINT_88_AVX2_ASM_NOIBT_SUBROUTINE_SAFE = time prove + (`exists f_events. + forall e a h pc stackpointer returnaddress. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, LENGTH poly_use_hint_88_avx2_asm_tmc) (a, 1024) /\ + nonoverlapping (a, 1024) (h, 1024) /\ + nonoverlapping (stackpointer, 8) (a, 1024) + ==> ensures x86 + (\s. + bytes_loaded s (word pc) poly_use_hint_88_avx2_asm_tmc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [a; h] s /\ + read events s = e) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (exists e2. + read events s = APPEND e2 e /\ + e2 = f_events a h pc stackpointer returnaddress /\ + memaccess_inbounds e2 [a,1024; h,1024; stackpointer,8] + [a,1024; stackpointer,8])) + (\s s'. true)`, + X86_PROMOTE_RETURN_NOSTACK_TAC poly_use_hint_88_avx2_asm_tmc POLY_USE_HINT_88_AVX2_ASM_SAFE THEN + DISCHARGE_SAFETY_PROPERTY_TAC);; + +let POLY_USE_HINT_88_AVX2_ASM_SUBROUTINE_SAFE = time prove + (`exists f_events. + forall e a h pc stackpointer returnaddress. + aligned 32 a /\ aligned 32 h /\ + nonoverlapping (word pc, LENGTH poly_use_hint_88_avx2_asm_mc) (a, 1024) /\ + nonoverlapping (a, 1024) (h, 1024) /\ + nonoverlapping (stackpointer, 8) (a, 1024) + ==> ensures x86 + (\s. + bytes_loaded s (word pc) poly_use_hint_88_avx2_asm_mc /\ + read RIP s = word pc /\ + read RSP s = stackpointer /\ + read (memory :> bytes64 stackpointer) s = returnaddress /\ + C_ARGUMENTS [a; h] s /\ + read events s = e) + (\s. read RIP s = returnaddress /\ + read RSP s = word_add stackpointer (word 8) /\ + (exists e2. + read events s = APPEND e2 e /\ + e2 = f_events a h pc stackpointer returnaddress /\ + memaccess_inbounds e2 [a,1024; h,1024; stackpointer,8] + [a,1024; stackpointer,8])) + (\s s'. true)`, + MATCH_ACCEPT_TAC(ADD_IBT_RULE POLY_USE_HINT_88_AVX2_ASM_NOIBT_SUBROUTINE_SAFE));; + diff --git a/proofs/hol_light/x86_64/proofs/subroutine_signatures.ml b/proofs/hol_light/x86_64/proofs/subroutine_signatures.ml index 2695535a3..37954c6bd 100644 --- a/proofs/hol_light/x86_64/proofs/subroutine_signatures.ml +++ b/proofs/hol_light/x86_64/proofs/subroutine_signatures.ml @@ -70,6 +70,40 @@ let subroutine_signatures = [ ]) ); +("mldsa_poly_use_hint_32_x86", + ([(*args*) + ("a", "int32_t[static 256]", (*is const?*)"false"); + ("h", "int32_t[static 256]", (*is const?*)"true"); + ], + "void", + [(* input buffers *) + ("a", "256"(* num elems *), 4(* elem bytesize *)); + ("h", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* output buffers *) + ("a", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* temporary buffers *) + ]) +); + +("mldsa_poly_use_hint_88_x86", + ([(*args*) + ("a", "int32_t[static 256]", (*is const?*)"false"); + ("h", "int32_t[static 256]", (*is const?*)"true"); + ], + "void", + [(* input buffers *) + ("a", "256"(* num elems *), 4(* elem bytesize *)); + ("h", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* output buffers *) + ("a", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* temporary buffers *) + ]) +); + ("mldsa_pointwise_x86", ([(*args*) ("a", "int32_t[static 256]", (*is const?*)"false"); diff --git a/scripts/autogen b/scripts/autogen index 3ce73146f..800e5efed 100755 --- a/scripts/autogen +++ b/scripts/autogen @@ -2820,6 +2820,18 @@ def hol_light_asm_joblist(): f"-Imldsa/src/native/x86_64/src -Imldsa/src/common.h {x86_64_flags}", "x86_64", ), + ( + "poly_use_hint_32_avx2_asm.S", + "dev/x86_64/src", + f"-Imldsa/src/native/x86_64/src -Imldsa/src/common.h {x86_64_flags}", + "x86_64", + ), + ( + "poly_use_hint_88_avx2_asm.S", + "dev/x86_64/src", + f"-Imldsa/src/native/x86_64/src -Imldsa/src/common.h {x86_64_flags}", + "x86_64", + ), ( "ntt_avx2_asm.S", "dev/x86_64/src", diff --git a/test/bench/bench_components_mldsa.c b/test/bench/bench_components_mldsa.c index 948915dae..98f6eb63e 100644 --- a/test/bench/bench_components_mldsa.c +++ b/test/bench/bench_components_mldsa.c @@ -71,6 +71,7 @@ static int cmp_uint64_t(const void *a, const void *b) static int bench(void) { MLD_ALIGN int32_t data0[256]; + MLD_ALIGN int32_t data1[256]; MLD_ALIGN mld_poly poly_out; MLD_ALIGN mld_polyvecl polyvecl_a, polyvecl_b; MLD_ALIGN mld_polymat polymat; @@ -102,6 +103,9 @@ static int bench(void) chknorm_acc ^= mld_poly_chknorm((const mld_poly *)data0, MLDSA_GAMMA1 - MLDSA_BETA);) + BENCH("poly_use_hint", + mld_poly_use_hint((mld_poly *)data0, (mld_poly *)data1)); + return (int)chknorm_acc; }