diff --git a/polyval/src/field_element/armv8.rs b/polyval/src/field_element/armv8.rs index b2e5b9b..483ceab 100644 --- a/polyval/src/field_element/armv8.rs +++ b/polyval/src/field_element/armv8.rs @@ -109,12 +109,12 @@ unsafe fn karatsuba1(x: Simd128, y: Simd128) -> (Simd128, Simd128, Simd128) { // M H L // // m = x.hi^x.lo * y.hi^y.lo - let m = pmull::<0>( + let m = pmull( veorq_u8(x, vextq_u8(x, x, 8)), // x.hi^x.lo veorq_u8(y, vextq_u8(y, y, 8)), // y.hi^y.lo ); - let h = pmull::<1>(x, y); // h = x.hi * y.hi - let l = pmull::<0>(x, y); // l = x.lo * y.lo + let h = pmull2(x, y); // h = x.hi * y.hi + let l = pmull(x, y); // l = x.lo * y.lo (h, m, l) } @@ -175,18 +175,28 @@ unsafe fn mont_reduce(x23: Simd128, x01: Simd128) -> Simd128 { // [D1:D0] = [B0 ⊕ C1 : B1 ⊕ C0] // Output: [D1 ⊕ X3 : D0 ⊕ X2] let poly = vreinterpretq_u8_p128(POLY); - let a = pmull::<0>(x01, poly); + let a = pmull(x01, poly); let b = veorq_u8(x01, vextq_u8(a, a, 8)); - let c = pmull::<1>(b, poly); + let c = pmull2(b, poly); veorq_u8(x23, veorq_u8(c, b)) } /// Multiplies the low bits in `a` and `b`. #[inline] #[target_feature(enable = "aes,neon")] -unsafe fn pmull(a: Simd128, b: Simd128) -> Simd128 { +unsafe fn pmull(a: Simd128, b: Simd128) -> Simd128 { vreinterpretq_u8_p128(vmull_p64( - vgetq_lane_u64(vreinterpretq_u64_u8(a), LANE), - vgetq_lane_u64(vreinterpretq_u64_u8(b), LANE), + vgetq_lane_u64(vreinterpretq_u64_u8(a), 0), + vgetq_lane_u64(vreinterpretq_u64_u8(b), 0), + )) +} + +/// Multiplies the high bits in `a` and `b`. +#[inline] +#[target_feature(enable = "aes,neon")] +unsafe fn pmull2(a: Simd128, b: Simd128) -> Simd128 { + vreinterpretq_u8_p128(vmull_p64( + vgetq_lane_u64(vreinterpretq_u64_u8(a), 1), + vgetq_lane_u64(vreinterpretq_u64_u8(b), 1), )) }