diff --git a/src/uint/ref_type.rs b/src/uint/ref_type.rs index 1f0a501d..11db152c 100644 --- a/src/uint/ref_type.rs +++ b/src/uint/ref_type.rs @@ -30,7 +30,6 @@ use { /// This type contains a limb slice which can be borrowed from either a [`Uint`] or [`BoxedUint`] and /// thus provides an abstraction for writing shared implementations. #[repr(transparent)] -#[derive(PartialEq, Eq)] pub struct UintRef { /// Inner limb array. Stored from least significant to most significant. // TODO(tarcieri): maintain an invariant of at least one limb? diff --git a/src/uint/ref_type/cmp.rs b/src/uint/ref_type/cmp.rs index 29d96949..1334dd0c 100644 --- a/src/uint/ref_type/cmp.rs +++ b/src/uint/ref_type/cmp.rs @@ -5,7 +5,7 @@ use core::{cmp::Ordering, mem::transmute}; use super::UintRef; -use crate::{Choice, Limb, word}; +use crate::{Choice, CtEq, Limb, word}; impl UintRef { /// Returns the truthy value if `self` is odd or the falsy value otherwise. @@ -117,7 +117,7 @@ impl UintRef { Ordering::Equal } - /// Returns the truthy value if `self < rhs` and the falsy value otherwise. + /// Returns the truthy value if `lhs < rhs` and the falsy value otherwise. #[inline(always)] pub(crate) const fn lt(lhs: &Self, rhs: &Self) -> Choice { let overlap = if lhs.nlimbs() < rhs.nlimbs() { @@ -140,6 +140,26 @@ impl UintRef { } } +impl Eq for UintRef {} + +impl Ord for UintRef { + fn cmp(&self, other: &Self) -> Ordering { + Self::cmp(self, other) + } +} + +impl PartialOrd for UintRef { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for UintRef { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + #[cfg(test)] mod tests { use core::cmp::Ordering; @@ -149,40 +169,47 @@ mod tests { #[test] fn cmp() { + fn check(a: &UintRef, b: &UintRef, ord: Ordering) { + assert_eq!(UintRef::cmp(a, b), ord); + assert_eq!(UintRef::cmp_vartime(a, b), ord); + assert_eq!(a.cmp(b), ord); + if ord == Ordering::Equal { + assert_eq!(a, b); + } else { + assert_ne!(a, b); + } + } + let z_small = UintRef::new(&[Limb::ZERO, Limb::ZERO]); let z_big = UintRef::new(&[Limb::ZERO, Limb::ZERO, Limb::ZERO]); let a = UintRef::new(&[Limb::ZERO, Limb::ZERO, Limb::ONE]); let b = UintRef::new(&[Limb::ONE, Limb::ZERO]); - assert_eq!(UintRef::cmp(z_small, z_big), Ordering::Equal); - assert_eq!(UintRef::cmp(z_big, z_small), Ordering::Equal); - assert_eq!(UintRef::cmp(z_small, a), Ordering::Less); - assert_eq!(UintRef::cmp(z_big, a), Ordering::Less); - assert_eq!(UintRef::cmp(a, z_small), Ordering::Greater); - assert_eq!(UintRef::cmp(a, z_big), Ordering::Greater); - assert_eq!(UintRef::cmp(a, b), Ordering::Greater); - assert_eq!(UintRef::cmp(b, a), Ordering::Less); - - assert_eq!(UintRef::cmp_vartime(z_small, z_big), Ordering::Equal); - assert_eq!(UintRef::cmp_vartime(z_big, z_small), Ordering::Equal); - assert_eq!(UintRef::cmp_vartime(z_small, a), Ordering::Less); - assert_eq!(UintRef::cmp_vartime(z_big, a), Ordering::Less); - assert_eq!(UintRef::cmp_vartime(a, z_small), Ordering::Greater); - assert_eq!(UintRef::cmp_vartime(a, z_big), Ordering::Greater); - assert_eq!(UintRef::cmp_vartime(a, b), Ordering::Greater); - assert_eq!(UintRef::cmp_vartime(b, a), Ordering::Less); + check(z_small, z_big, Ordering::Equal); + check(z_big, z_small, Ordering::Equal); + check(z_small, a, Ordering::Less); + check(z_big, a, Ordering::Less); + check(a, z_small, Ordering::Greater); + check(a, z_big, Ordering::Greater); + check(a, b, Ordering::Greater); + check(b, a, Ordering::Less); } #[test] fn lt() { + fn check(a: &UintRef, b: &UintRef) { + assert!(UintRef::lt(a, b).to_bool_vartime()); + assert!(!UintRef::lt(b, a).to_bool_vartime()); + assert!(a < b); + assert!(b > a); + } + let lesser = UintRef::new(&[Limb::ZERO, Limb::ZERO, Limb::ZERO]); let greater = UintRef::new(&[Limb::ZERO, Limb::ONE, Limb::ZERO]); - assert!(UintRef::lt(lesser, greater).to_bool()); - assert!(!UintRef::lt(greater, lesser).to_bool()); + check(lesser, greater); let smaller = UintRef::new(&[Limb::ZERO, Limb::ZERO]); let bigger = UintRef::new(&[Limb::ZERO, Limb::ZERO, Limb::ONE]); - assert!(UintRef::lt(smaller, bigger).to_bool()); - assert!(!UintRef::lt(bigger, smaller).to_bool()); + check(smaller, bigger); } }