Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions benches/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,14 @@ fn bench_sqrt(c: &mut Criterion) {
);
});

group.bench_function("floor_sqrt, one Limb", |b| {
b.iter_batched(
|| Uint::new([Limb::random_from_rng(&mut rng)]),
|x| x.floor_sqrt(),
BatchSize::SmallInput,
);
});

group.bench_function("floor_sqrt_vartime, U256 one Limb", |b| {
b.iter_batched(
|| U256::from_word(Limb::random_from_rng(&mut rng).0),
Expand Down
5 changes: 1 addition & 4 deletions src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub(crate) use ref_type::UintRef;
use crate::{
Bounded, Choice, ConstOne, ConstZero, Constants, CtEq, CtOption, EncodedUint, FixedInteger,
Int, Integer, Limb, NonZero, Odd, One, Unsigned, UnsignedWithMontyForm, Word, Zero, bitlen,
limb::nlimbs, modular::FixedMontyForm, primitives, traits::sealed::Sealed,
limb::nlimbs, modular::FixedMontyForm, traits::sealed::Sealed,
};
use core::fmt;

Expand Down Expand Up @@ -108,9 +108,6 @@ impl<const LIMBS: usize> Uint<LIMBS> {
/// Total size of the represented integer in bits.
pub const BITS: u32 = bitlen::from_limbs(LIMBS);

/// `floor(log2(Self::BITS))`.
pub(crate) const LOG2_BITS: u32 = primitives::u32_bits(Self::BITS) - 1;

/// Total size of the represented integer in bytes.
pub const BYTES: usize = LIMBS * Limb::BYTES;

Expand Down
190 changes: 92 additions & 98 deletions src/uint/boxed/sqrt.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
//! [`BoxedUint`] square root operations.

use crate::{
BitOps, BoxedUint, CheckedSquareRoot, ConcatenatingSquare, CtAssign, CtEq, CtGt, CtOption,
FloorSquareRoot, Limb,
};
use core::mem;
use crate::{BoxedUint, CheckedSquareRoot, Choice, CtOption, FloorSquareRoot, NonZero};

impl BoxedUint {
/// Computes `floor(√(self))` in constant time.
Expand All @@ -18,46 +14,13 @@ impl BoxedUint {

/// Computes √(`self`) in constant time.
///
/// Callers can check if `self` is a square by squaring the result.
/// Callers can check if `self` is a square by squaring the result, or use
/// `checked_sqrt`.
#[must_use]
pub fn floor_sqrt(&self) -> Self {
// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13.
//
// See Hast, "Note on computation of integer square roots"
// for the proof of the sufficiency of the bound on iterations.
// https://github.com/RustCrypto/crypto-bigint/files/12600669/ct_sqrt.pdf

// The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
// Will not overflow since `b <= BITS`.
let mut x = Self::one_with_precision(self.bits_precision());
x.unbounded_shl_assign_vartime((self.bits() + 1) >> 1); // ≥ √(`self`)

let mut nz_x = x.clone();
let mut quo = Self::zero_with_precision(self.bits_precision());
let mut rem = Self::zero_with_precision(self.bits_precision());
let mut i = 0;

// Repeat enough times to guarantee result has stabilized.
// TODO (#378): the tests indicate that just `Self::LOG2_BITS` may be enough.
while i < self.log2_bits() + 2 {
let x_nonzero = x.is_nonzero();
nz_x.ct_assign(&x, x_nonzero);

// Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
quo.limbs.copy_from_slice(&self.limbs);
rem.limbs.copy_from_slice(&nz_x.limbs);
quo.as_mut_uint_ref().div_rem(rem.as_mut_uint_ref());
x.conditional_carrying_add_assign(&quo, x_nonzero);
x.shr1_assign();

i += 1;
}

// At this point `x_prev == x_{n}` and `x == x_{n+1}`
// where `n == i - 1 == LOG2_BITS + 1 == floor(log2(BITS)) + 1`.
// Thus, according to Hast, `sqrt(self) = min(x_n, x_{n+1})`.
x.ct_assign(&nz_x, x.ct_gt(&nz_x));
x
let mut root = self.clone();
root.floor_sqrt_assign();
root
}

/// Computes `floor(√(self))`.
Expand All @@ -73,47 +36,15 @@ impl BoxedUint {

/// Computes √(`self`).
///
/// Callers can check if `self` is a square by squaring the result.
/// Callers can check if `self` is a square by squaring the result, or use
/// `checked_sqrt_vartime`.
///
/// Variable time with respect to `self`.
#[must_use]
pub fn floor_sqrt_vartime(&self) -> Self {
// Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13

if self.is_zero_vartime() {
return Self::zero_with_precision(self.bits_precision());
}

// The initial guess: `x_0 = 2^ceil(b/2)`, where `2^(b-1) <= self < b`.
// Will not overflow since `b <= BITS`.
// The initial value of `x` is always greater than zero.
let mut x = Self::one_with_precision(self.bits_precision());
x.unbounded_shl_assign_vartime((self.bits() + 1) >> 1); // ≥ √(`self`)

let mut quo = Self::zero_with_precision(self.bits_precision());
let mut rem = Self::zero_with_precision(self.bits_precision());

loop {
// Calculate `x_{i+1} = floor((x_i + self / x_i) / 2)`
quo.limbs.copy_from_slice(&self.limbs);
rem.limbs.copy_from_slice(&x.limbs);
quo.as_mut_uint_ref().div_rem_vartime(rem.as_mut_uint_ref());
quo.carrying_add_assign(&x, Limb::ZERO);
quo.shr1_assign();

// If `quo` is the same as `x` or greater, we reached convergence
// (`x` is guaranteed to either go down or oscillate between
// `sqrt(self)` and `sqrt(self) + 1`)
if !x.cmp_vartime(&quo).is_gt() {
break;
}
x.limbs.copy_from_slice(&quo.limbs);
if x.is_zero_vartime() {
break;
}
}

x
let mut root = self.clone();
root.floor_sqrt_assign_vartime();
root
}

/// Wrapped sqrt is just `floor(√(self))`.
Expand All @@ -138,9 +69,9 @@ impl BoxedUint {
/// only if the square root is exact.
#[must_use]
pub fn checked_sqrt(&self) -> CtOption<Self> {
let r = self.floor_sqrt();
let s = r.wrapping_square();
CtOption::new(r, self.ct_eq(&s))
let mut root = self.clone();
let exact = root.floor_sqrt_assign();
CtOption::new(root, exact)
}

/// Perform checked sqrt, returning an [`Option`] which `is_some`
Expand All @@ -149,24 +80,32 @@ impl BoxedUint {
/// Variable time with respect to `self`.
#[must_use]
pub fn checked_sqrt_vartime(&self) -> Option<Self> {
let r = self.floor_sqrt_vartime();
let s = r.wrapping_square();
if self.cmp_vartime(&s).is_eq() {
Some(r)
let mut root = self.clone();
if root.floor_sqrt_assign_vartime() {
Some(root)
} else {
None
}
}

/// Assigns `floor(√(self))` to `self` in constant time, and returns a [`Choice`]
/// indicating whether the square root is exact.
fn floor_sqrt_assign(&mut self) -> Choice {
let size = self.nlimbs();
let mut buf = Self::zero_with_precision(self.bits_precision() * 2);
self.as_mut_uint_ref()
.sqrt_assign(buf.as_mut_uint_ref().split_at_mut(size))
}

/// Assigns `floor(√(self))` to `self`, and returns a [`bool`]
/// indicating whether the square root is exact.
///
/// Variable time with respect to `self`.
pub fn floor_sqrt_assign_vartime(&mut self) -> bool {
// TODO(tarcieri): more optimized implementation
let mut ret = self.floor_sqrt_vartime();
mem::swap(&mut ret, self);
self.concatenating_square() == ret
let size = self.nlimbs();
let mut buf = Self::zero_with_precision(self.bits_precision() * 2);
self.as_mut_uint_ref()
.sqrt_assign_vartime(buf.as_mut_uint_ref().split_at_mut(size))
}
}

Expand All @@ -192,14 +131,38 @@ impl FloorSquareRoot for BoxedUint {
}
}

impl CheckedSquareRoot for NonZero<BoxedUint> {
type Output = Self;

fn checked_sqrt(&self) -> CtOption<Self> {
self.as_ref().checked_sqrt().map(NonZero::new_unchecked)
}

fn checked_sqrt_vartime(&self) -> Option<Self> {
self.as_ref()
.checked_sqrt_vartime()
.map(NonZero::new_unchecked)
}
}

impl FloorSquareRoot for NonZero<BoxedUint> {
fn floor_sqrt(&self) -> Self {
NonZero::new_unchecked(self.as_ref().floor_sqrt())
}

fn floor_sqrt_vartime(&self) -> Self {
NonZero::new_unchecked(self.as_ref().floor_sqrt_vartime())
}
}

#[cfg(test)]
#[allow(clippy::integer_division_remainder_used, reason = "test")]
mod tests {
use crate::{BoxedUint, Limb};
use crate::{BoxedUint, CheckedSquareRoot, FloorSquareRoot, Limb};

#[cfg(feature = "rand_core")]
use {
crate::RandomBits,
crate::{ConcatenatingSquare, RandomBits},
chacha20::ChaCha8Rng,
rand_core::{Rng, SeedableRng},
};
Expand All @@ -218,7 +181,7 @@ mod tests {
for i in 0..half.limbs.len() / 2 {
half.limbs[i] = Limb::MAX;
}
let u256_max = !BoxedUint::zero_with_precision(256);
let u256_max = BoxedUint::max(256);
assert_eq!(u256_max.floor_sqrt(), half);

// Test edge cases that use up the maximum number of iterations.
Expand Down Expand Up @@ -307,8 +270,26 @@ mod tests {
let r = BoxedUint::from(*e);
assert_eq!(l.floor_sqrt(), r);
assert_eq!(l.floor_sqrt_vartime(), r);
assert!(l.checked_sqrt().is_some().to_bool());
assert!(l.checked_sqrt_vartime().is_some());
assert_eq!(
CheckedSquareRoot::checked_sqrt(&l).into_option().as_ref(),
Some(&r)
);
assert_eq!(
CheckedSquareRoot::checked_sqrt_vartime(&l).as_ref(),
Some(&r)
);
let nz_l = l.as_nz_vartime().unwrap();
let nz_r = r.to_nz().unwrap();
assert_eq!(FloorSquareRoot::floor_sqrt(nz_l), nz_r);
assert_eq!(FloorSquareRoot::floor_sqrt_vartime(nz_l), nz_r);
assert_eq!(
CheckedSquareRoot::checked_sqrt(nz_l).into_option().as_ref(),
Some(&nz_r)
);
assert_eq!(
CheckedSquareRoot::checked_sqrt_vartime(nz_l).as_ref(),
Some(&nz_r)
);
}
}

Expand Down Expand Up @@ -362,8 +343,6 @@ mod tests {
#[cfg(feature = "rand_core")]
#[test]
fn fuzz() {
use crate::{CheckedSquareRoot, ConcatenatingSquare};

let mut rng = ChaCha8Rng::from_seed([7u8; 32]);
let rounds = if cfg!(miri) { 10 } else { 50 };
for _ in 0..rounds {
Expand All @@ -384,4 +363,19 @@ mod tests {
assert_eq!(s.concatenating_square().floor_sqrt_vartime(), s2);
}
}

#[test]
// test input from issue #1303
fn sqrt_edge() {
let sq = BoxedUint::from_be_hex(
"0000000000000AB88E41D05334F646EE1B715C614C81DCE1C46E2E021EF2E7FA45E5DD2CC670725E52B091A0E08B1E9E6C43262581226A7BFDE534F704D6DC6C1FD5C93B3AE4FB73C9D26EBC09DB83040C2CD157884532EB5578CEF2BA95E87C7F6E328F062BB20D3C95D122EB4D26727901A424AF14143EBCEFE5D5958E9EDEA73649F5C76932D491C4A51C370F51E6DBEB66D56D7FB51A",
1216,
).unwrap();
let expect = BoxedUint::from_be_hex(
"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000346375360711FDFC13BE3CF276FCC535C2E29164A842D3B71E2C7DFA60AE0E24CD0CC65AD7A3B626E03DAC8260D921AB8E16C6757D2CA8DEE5DFA62E5C1CB3C90A49B9F1A192498905",
1216,
).unwrap();
assert_eq!(sq.floor_sqrt(), expect);
assert_eq!(sq.floor_sqrt_vartime(), expect);
}
}
1 change: 1 addition & 0 deletions src/uint/ref_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod mul;
mod shl;
mod shr;
mod slice;
mod sqrt;
mod sub;

use crate::{Choice, Limb, NonZero, Odd, Uint, Word};
Expand Down
Loading
Loading