diff --git a/benches/uint.rs b/benches/uint.rs index 10108bec..180a61f8 100644 --- a/benches/uint.rs +++ b/benches/uint.rs @@ -1045,6 +1045,14 @@ fn bench_sqrt(c: &mut Criterion) { ); }); + group.bench_function("floor_sqrt, one Limb", |b| { + b.iter_batched( + || Uint::new([Limb::random_from_rng(&mut rng)]), + |x| x.floor_sqrt(), + BatchSize::SmallInput, + ); + }); + group.bench_function("floor_sqrt_vartime, U256 one Limb", |b| { b.iter_batched( || U256::from_word(Limb::random_from_rng(&mut rng).0), diff --git a/src/uint.rs b/src/uint.rs index 1dd52181..54b50669 100644 --- a/src/uint.rs +++ b/src/uint.rs @@ -8,7 +8,7 @@ pub(crate) use ref_type::UintRef; use crate::{ Bounded, Choice, ConstOne, ConstZero, Constants, CtEq, CtOption, EncodedUint, FixedInteger, Int, Integer, Limb, NonZero, Odd, One, Unsigned, UnsignedWithMontyForm, Word, Zero, bitlen, - limb::nlimbs, modular::FixedMontyForm, primitives, traits::sealed::Sealed, + limb::nlimbs, modular::FixedMontyForm, traits::sealed::Sealed, }; use core::fmt; @@ -108,9 +108,6 @@ impl Uint { /// Total size of the represented integer in bits. pub const BITS: u32 = bitlen::from_limbs(LIMBS); - /// `floor(log2(Self::BITS))`. - pub(crate) const LOG2_BITS: u32 = primitives::u32_bits(Self::BITS) - 1; - /// Total size of the represented integer in bytes. pub const BYTES: usize = LIMBS * Limb::BYTES; diff --git a/src/uint/boxed/sqrt.rs b/src/uint/boxed/sqrt.rs index 32d58e1a..3ba1f5ab 100644 --- a/src/uint/boxed/sqrt.rs +++ b/src/uint/boxed/sqrt.rs @@ -1,10 +1,6 @@ //! [`BoxedUint`] square root operations. -use crate::{ - BitOps, BoxedUint, CheckedSquareRoot, ConcatenatingSquare, CtAssign, CtEq, CtGt, CtOption, - FloorSquareRoot, Limb, -}; -use core::mem; +use crate::{BoxedUint, CheckedSquareRoot, Choice, CtOption, FloorSquareRoot, NonZero}; impl BoxedUint { /// Computes `floor(√(self))` in constant time. @@ -18,46 +14,13 @@ impl BoxedUint { /// Computes √(`self`) in constant time. /// - /// Callers can check if `self` is a square by squaring the result. + /// Callers can check if `self` is a square by squaring the result, or use + /// `checked_sqrt`. #[must_use] pub fn floor_sqrt(&self) -> Self { - // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13. - // - // See Hast, "Note on computation of integer square roots" - // for the proof of the sufficiency of the bound on iterations. - // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf - - // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. - // Will not overflow since `b <= BITS`. - let mut x = Self::one_with_precision(self.bits_precision()); - x.unbounded_shl_assign_vartime((self.bits() + 1) >> 1); // ≥ √(`self`) - - let mut nz_x = x.clone(); - let mut quo = Self::zero_with_precision(self.bits_precision()); - let mut rem = Self::zero_with_precision(self.bits_precision()); - let mut i = 0; - - // Repeat enough times to guarantee result has stabilized. - // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough. - while i < self.log2_bits() + 2 { - let x_nonzero = x.is_nonzero(); - nz_x.ct_assign(&x, x_nonzero); - - // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` - quo.limbs.copy_from_slice(&self.limbs); - rem.limbs.copy_from_slice(&nz_x.limbs); - quo.as_mut_uint_ref().div_rem(rem.as_mut_uint_ref()); - x.conditional_carrying_add_assign(&quo, x_nonzero); - x.shr1_assign(); - - i += 1; - } - - // At this point `x_prev == x_{n}` and `x == x_{n+1}` - // where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`. - // Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`. - x.ct_assign(&nz_x, x.ct_gt(&nz_x)); - x + let mut root = self.clone(); + root.floor_sqrt_assign(); + root } /// Computes `floor(√(self))`. @@ -73,47 +36,15 @@ impl BoxedUint { /// Computes √(`self`). /// - /// Callers can check if `self` is a square by squaring the result. + /// Callers can check if `self` is a square by squaring the result, or use + /// `checked_sqrt_vartime`. /// /// Variable time with respect to `self`. #[must_use] pub fn floor_sqrt_vartime(&self) -> Self { - // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 - - if self.is_zero_vartime() { - return Self::zero_with_precision(self.bits_precision()); - } - - // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. - // Will not overflow since `b <= BITS`. - // The initial value of `x` is always greater than zero. - let mut x = Self::one_with_precision(self.bits_precision()); - x.unbounded_shl_assign_vartime((self.bits() + 1) >> 1); // ≥ √(`self`) - - let mut quo = Self::zero_with_precision(self.bits_precision()); - let mut rem = Self::zero_with_precision(self.bits_precision()); - - loop { - // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` - quo.limbs.copy_from_slice(&self.limbs); - rem.limbs.copy_from_slice(&x.limbs); - quo.as_mut_uint_ref().div_rem_vartime(rem.as_mut_uint_ref()); - quo.carrying_add_assign(&x, Limb::ZERO); - quo.shr1_assign(); - - // If `quo` is the same as `x` or greater, we reached convergence - // (`x` is guaranteed to either go down or oscillate between - // `sqrt(self)` and `sqrt(self) + 1`) - if !x.cmp_vartime(&quo).is_gt() { - break; - } - x.limbs.copy_from_slice(&quo.limbs); - if x.is_zero_vartime() { - break; - } - } - - x + let mut root = self.clone(); + root.floor_sqrt_assign_vartime(); + root } /// Wrapped sqrt is just `floor(√(self))`. @@ -138,9 +69,9 @@ impl BoxedUint { /// only if the square root is exact. #[must_use] pub fn checked_sqrt(&self) -> CtOption { - let r = self.floor_sqrt(); - let s = r.wrapping_square(); - CtOption::new(r, self.ct_eq(&s)) + let mut root = self.clone(); + let exact = root.floor_sqrt_assign(); + CtOption::new(root, exact) } /// Perform checked sqrt, returning an [`Option`] which `is_some` @@ -149,24 +80,32 @@ impl BoxedUint { /// Variable time with respect to `self`. #[must_use] pub fn checked_sqrt_vartime(&self) -> Option { - let r = self.floor_sqrt_vartime(); - let s = r.wrapping_square(); - if self.cmp_vartime(&s).is_eq() { - Some(r) + let mut root = self.clone(); + if root.floor_sqrt_assign_vartime() { + Some(root) } else { None } } + /// Assigns `floor(√(self))` to `self` in constant time, and returns a [`Choice`] + /// indicating whether the square root is exact. + fn floor_sqrt_assign(&mut self) -> Choice { + let size = self.nlimbs(); + let mut buf = Self::zero_with_precision(self.bits_precision() * 2); + self.as_mut_uint_ref() + .sqrt_assign(buf.as_mut_uint_ref().split_at_mut(size)) + } + /// Assigns `floor(√(self))` to `self`, and returns a [`bool`] /// indicating whether the square root is exact. /// /// Variable time with respect to `self`. pub fn floor_sqrt_assign_vartime(&mut self) -> bool { - // TODO(tarcieri): more optimized implementation - let mut ret = self.floor_sqrt_vartime(); - mem::swap(&mut ret, self); - self.concatenating_square() == ret + let size = self.nlimbs(); + let mut buf = Self::zero_with_precision(self.bits_precision() * 2); + self.as_mut_uint_ref() + .sqrt_assign_vartime(buf.as_mut_uint_ref().split_at_mut(size)) } } @@ -192,14 +131,38 @@ impl FloorSquareRoot for BoxedUint { } } +impl CheckedSquareRoot for NonZero { + type Output = Self; + + fn checked_sqrt(&self) -> CtOption { + self.as_ref().checked_sqrt().map(NonZero::new_unchecked) + } + + fn checked_sqrt_vartime(&self) -> Option { + self.as_ref() + .checked_sqrt_vartime() + .map(NonZero::new_unchecked) + } +} + +impl FloorSquareRoot for NonZero { + fn floor_sqrt(&self) -> Self { + NonZero::new_unchecked(self.as_ref().floor_sqrt()) + } + + fn floor_sqrt_vartime(&self) -> Self { + NonZero::new_unchecked(self.as_ref().floor_sqrt_vartime()) + } +} + #[cfg(test)] #[allow(clippy::integer_division_remainder_used, reason = "test")] mod tests { - use crate::{BoxedUint, Limb}; + use crate::{BoxedUint, CheckedSquareRoot, FloorSquareRoot, Limb}; #[cfg(feature = "rand_core")] use { - crate::RandomBits, + crate::{ConcatenatingSquare, RandomBits}, chacha20::ChaCha8Rng, rand_core::{Rng, SeedableRng}, }; @@ -218,7 +181,7 @@ mod tests { for i in 0..half.limbs.len() / 2 { half.limbs[i] = Limb::MAX; } - let u256_max = !BoxedUint::zero_with_precision(256); + let u256_max = BoxedUint::max(256); assert_eq!(u256_max.floor_sqrt(), half); // Test edge cases that use up the maximum number of iterations. @@ -307,8 +270,26 @@ mod tests { let r = BoxedUint::from(*e); assert_eq!(l.floor_sqrt(), r); assert_eq!(l.floor_sqrt_vartime(), r); - assert!(l.checked_sqrt().is_some().to_bool()); - assert!(l.checked_sqrt_vartime().is_some()); + assert_eq!( + CheckedSquareRoot::checked_sqrt(&l).into_option().as_ref(), + Some(&r) + ); + assert_eq!( + CheckedSquareRoot::checked_sqrt_vartime(&l).as_ref(), + Some(&r) + ); + let nz_l = l.as_nz_vartime().unwrap(); + let nz_r = r.to_nz().unwrap(); + assert_eq!(FloorSquareRoot::floor_sqrt(nz_l), nz_r); + assert_eq!(FloorSquareRoot::floor_sqrt_vartime(nz_l), nz_r); + assert_eq!( + CheckedSquareRoot::checked_sqrt(nz_l).into_option().as_ref(), + Some(&nz_r) + ); + assert_eq!( + CheckedSquareRoot::checked_sqrt_vartime(nz_l).as_ref(), + Some(&nz_r) + ); } } @@ -362,8 +343,6 @@ mod tests { #[cfg(feature = "rand_core")] #[test] fn fuzz() { - use crate::{CheckedSquareRoot, ConcatenatingSquare}; - let mut rng = ChaCha8Rng::from_seed([7u8; 32]); let rounds = if cfg!(miri) { 10 } else { 50 }; for _ in 0..rounds { @@ -384,4 +363,19 @@ mod tests { assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2); } } + + #[test] + // test input from issue #1303 + fn sqrt_edge() { + let sq = BoxedUint::from_be_hex( + "0000000000000AB88E41D05334F646EE1B715C614C81DCE1C46E2E021EF2E7FA45E5DD2CC670725E52B091A0E08B1E9E6C43262581226A7BFDE534F704D6DC6C1FD5C93B3AE4FB73C9D26EBC09DB83040C2CD157884532EB5578CEF2BA95E87C7F6E328F062BB20D3C95D122EB4D26727901A424AF14143EBCEFE5D5958E9EDEA73649F5C76932D491C4A51C370F51E6DBEB66D56D7FB51A", + 1216, + ).unwrap(); + let expect = BoxedUint::from_be_hex( + "00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000346375360711FDFC13BE3CF276FCC535C2E29164A842D3B71E2C7DFA60AE0E24CD0CC65AD7A3B626E03DAC8260D921AB8E16C6757D2CA8DEE5DFA62E5C1CB3C90A49B9F1A192498905", + 1216, + ).unwrap(); + assert_eq!(sq.floor_sqrt(), expect); + assert_eq!(sq.floor_sqrt_vartime(), expect); + } } diff --git a/src/uint/ref_type.rs b/src/uint/ref_type.rs index 11db152c..3215cd22 100644 --- a/src/uint/ref_type.rs +++ b/src/uint/ref_type.rs @@ -10,6 +10,7 @@ mod mul; mod shl; mod shr; mod slice; +mod sqrt; mod sub; use crate::{Choice, Limb, NonZero, Odd, Uint, Word}; diff --git a/src/uint/ref_type/add.rs b/src/uint/ref_type/add.rs index 10091a8e..973384c7 100644 --- a/src/uint/ref_type/add.rs +++ b/src/uint/ref_type/add.rs @@ -4,7 +4,6 @@ use crate::{Choice, Limb}; impl UintRef { /// Perform an in-place carrying add of a limb, returning the carried limb value. #[inline] - #[track_caller] pub const fn add_assign_limb(&mut self, mut rhs: Limb) -> Limb { let mut i = 0; while i < self.limbs.len() { @@ -15,6 +14,9 @@ impl UintRef { } /// Perform an in-place carrying add of another [`UintRef`], returning the carried limb value. + /// + /// # Panics + /// If `self` is shorter than `rhs`. #[inline] #[track_caller] pub const fn carrying_add_assign(&mut self, rhs: &Self, carry: Limb) -> Limb { @@ -24,23 +26,30 @@ impl UintRef { /// Perform an in-place carrying add of another limb slice, returning the carried limb value. /// /// # Panics - /// If `self` and `rhs` have different lengths. + /// If `self` is shorter than `rhs`. #[inline] #[track_caller] pub const fn carrying_add_assign_slice(&mut self, rhs: &[Limb], mut carry: Limb) -> Limb { assert!( - self.limbs.len() == rhs.len(), + self.limbs.len() >= rhs.len(), "length mismatch in carrying_add_assign_slice" ); let mut i = 0; - while i < self.limbs.len() { + while i < rhs.len() { (self.limbs[i], carry) = self.limbs[i].carrying_add(rhs[i], carry); i += 1; } + while i < self.limbs.len() { + (self.limbs[i], carry) = self.limbs[i].overflowing_add(carry); + i += 1; + } carry } /// Perform an in-place carrying add of another limb slice, returning the carried limb value. + /// + /// # Panics + /// If `self` is shorter than `rhs`. #[inline] #[track_caller] pub const fn conditional_add_assign( @@ -55,7 +64,7 @@ impl UintRef { /// Perform an in-place carrying add of another limb slice, returning the carried limb value. /// /// # Panics - /// If `self` and `rhs` have different lengths. + /// If `self` is shorter than `rhs`. #[inline] #[track_caller] pub const fn conditional_add_assign_slice( @@ -65,15 +74,19 @@ impl UintRef { choice: Choice, ) -> Limb { assert!( - self.limbs.len() == rhs.len(), + self.limbs.len() >= rhs.len(), "length mismatch in conditional_add_assign_slice" ); let mut i = 0; - while i < self.limbs.len() { + while i < rhs.len() { (self.limbs[i], carry) = self.limbs[i].carrying_add(Limb::select(Limb::ZERO, rhs[i], choice), carry); i += 1; } + while i < self.limbs.len() { + (self.limbs[i], carry) = self.limbs[i].overflowing_add(carry); + i += 1; + } carry } @@ -102,3 +115,45 @@ impl UintRef { self.trailing_mut(i).add_assign_limb(carry) } } + +#[cfg(test)] +mod tests { + use crate::{Choice, Limb, UintRef, Unsigned}; + + #[test] + fn carrying_add_assign_mixed() { + let mut a = [Limb::MAX]; + let carry = + UintRef::new_mut(&mut a).carrying_add_assign(Limb::ONE.as_uint_ref(), Limb::ZERO); + assert_eq!((a, carry), ([Limb::ZERO], Limb::ONE)); + + let mut a = [Limb::MAX]; + let carry = UintRef::new_mut(&mut a).conditional_add_assign( + Limb::ONE.as_uint_ref(), + Limb::ZERO, + Choice::FALSE, + ); + assert_eq!((a, carry), ([Limb::MAX], Limb::ZERO)); + + let mut a = [Limb::MAX]; + let carry = UintRef::new_mut(&mut a).conditional_add_assign( + Limb::ONE.as_uint_ref(), + Limb::ZERO, + Choice::TRUE, + ); + assert_eq!((a, carry), ([Limb::ZERO], Limb::ONE)); + + let mut a = [Limb::MAX, Limb::MAX]; + let carry = + UintRef::new_mut(&mut a).carrying_add_assign(Limb::ZERO.as_uint_ref(), Limb::ONE); + assert_eq!((a, carry), ([Limb::ZERO, Limb::ZERO], Limb::ONE)); + + let mut a = [Limb::MAX, Limb::MAX]; + let carry = UintRef::new_mut(&mut a).conditional_add_assign( + Limb::MAX.as_uint_ref(), + Limb::ONE, + Choice::TRUE, + ); + assert_eq!((a, carry), ([Limb::MAX, Limb::ZERO], Limb::ONE)); + } +} diff --git a/src/uint/ref_type/sqrt.rs b/src/uint/ref_type/sqrt.rs new file mode 100644 index 00000000..1c04020b --- /dev/null +++ b/src/uint/ref_type/sqrt.rs @@ -0,0 +1,298 @@ +//! Square root calculation. + +use super::UintRef; +use crate::{Choice, Limb, NonZero, Uint, WideWord, Word, word}; + +impl UintRef { + #[inline] + pub(crate) const fn sqrt_assign(&mut self, buf: (&mut UintRef, &mut UintRef)) -> Choice { + let len = self.nlimbs(); + assert!(buf.0.nlimbs() >= len && buf.1.nlimbs() >= len); + + match len { + 0 => Choice::TRUE, + 1 => { + let rem; + (self.limbs[0].0, rem) = sqrt_rem_word(self.limbs[0].0); + Limb(rem).is_zero() + } + _ => { + let mut bits = self.bits(); + let is_zero = Choice::from_u32_nz(bits).not(); + // Set first limb to 1 if the input is zero + self.limbs[0] = Limb::select(self.limbs[0], Limb::ONE, is_zero); + bits = is_zero.select_u32(bits, 1); + + // Shift such that at least one of the top two bits are non-zero + let s_shift = (self.bits_precision() - bits) >> 1; + self.shl_assign(s_shift * 2); + + // Compute root and uncorrected remainder, placing them in buf.0 and buf.1 + self.sqrt_rem_normalized::(buf.0, buf.1); + + // Copy the shifted root to self and correct it + self.copy_from(buf.0.leading(self.nlimbs())); + self.shr_assign(s_shift); + + // Set root to zero if self was zero + self.limbs[0] = Limb::select(self.limbs[0], Limb::ZERO, is_zero); + + // Check if there is a remainder + buf.1.is_zero() + } + } + } + + #[inline] + pub(crate) const fn sqrt_assign_vartime(&mut self, buf: (&mut UintRef, &mut UintRef)) -> bool { + let bits = self.bits_vartime(); + if bits == 0 { + return true; + } + + let words = bits.div_ceil(Limb::BITS); + assert!(buf.0.nlimbs() >= words as usize && buf.1.nlimbs() >= words as usize); + + // Only consider the populated limbs + let out = self.leading_mut(words as usize); + + if words <= 2 { + // No shifts needed + sqrt_rem_small::(out.as_limbs(), buf.0.as_mut_limbs(), buf.1.as_mut_limbs()); + out.copy_from(buf.0.leading(out.nlimbs())); + } else { + // Shift such that at least one of the top two bits are non-zero + let lz = words * Limb::BITS - bits; + let s_shift = lz >> 1; + out.shl_assign_limb_vartime(s_shift * 2); + + // Compute root and uncorrected remainder, placing them in buf.0 and buf.1 + out.sqrt_rem_normalized::(buf.0, buf.1); + + // Copy the shifted root to self and correct it + out.copy_from(buf.0.leading(out.nlimbs())); + out.shr_assign_limb_vartime(s_shift); + } + + // Check if there is a remainder + buf.1.is_zero_vartime() + } + + /// Corresponds to Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.12 + #[allow(clippy::cast_possible_truncation)] + const fn sqrt_rem_normalized(&self, s: &mut UintRef, r: &mut UintRef) { + let len = self.nlimbs(); + + // Handle base case of a square root < 4 limbs + // Unlike the source material, we do not handle 4 limbs in the base case sqrt function + if len < 4 { + sqrt_rem_small::(self.as_limbs(), s.as_mut_limbs(), r.as_mut_limbs()); + return; + } + + let l = len >> 2; + let rt_len = (len - l * 2).div_ceil(2); + let (a_lo, a_hi) = self.split_at(l * 2); + let (a0, a1) = a_lo.split_at(l); + + // (s', r') = SqrtRem(a3•B + a2), leaving `l` spare low limbs in each + a_hi.sqrt_rem_normalized::(s.trailing_mut(l), r.trailing_mut(l)); + + // Split up buffers + let (s, s_tail) = s.split_at_mut(l + rt_len); + let (s_lo, s_hi) = s.split_at_mut(l); + let (r, r_tail) = r.split_at_mut(l + rt_len + 1); + + // Set r = r'B + a1 + r.leading_mut(l).copy_from(a1); + // Set u = s' + let u = s_tail.leading_mut(rt_len); + u.copy_from(s_hi); + + // We wish to divide r/2s', setting r to the quotient and u to the remainder, but + // 2s' doesn't fit within the available buffer. Compute r/s' and adjust the result + r.div_rem(u); + // Adjust the result for divisor 2s' + let r_mod2 = r.shr1_assign(); + let u_hi = u.conditional_add_assign(s_hi, Limb::ZERO, r_mod2.is_nonzero()); + + let (r_lo, r_hi) = r.split_at_mut(l); + + // s = s'B + q + s_lo.copy_from(r_lo); + let q_hi = r_hi.limbs[0]; + s_hi.add_assign_limb(q_hi); + debug_assert!(q_hi.0 & !1 == 0); + + // r = uB + a0 + r_lo.copy_from(a0); + r_hi.leading_mut(rt_len).copy_from(u); + r_hi.limbs[rt_len] = u_hi; + + // Compute q^2 + let q2 = s_tail.leading_mut(l * 2); + q2.fill(Limb::ZERO); + s_lo.wrapping_square(q2); + + // r -= q^2, producing a borrow if r < 0 + let (r0, r1) = r.split_at_mut(l * 2); + let borrow = r0.borrowing_sub_assign(q2, Limb::ZERO); + let swap = r1.borrowing_sub_assign_limb(q_hi, borrow).lsb_to_choice(); + let s_carry = Limb::select(Limb::ZERO, Limb::ONE, swap); + + // s -= 1 if r < 0 + s.borrowing_sub_assign_limb(s_carry, Limb::ZERO); + + // r += 2s + 1 if r < 0 + let s_mul = s_carry.shl(1); + r.carrying_add_assign_mul_limb(s, s_mul, s_carry); + + // Clear upper limbs of the result buffers + s_tail.fill(Limb::ZERO); + r_tail.fill(Limb::ZERO); + } +} + +/// Compute the square root and remainder of a [`Word`]. +/// +/// Adapted from Hacker's Delight 2E by Henry S. Warren Jr., +/// Fig. 11-4, "Integer square root, hardware algorithm" +/// based on Toepler's algorithm. +#[inline] +const fn sqrt_rem_word(value: Word) -> (Word, Word) { + let mut m = 1 << (Word::BITS - 2); + let mut x = value; + let mut y = 0; + while m != 0 { + let b = y | m; + y >>= 1; + let mask = word::choice_to_mask(word::choice_from_le(b, x)); + x = x.wrapping_sub(b & mask); + y |= m & mask; + m >>= 2; + } + (y, x) +} + +/// Compute the root and remainder for `n+1` limbs, given the root and remainder +/// for `n` limbs (n=1 or 2). The input value must be normalized such that at least one of the +/// top two bits is set, producing an `s1` with at least `Word::BITS/2` bits. +/// +/// This is essentially a specialized version of the root expansion portion of `sqrt_rem_normalized`. +#[allow(clippy::cast_possible_truncation)] +#[inline] +const fn sqrt_rem_expand_by_word(s1: Word, r1: WideWord, next: Word) -> (WideWord, WideWord) { + const HALF_WIDTH: u32 = Word::BITS >> 1; + debug_assert!((s1 >> (HALF_WIDTH - 1)) > 0); + + // Split the lower word into lower and upper halves + let (a0, a1) = ( + (next & ((1 << HALF_WIDTH) - 1)) as WideWord, + (next >> HALF_WIDTH) as WideWord, + ); + let d = (r1 << HALF_WIDTH) | a1; + + // Divide by (r1B + a1) by s1 (not 2s1 which could overflow a limb) + let (q, _) = Uint::<2>::from_wide_word(d).div_rem_limb(NonZero::::new_unwrap(Limb(s1))); + // Correct the quotient + let q = (q.limbs[0].0 >> 1) as WideWord; + // Recompute the remainder u + let u = d - (q << 1) * s1 as WideWord; + // Set s = s1B + q + let s = ((s1 as WideWord) << HALF_WIDTH) + q; + + // Set r' = uB + a0 + let r_pre = (u << HALF_WIDTH) | a0; + let q2 = q.pow(2); + let swap = word::choice_to_wide_mask(word::choice_from_wide_lt(r_pre, q2)); + + // Subtract 1 from s if r' < q^2 + let s = s - (swap & 1); + // Set r = r' - q2, adding back 2s + 1 if r would be negative + let r = r_pre.wrapping_add(((s << 1) | 1) & swap) - q2; + + (s, r) +} + +/// Compute the square root and remainder for a 1 to 3 limb input, which must be normalized. +/// This is the base case square root calculation. +#[allow(clippy::cast_possible_truncation)] +#[inline] +const fn sqrt_rem_small(value: &[Limb], s: &mut [Limb], r: &mut [Limb]) { + let len = value.len(); + assert!(len < 4, "value exceeds maximum size for sqrt_rem_small"); + + if len == 0 { + return; + } + + if len == 1 { + let base = value[0].0; + (s[0].0, r[0].0) = if VARTIME { + let root = base.isqrt(); + (root, base.wrapping_sub(root.wrapping_pow(2))) + } else { + sqrt_rem_word(base) + }; + return; + } + + // Compute root, remainder for two words + let (s1, r1) = { + if VARTIME { + let base = word::join(value[len - 2].0, value[len - 1].0); + let root = base.isqrt(); + (root as Word, base.wrapping_sub(root.wrapping_pow(2))) + } else { + let (s0, r0) = sqrt_rem_word(value[len - 1].0); + let (s1, r1) = sqrt_rem_expand_by_word(s0, r0 as WideWord, value[len - 2].0); + (s1 as Word, r1) + } + }; + if len == 2 { + s[0].0 = s1; + (r[0].0, r[1].0) = word::split_wide(r1); + return; + } + + // Expand to root, remainder for three words + let (s1, r1) = sqrt_rem_expand_by_word(s1, r1, value[len - 3].0); + (s[0].0, s[1].0) = word::split_wide(s1); + (r[0].0, r[1].0) = word::split_wide(r1); +} + +#[cfg(test)] +mod tests { + use super::{sqrt_rem_expand_by_word, sqrt_rem_word}; + use crate::{WideWord, Word, word}; + + #[test] + fn sqrt_rem_word_expected() { + fn check(val: Word) { + let (root, rem) = sqrt_rem_word(val); + let check = root.pow(2) + rem; + assert_eq!(check, val, "val: {val}, root: {root}, rem: {rem}"); + } + + check(0); + check(1); + check(2); + check(Word::MAX); + } + + #[test] + fn sqrt_rem_expand_by_word_expected() { + fn check(val: WideWord) { + let (lo, hi) = word::split_wide(val); + let s1 = hi.isqrt(); + let r1 = WideWord::from(hi - s1.pow(2)); + let (s, r) = sqrt_rem_expand_by_word(s1, r1, lo); + assert_eq!(s, val.isqrt()); + assert_eq!(r, val - s.pow(2)); + } + + check(1 << (WideWord::BITS - 1)); + check(2 << (WideWord::BITS - 2)); + check(WideWord::MAX); + } +} diff --git a/src/uint/ref_type/sub.rs b/src/uint/ref_type/sub.rs index 926ed3b9..4335881f 100644 --- a/src/uint/ref_type/sub.rs +++ b/src/uint/ref_type/sub.rs @@ -17,7 +17,11 @@ impl UintRef { /// Perform an in-place borrowing subtraction of another [`UintRef`], returning the carried limb /// value. + /// + /// # Panics + /// If `self` is shorter than `rhs`. #[inline] + #[track_caller] pub const fn borrowing_sub_assign(&mut self, rhs: &Self, borrow: Limb) -> Limb { self.borrowing_sub_assign_slice(&rhs.limbs, borrow) } @@ -26,36 +30,98 @@ impl UintRef { /// value. /// /// # Panics - /// If `self` and `rhs` have different lengths. + /// If `self` is shorter than `rhs`. #[inline] + #[track_caller] pub const fn borrowing_sub_assign_slice(&mut self, rhs: &[Limb], mut borrow: Limb) -> Limb { - assert!(self.limbs.len() == rhs.len(), "length mismatch"); + assert!( + self.limbs.len() >= rhs.len(), + "length mismatch in borrowing_sub_assign_slice" + ); let mut i = 0; - while i < self.limbs.len() { + while i < rhs.len() { (self.limbs[i], borrow) = self.limbs[i].borrowing_sub(rhs[i], borrow); i += 1; } + while i < self.limbs.len() { + (self.limbs[i], borrow) = self.limbs[i].borrowing_sub(Limb::ZERO, borrow); + i += 1; + } borrow } /// Perform in-place wrapping subtraction, returning the truthy value as the second element of /// the tuple if an underflow has occurred. - pub(crate) fn conditional_borrowing_sub_assign( + /// + /// # Panics + /// If `self` is shorter than `rhs`. + #[track_caller] + pub(crate) const fn conditional_borrowing_sub_assign( &mut self, rhs: &Self, choice: Choice, ) -> Choice { - debug_assert!(self.bits_precision() <= rhs.bits_precision()); + assert!( + self.limbs.len() >= rhs.limbs.len(), + "length mismatch in conditional_borrowing_sub_assign" + ); let mask = Limb::select(Limb::ZERO, Limb::MAX, choice); let mut borrow = Limb::ZERO; - for i in 0..self.nlimbs() { - let masked_rhs = *rhs.limbs.get(i).unwrap_or(&Limb::ZERO) & mask; - let (limb, b) = self.limbs[i].borrowing_sub(masked_rhs, borrow); - self.limbs[i] = limb; - borrow = b; + let mut i = 0; + while i < rhs.limbs.len() { + let masked_rhs = rhs.limbs[i].bitand(mask); + (self.limbs[i], borrow) = self.limbs[i].borrowing_sub(masked_rhs, borrow); + i += 1; + } + while i < self.limbs.len() { + (self.limbs[i], borrow) = self.limbs[i].borrowing_sub(Limb::ZERO, borrow); + i += 1; } borrow.lsb_to_choice() } } + +#[cfg(test)] +mod tests { + use crate::{Choice, Limb, UintRef, Unsigned}; + + #[test] + fn borrowing_sub_assign_mixed() { + let mut a = [Limb::MAX]; + let borrow = + UintRef::new_mut(&mut a).borrowing_sub_assign(Limb::MAX.as_uint_ref(), Limb::MAX); + assert_eq!((a, borrow), ([Limb::MAX], Limb::MAX)); + + let mut a = [Limb::MAX]; + let borrow = UintRef::new_mut(&mut a) + .conditional_borrowing_sub_assign(Limb::MAX.as_uint_ref(), Choice::FALSE); + assert_eq!((a, borrow.to_bool()), ([Limb::MAX], false)); + + let mut a = [Limb::MAX - Limb::ONE]; + let borrow = UintRef::new_mut(&mut a) + .conditional_borrowing_sub_assign(Limb::MAX.as_uint_ref(), Choice::TRUE); + assert_eq!((a, borrow.to_bool()), ([Limb::MAX], true)); + + let mut a = [Limb::MAX - Limb::ONE, Limb::ONE]; + let borrow = + UintRef::new_mut(&mut a).borrowing_sub_assign(Limb::MAX.as_uint_ref(), Limb::ZERO); + assert_eq!((a, borrow), ([Limb::MAX, Limb::ZERO], Limb::ZERO)); + + let mut a = [Limb::MAX - Limb::ONE, Limb::ZERO]; + let borrow = + UintRef::new_mut(&mut a).borrowing_sub_assign(Limb::MAX.as_uint_ref(), Limb::ZERO); + assert_eq!((a, borrow), ([Limb::MAX, Limb::MAX], Limb::MAX)); + + let mut a = [Limb::MAX - Limb::ONE, Limb::ONE]; + let borrow = UintRef::new_mut(&mut a) + .conditional_borrowing_sub_assign(Limb::MAX.as_uint_ref(), Choice::TRUE); + assert_eq!((a, borrow.to_bool()), ([Limb::MAX, Limb::ZERO], false)); + + let mut a = [Limb::MAX - Limb::ONE, Limb::ZERO]; + let borrow = UintRef::new_mut(&mut a) + .conditional_borrowing_sub_assign(Limb::MAX.as_uint_ref(), Choice::TRUE); + assert_eq!((a, borrow.to_bool()), ([Limb::MAX, Limb::MAX], true)); + } +} diff --git a/src/uint/sqrt.rs b/src/uint/sqrt.rs index 6a9ed921..ba64c89c 100644 --- a/src/uint/sqrt.rs +++ b/src/uint/sqrt.rs @@ -1,6 +1,8 @@ //! [`Uint`] square root operations. -use crate::{CheckedSquareRoot, CtEq, CtOption, FloorSquareRoot, Limb, NonZero, Uint}; +use ctutils::Choice; + +use crate::{CheckedSquareRoot, CtOption, FloorSquareRoot, NonZero, Uint}; impl Uint { /// Computes `floor(√(self))` in constant time. @@ -14,12 +16,13 @@ impl Uint { /// Computes `floor(√(self))` in constant time. /// - /// Callers can check if `self` is a square by squaring the result. + /// Callers can check if `self` is a square by squaring the result, or use + /// `checked_sqrt`. #[must_use] pub const fn floor_sqrt(&self) -> Self { - let (self_nz, self_is_nz) = self.to_nz_or_one(); - let root_nz = self_nz.floor_sqrt(); - Self::select(&Self::ZERO, root_nz.as_ref(), self_is_nz) + let mut root = *self; + root.floor_sqrt_assign(); + root } /// Computes `floor(√(self))`. @@ -40,11 +43,9 @@ impl Uint { /// Variable time with respect to `self`. #[must_use] pub const fn floor_sqrt_vartime(&self) -> Self { - if let Some(self_nz) = self.as_nz_vartime() { - self_nz.floor_sqrt_vartime().get_copy() - } else { - Self::ZERO - } + let mut root = *self; + root.floor_sqrt_assign_vartime(); + root } /// Wrapped sqrt is just `floor(√(self))`. @@ -69,121 +70,79 @@ impl Uint { /// only if the square root is exact. #[must_use] pub fn checked_sqrt(&self) -> CtOption { - let (self_nz, self_is_nz) = self.to_nz_or_one(); - self_nz - .checked_sqrt() - .map(|nz| Self::select(&Self::ZERO, nz.as_ref(), self_is_nz)) + let mut root = *self; + let exact = root.floor_sqrt_assign(); + CtOption::new(root, exact) } /// Perform checked sqrt, returning an [`Option`] which `is_some` /// only if the square root is exact. /// /// Variable time with respect to `self`. + #[must_use] pub fn checked_sqrt_vartime(&self) -> Option { - if let Some(self_nz) = self.as_nz_vartime() { - self_nz.checked_sqrt_vartime().map(NonZero::get) + let mut root = *self; + if root.floor_sqrt_assign_vartime() { + Some(root) } else { - Some(Self::ZERO) + None } } + + /// Assigns `floor(√(self))` to `self` and returns a [`Choice`] indicating + /// whether the square root is exact. + const fn floor_sqrt_assign(&mut self) -> Choice { + let mut buf = (Uint::::ZERO, Uint::::ZERO); + self.as_mut_uint_ref() + .sqrt_assign((buf.0.as_mut_uint_ref(), buf.1.as_mut_uint_ref())) + } + + /// Assigns `floor(√(self))` to `self` and returns a [`bool`] indicating + /// whether the square root is exact. + /// + /// Variable time with respect to `self`. + const fn floor_sqrt_assign_vartime(&mut self) -> bool { + let mut buf = (Uint::::ZERO, Uint::::ZERO); + self.as_mut_uint_ref() + .sqrt_assign_vartime((buf.0.as_mut_uint_ref(), buf.1.as_mut_uint_ref())) + } } impl NonZero> { /// Computes `floor(√(self))` in constant time. /// - /// Callers can check if `self` is a square by squaring the result. + /// Callers can check if `self` is a square by squaring the result, or + /// use `checked_sqrt`. #[must_use] pub const fn floor_sqrt(&self) -> Self { - // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13. - // - // See Hast, "Note on computation of integer square roots" - // for the proof of the sufficiency of the bound on iterations. - // https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf - - let rt_bits = self.as_ref().bits().div_ceil(2); - // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < 2^b`. - // Will not overflow since `b <= BITS`. - let mut x = Uint::::ZERO.set_bit_vartime(rt_bits, true); - // Compute `self.0 / x_0` by shifting. - let mut q = self.as_ref().shr(rt_bits); - // The first division has been performed. - let mut i = 1; - - loop { - // Calculate `x_{i+1} = floor((x_i + self_nz / x_i) / 2)`, leaving `x` unmodified - // if it would increase. - x = Uint::select(&x.wrapping_add(&q).shr1(), &x, Uint::lt(&x, &q)); - - // We repeat enough times to guarantee the result has stabilized. - // TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough. - i += 1; - if i >= Uint::::LOG2_BITS + 2 { - return x.to_nz().expect_copied("ensured non-zero"); - } - - (q, _) = self - .as_ref() - .div_rem(x.to_nz().expect_ref("ensured non-zero")); - } + NonZero::new_unchecked(self.as_ref().floor_sqrt()) } /// Computes `floor(√(self))`. /// - /// Callers can check if `self` is a square by squaring the result. + /// Callers can check if `self` is a square by squaring the result, or + /// use `checked_sqrt_vartime`. /// /// Variable time with respect to `self`. #[must_use] pub const fn floor_sqrt_vartime(&self) -> Self { - // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13 - - let bits = self.as_ref().bits_vartime(); - if bits <= Limb::BITS { - let rt = self.as_ref().limbs[0].0.isqrt(); - return Uint::from_word(rt) - .to_nz() - .expect_copied("ensured non-zero"); - } - let rt_bits = bits.div_ceil(2); - - // The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`. - // Will not overflow since `b <= BITS`. - let mut x = Uint::ZERO.set_bit_vartime(rt_bits, true); - // Compute `self / x_0` by shifting. - let mut q = self.as_ref().shr_vartime(rt_bits); - - loop { - // Terminate if `x_{i+1}` >= `x`. - if q.cmp_vartime(&x).is_ge() { - return x.to_nz().expect_copied("ensured non-zero"); - } - // Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)` - x = x.wrapping_add(&q).shr_vartime(1); - q = self - .as_ref() - .wrapping_div_vartime(x.to_nz().expect_ref("ensured non-zero")); - } + NonZero::new_unchecked(self.as_ref().floor_sqrt_vartime()) } /// Perform checked sqrt, returning a [`CtOption`] which `is_some` /// only if the square root is exact. #[must_use] pub fn checked_sqrt(&self) -> CtOption { - let r = self.floor_sqrt(); - let s = r.wrapping_square(); - CtOption::new(r, self.as_ref().ct_eq(&s)) + self.as_ref().checked_sqrt().map(NonZero::new_unchecked) } /// Perform checked sqrt, returning an [`Option`] which `is_some` /// only if the square root is exact. #[must_use] pub fn checked_sqrt_vartime(&self) -> Option { - let r = self.floor_sqrt_vartime(); - let s = r.wrapping_square(); - if self.as_ref().cmp_vartime(&s).is_eq() { - Some(r) - } else { - None - } + self.as_ref() + .checked_sqrt_vartime() + .map(NonZero::new_unchecked) } } @@ -238,9 +197,9 @@ mod tests { #[cfg(feature = "rand_core")] use { - crate::{Random, U512}, + crate::{CheckedAdd, CheckedSquareRoot, Concat, FloorSquareRoot, Random, RandomBits, Uint}, chacha20::ChaCha8Rng, - rand_core::{Rng, SeedableRng}, + rand_core::SeedableRng, }; #[test] @@ -344,33 +303,52 @@ mod tests { #[cfg(feature = "rand_core")] #[test] - fn fuzz() { - use crate::{CheckedSquareRoot, FloorSquareRoot}; - - let mut rng = ChaCha8Rng::from_seed([7u8; 32]); - for _ in 0..50 { - let t = u64::from(rng.next_u32()); - let s = U256::from(t); - let s2 = s.checked_square().unwrap(); - assert_eq!(FloorSquareRoot::floor_sqrt(&s2), s); - assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&s2), s); - assert!(CheckedSquareRoot::checked_sqrt(&s2).is_some().to_bool()); - assert!(CheckedSquareRoot::checked_sqrt_vartime(&s2).is_some()); - - if let Some(nz) = s2.to_nz().into_option() { - assert_eq!(FloorSquareRoot::floor_sqrt(&nz).get(), s); - assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&nz).get(), s); - assert!(CheckedSquareRoot::checked_sqrt(&nz).is_some().to_bool()); - assert!(CheckedSquareRoot::checked_sqrt_vartime(&nz).is_some()); + fn sqrt_fuzz() { + fn check_size() + where + Uint: Concat>, + { + let mut rng = ChaCha8Rng::from_seed([7u8; 32]); + for _ in 0..5000 { + let s = Uint::::random_bits(&mut rng, Uint::::BITS / 2); + let s2 = s.checked_square().unwrap(); + assert_eq!(FloorSquareRoot::floor_sqrt(&s2), s); + assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&s2), s); + assert!(CheckedSquareRoot::checked_sqrt(&s2).is_some().to_bool()); + assert!(CheckedSquareRoot::checked_sqrt_vartime(&s2).is_some()); + + if let Some(nz) = s2.to_nz().into_option() { + assert_eq!(FloorSquareRoot::floor_sqrt(&nz).get(), s); + assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&nz).get(), s); + assert!(CheckedSquareRoot::checked_sqrt(&nz).is_some().to_bool()); + assert!(CheckedSquareRoot::checked_sqrt_vartime(&nz).is_some()); + } + + if let Some(sx) = s2.checked_add(&Uint::ONE).into_option() { + assert_eq!(FloorSquareRoot::floor_sqrt(&sx), s); + assert_eq!(FloorSquareRoot::floor_sqrt_vartime(&sx), s); + assert!(CheckedSquareRoot::checked_sqrt(&sx).is_none().to_bool()); + assert!(CheckedSquareRoot::checked_sqrt_vartime(&sx).is_none()); + } + } + + for _ in 0..50 { + let s = Uint::::random_from_rng(&mut rng); + let mut s2 = Uint::::ZERO; + s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs); + assert_eq!(s.concatenating_square().floor_sqrt(), s2); + assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2); } } - for _ in 0..50 { - let s = U256::random_from_rng(&mut rng); - let mut s2 = U512::ZERO; - s2.limbs[..s.limbs.len()].copy_from_slice(&s.limbs); - assert_eq!(s.concatenating_square().floor_sqrt(), s2); - assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2); + check_size::<1, 2>(); + check_size::<2, 4>(); + cpubits::cpubits! { + 64 => { + check_size::<3, 6>(); + } } + check_size::<4, 8>(); + check_size::<12, 24>(); } } diff --git a/src/word.rs b/src/word.rs index a0e472cd..102dfe02 100644 --- a/src/word.rs +++ b/src/word.rs @@ -29,6 +29,12 @@ cpubits::cpubits! { pub(crate) const fn choice_from_wide_le(x: WideWord, y: WideWord) -> Choice { Choice::from_u64_le(x, y) } + + /// Returns the truthy value if `x < y` and the falsy value otherwise. + #[inline] + pub(crate) const fn choice_from_wide_lt(x: WideWord, y: WideWord) -> Choice { + Choice::from_u64_lt(x, y) + } } 64 => { /// Unsigned integer type that the [`Limb`][`crate::Limb`] newtype wraps. @@ -55,6 +61,11 @@ cpubits::cpubits! { Choice::from_u128_le(x, y) } + /// Returns the truthy value if `x < y` and the falsy value otherwise. + #[inline] + pub(crate) const fn choice_from_wide_lt(x: WideWord, y: WideWord) -> Choice { + Choice::from_u128_lt(x, y) + } } }