From ab0268606dfb238b74dcaa0cc7624a0fe92d8fd5 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sun, 11 Jan 2026 19:03:31 +0900 Subject: [PATCH] _bigint helpers --- README.md | 4 + src/math.rs | 4 + src/math/bigint.rs | 499 ++++++++++++++++++++++++++++++++++++++++++++ src/math/integer.rs | 47 +---- 4 files changed, 517 insertions(+), 37 deletions(-) create mode 100644 src/math/bigint.rs diff --git a/README.md b/README.md index a1320da..675785e 100644 --- a/README.md +++ b/README.md @@ -156,3 +156,7 @@ assert_eq!( - `pymath::math` - Real number math functions (Python's `math` module) - `pymath::cmath` - Complex number functions (Python's `cmath` module) - `pymath::m` - Direct libm bindings + +## Important Note + +This library guarantees **compatibility with Python's math module**, not mathematical correctness. diff --git a/src/math.rs b/src/math.rs index 9b40faf..2ea1e48 100644 --- a/src/math.rs +++ b/src/math.rs @@ -2,6 +2,8 @@ // Submodules mod aggregate; +#[cfg(feature = "_bigint")] +mod bigint; mod exponential; mod gamma; #[cfg(feature = "_bigint")] @@ -11,6 +13,8 @@ mod trigonometric; // Re-export from submodules pub use aggregate::{dist, fsum, prod, prod_int, sumprod, sumprod_int}; +#[cfg(feature = "_bigint")] +pub use bigint::{comb_bigint, ldexp_bigint, log_bigint, log2_bigint, log10_bigint, perm_bigint}; pub use exponential::{cbrt, exp, exp2, expm1, log, log1p, log2, log10, pow, sqrt}; pub use gamma::{erf, erfc, gamma, lgamma}; pub use misc::{ diff --git a/src/math/bigint.rs b/src/math/bigint.rs new file mode 100644 index 0000000..2ba4488 --- /dev/null +++ b/src/math/bigint.rs @@ -0,0 +1,499 @@ +//! BigInt helper functions for RustPython. +//! +//! These functions handle cases where Python integers exceed i64 range. +//! They are not part of Python's math.integer module but are needed +//! for RustPython's internal implementation. + +use super::integer::perm_comb_small; +#[cfg(feature = "malachite-bigint")] +pub(crate) use malachite_bigint::{BigInt, BigUint}; +#[cfg(feature = "num-bigint")] +pub(crate) use num_bigint::{BigInt, BigUint}; +use num_traits::{One, Signed, ToPrimitive}; + +// Perm/Comb for BigInt n + +/// Compute perm(n, k) where n is a BigInt and k fits in u64. +/// Uses divide-and-conquer: P(n, k) = P(n, j) * P(n-j, k-j) +/// +/// See: perm_comb in CPython mathmodule.c +pub fn perm_bigint(n: &BigInt, k: u64) -> BigUint { + if k == 0 { + return BigUint::one(); + } + if k == 1 { + return n.magnitude().clone(); + } + + // P(n, k) = P(n, j) * P(n-j, k-j) + let j = k / 2; + let a = perm_bigint(n, j); + let n_minus_j = n - BigInt::from(j); + let b = perm_bigint(&n_minus_j, k - j); + a * b +} + +/// Compute comb(n, k) where n is a BigInt and k fits in u64. +/// Uses divide-and-conquer: C(n, k) = C(n, j) * C(n-j, k-j) / C(k, j) +/// +/// See: perm_comb in CPython mathmodule.c +pub fn comb_bigint(n: &BigInt, k: u64) -> BigUint { + if k == 0 { + return BigUint::one(); + } + if k == 1 { + return n.magnitude().clone(); + } + + // C(n, k) = C(n, j) * C(n-j, k-j) / C(k, j) + let j = k / 2; + let a = comb_bigint(n, j); + let n_minus_j = n - BigInt::from(j); + let b = comb_bigint(&n_minus_j, k - j); + let numerator = a * b; + // C(k, j) using small version since k fits in u64 + let divisor = perm_comb_small(k, j, true); + numerator / divisor +} + +// Logarithm functions for BigInt + +/// Compute frexp-like decomposition for BigInt. +/// Returns (mantissa, exponent) where: +/// - mantissa is in [0.5, 1.0) for positive n +/// - n ~= mantissa * 2^exponent +/// +/// See: _PyLong_Frexp in CPython longobject.c +fn frexp_bigint(n: &BigInt) -> (f64, i64) { + let bits = n.bits(); + if bits == 0 { + return (0.0, 0); + } + + let bits = bits as i64; + + // For small integers that fit in f64 mantissa (53 bits) + if bits <= 53 { + let x = n.to_f64().unwrap(); + // frexp returns (m, e) where x = m * 2^e and 0.5 <= |m| < 1 + let mut e: i32 = 0; + let m = crate::m::frexp(x, &mut e); + return (m, e as i64); + } + + // For large integers, extract top ~53 bits + // Shift right to keep DBL_MANT_DIG + 2 = 55 bits for rounding + let shift = bits - 55; + let mantissa_int = n >> shift as u64; + let mut x = mantissa_int.to_f64().unwrap(); + + // x is now approximately n / 2^shift, with ~55 bits of precision + // Scale to [0.5, 1.0) range + // x is in [2^54, 2^55), so divide by 2^55 to get [0.5, 1.0) + x /= (1u64 << 55) as f64; + + // Adjust if rounding pushed us to 1.0 + if x == 1.0 { + x = 0.5; + return (x, bits + 1); + } + + (x, bits) +} + +/// Return the natural logarithm of a BigInt. +/// +/// Returns Err(EDOM) if n is not positive. +pub fn log_bigint(n: &BigInt, base: Option) -> crate::Result { + if !n.is_positive() { + return Err(crate::Error::EDOM); + } + + // Try direct conversion first + if let Some(x) = n.to_f64() { + if x.is_finite() { + return super::log(x, base); + } + } + + // Use frexp decomposition for large values + // n ~= x * 2^e, so log(n) = log(x) + log(2) * e + let (x, e) = frexp_bigint(n); + let log_n = crate::m::log(x) + std::f64::consts::LN_2 * (e as f64); + + match base { + None => Ok(log_n), + Some(b) => { + if b <= 0.0 || b == 1.0 { + return Err(crate::Error::EDOM); + } + Ok(log_n / crate::m::log(b)) + } + } +} + +/// Return the base-2 logarithm of a BigInt. +/// +/// Returns Err(EDOM) if n is not positive. +pub fn log2_bigint(n: &BigInt) -> crate::Result { + if !n.is_positive() { + return Err(crate::Error::EDOM); + } + + // Try direct conversion first + if let Some(x) = n.to_f64() { + if x.is_finite() { + return super::log2(x); + } + } + + // Use frexp decomposition for large values + // n ~= x * 2^e, so log2(n) = log2(x) + e + let (x, e) = frexp_bigint(n); + Ok(crate::m::log2(x) + (e as f64)) +} + +/// Return the base-10 logarithm of a BigInt. +/// +/// Returns Err(EDOM) if n is not positive. +pub fn log10_bigint(n: &BigInt) -> crate::Result { + if !n.is_positive() { + return Err(crate::Error::EDOM); + } + + // Try direct conversion first + if let Some(x) = n.to_f64() { + if x.is_finite() { + return super::log10(x); + } + } + + // Use frexp decomposition for large values + // n ~= x * 2^e, so log10(n) = log10(x) + log10(2) * e + let (x, e) = frexp_bigint(n); + Ok(crate::m::log10(x) + std::f64::consts::LOG10_2 * (e as f64)) +} + +/// Compute ldexp(x, exp) where exp is a BigInt. +/// +/// Returns x * 2^exp, handling BigInt exponent overflow: +/// - If x is 0, inf, or nan, returns x unchanged +/// - If exp overflows i32 positively, returns ERANGE (overflow) +/// - If exp overflows i32 negatively, returns signed zero (underflow) +pub fn ldexp_bigint(x: f64, exp: &BigInt) -> crate::Result { + // Special values are returned unchanged regardless of exponent + if x == 0.0 || !x.is_finite() { + return Ok(x); + } + + // Fast path: try i64 first (like CPython's PyLong_AsLongAndOverflow) + let exp_clamped: i64 = match exp.try_into() { + Ok(e) => e, + Err(_) => { + // Exponent overflows i64, clamp to i64::MIN/MAX + if exp.is_negative() { + i64::MIN + } else { + i64::MAX + } + } + }; + + // Check against i32 bounds + if exp_clamped > i32::MAX as i64 { + // overflow + Err(crate::Error::ERANGE) + } else if exp_clamped < i32::MIN as i64 { + // underflow to signed zero + Ok(if x.is_sign_negative() { -0.0 } else { 0.0 }) + } else { + super::ldexp(x, exp_clamped as i32) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pyo3::prelude::*; + + fn test_log_bigint_impl(n: &BigInt, base: Option) { + let rs = log_bigint(n, base); + crate::test::with_py_math(|py, math| { + // Convert BigInt to Python int via string using builtins.int() + let n_str = n.to_string(); + let builtins = pyo3::types::PyModule::import(py, "builtins").unwrap(); + let py_n = builtins + .getattr("int") + .unwrap() + .call1((n_str.as_str(),)) + .unwrap(); + let py_result = match base { + Some(b) => math.getattr("log").unwrap().call1((py_n, b)), + None => math.getattr("log").unwrap().call1((py_n,)), + }; + match py_result { + Ok(py_val) => { + let py_f: f64 = py_val.extract().unwrap(); + let rs_f = rs.unwrap(); + if py_f.is_nan() && rs_f.is_nan() { + return; + } + // Handle exact equality (e.g., log(1) = 0) + if py_f == rs_f { + return; + } + // Allow small relative error for large numbers + let rel_err = ((py_f - rs_f) / py_f).abs(); + assert!( + rel_err < 1e-10, + "log_bigint({n}, {base:?}): py={py_f} vs rs={rs_f}, rel_err={rel_err}" + ); + } + Err(_) => { + assert!( + rs.is_err(), + "log_bigint({n}, {base:?}): py raised error but rs={rs:?}" + ); + } + } + }); + } + + fn test_log2_bigint_impl(n: &BigInt) { + let rs = log2_bigint(n); + crate::test::with_py_math(|py, math| { + let n_str = n.to_string(); + let builtins = pyo3::types::PyModule::import(py, "builtins").unwrap(); + let py_n = builtins + .getattr("int") + .unwrap() + .call1((n_str.as_str(),)) + .unwrap(); + let py_result = math.getattr("log2").unwrap().call1((py_n,)); + match py_result { + Ok(py_val) => { + let py_f: f64 = py_val.extract().unwrap(); + let rs_f = rs.unwrap(); + if py_f.is_nan() && rs_f.is_nan() { + return; + } + if py_f == rs_f { + return; + } + let rel_err = ((py_f - rs_f) / py_f).abs(); + assert!( + rel_err < 1e-10, + "log2_bigint({n}): py={py_f} vs rs={rs_f}, rel_err={rel_err}" + ); + } + Err(_) => { + assert!( + rs.is_err(), + "log2_bigint({n}): py raised error but rs={rs:?}" + ); + } + } + }); + } + + fn test_log10_bigint_impl(n: &BigInt) { + let rs = log10_bigint(n); + crate::test::with_py_math(|py, math| { + let n_str = n.to_string(); + let builtins = pyo3::types::PyModule::import(py, "builtins").unwrap(); + let py_n = builtins + .getattr("int") + .unwrap() + .call1((n_str.as_str(),)) + .unwrap(); + let py_result = math.getattr("log10").unwrap().call1((py_n,)); + match py_result { + Ok(py_val) => { + let py_f: f64 = py_val.extract().unwrap(); + let rs_f = rs.unwrap(); + if py_f.is_nan() && rs_f.is_nan() { + return; + } + if py_f == rs_f { + return; + } + let rel_err = ((py_f - rs_f) / py_f).abs(); + assert!( + rel_err < 1e-10, + "log10_bigint({n}): py={py_f} vs rs={rs_f}, rel_err={rel_err}" + ); + } + Err(_) => { + assert!( + rs.is_err(), + "log10_bigint({n}): py raised error but rs={rs:?}" + ); + } + } + }); + } + + #[test] + fn edgetest_log_bigint() { + // Small values + for i in -10i64..=100 { + let n = BigInt::from(i); + test_log_bigint_impl(&n, None); + test_log_bigint_impl(&n, Some(2.0)); + test_log_bigint_impl(&n, Some(10.0)); + } + // Powers of 2 (important for log2) + for exp in 0..100u32 { + let n = BigInt::from(1u64) << exp; + test_log_bigint_impl(&n, None); + test_log_bigint_impl(&n, Some(2.0)); + } + // Large values + let ten = BigInt::from(10); + for exp in [100, 200, 500, 1000] { + let n = ten.pow(exp); + test_log_bigint_impl(&n, None); + test_log_bigint_impl(&n, Some(2.0)); + test_log_bigint_impl(&n, Some(10.0)); + } + } + + #[test] + fn edgetest_log2_bigint() { + for i in -10i64..=100 { + test_log2_bigint_impl(&BigInt::from(i)); + } + for exp in 0..100u32 { + test_log2_bigint_impl(&(BigInt::from(1u64) << exp)); + } + let ten = BigInt::from(10); + for exp in [100, 200, 500, 1000] { + test_log2_bigint_impl(&ten.pow(exp)); + } + } + + #[test] + fn edgetest_log10_bigint() { + for i in -10i64..=100 { + test_log10_bigint_impl(&BigInt::from(i)); + } + let ten = BigInt::from(10); + for exp in [10, 50, 100, 200, 500, 1000] { + test_log10_bigint_impl(&ten.pow(exp)); + } + } + + proptest::proptest! { + #[test] + fn proptest_log_bigint(n in 1i64..1_000_000i64) { + test_log_bigint_impl(&BigInt::from(n), None); + } + + #[test] + fn proptest_log_bigint_base2(n in 1i64..1_000_000i64) { + test_log_bigint_impl(&BigInt::from(n), Some(2.0)); + } + + #[test] + fn proptest_log_bigint_base10(n in 1i64..1_000_000i64) { + test_log_bigint_impl(&BigInt::from(n), Some(10.0)); + } + + #[test] + fn proptest_log2_bigint(n in 1i64..1_000_000i64) { + test_log2_bigint_impl(&BigInt::from(n)); + } + + #[test] + fn proptest_log10_bigint(n in 1i64..1_000_000i64) { + test_log10_bigint_impl(&BigInt::from(n)); + } + } + + // ldexp_bigint tests + + fn test_ldexp_bigint_impl(x: f64, exp: &BigInt) { + let rs = ldexp_bigint(x, exp); + crate::test::with_py_math(|py, math| { + let exp_str = exp.to_string(); + let builtins = pyo3::types::PyModule::import(py, "builtins").unwrap(); + let py_exp = builtins + .getattr("int") + .unwrap() + .call1((exp_str.as_str(),)) + .unwrap(); + let py_result = math.getattr("ldexp").unwrap().call1((x, py_exp)); + match py_result { + Ok(py_val) => { + let py_f: f64 = py_val.extract().unwrap(); + let rs_f = rs.unwrap(); + // Handle NaN + if py_f.is_nan() && rs_f.is_nan() { + return; + } + // Handle exact equality (including signed zeros and infinities) + if py_f == rs_f && py_f.is_sign_positive() == rs_f.is_sign_positive() { + return; + } + // Handle infinities + if py_f.is_infinite() && rs_f.is_infinite() { + assert_eq!( + py_f.is_sign_positive(), + rs_f.is_sign_positive(), + "ldexp_bigint({x}, {exp}): sign mismatch py={py_f} vs rs={rs_f}" + ); + return; + } + panic!("ldexp_bigint({x}, {exp}): py={py_f} vs rs={rs_f}"); + } + Err(_) => { + assert!( + rs.is_err(), + "ldexp_bigint({x}, {exp}): py raised error but rs={rs:?}" + ); + } + } + }); + } + + #[test] + fn edgetest_ldexp_bigint() { + let x_vals = [ + 0.0, + -0.0, + 1.0, + -1.0, + 0.5, + -0.5, + 2.0, + -2.0, + f64::INFINITY, + f64::NEG_INFINITY, + f64::NAN, + ]; + let exp_vals: Vec = vec![ + BigInt::from(0), + BigInt::from(1), + BigInt::from(-1), + BigInt::from(10), + BigInt::from(-10), + BigInt::from(100), + BigInt::from(-100), + BigInt::from(1000), + BigInt::from(-1000), + BigInt::from(i32::MAX), + BigInt::from(i32::MIN), + // Overflow cases + BigInt::from(i32::MAX as i64 + 1), + BigInt::from(i32::MIN as i64 - 1), + BigInt::from(10).pow(20), + -BigInt::from(10).pow(20), + ]; + + for &x in &x_vals { + for exp in &exp_vals { + test_ldexp_bigint_impl(x, exp); + } + } + } +} diff --git a/src/math/integer.rs b/src/math/integer.rs index 521997d..cf08eb9 100644 --- a/src/math/integer.rs +++ b/src/math/integer.rs @@ -3,11 +3,7 @@ //! Integer-related mathematical functions. //! This module requires either `num-bigint` or `malachite-bigint` feature. -#[cfg(feature = "malachite-bigint")] -use malachite_bigint::{BigInt, BigUint}; -#[cfg(feature = "num-bigint")] -use num_bigint::{BigInt, BigUint}; - +use super::bigint::{BigInt, BigUint}; use num_integer::Integer; use num_traits::{One, Signed, ToPrimitive, Zero}; @@ -16,13 +12,13 @@ use num_traits::{One, Signed, ToPrimitive, Zero}; /// gcd() with no arguments returns 0. /// gcd(a) returns abs(a). #[inline] -pub fn gcd(args: &[BigInt]) -> BigInt { +pub fn gcd(args: &[&BigInt]) -> BigInt { if args.is_empty() { return BigInt::zero(); } let mut result = args[0].abs(); for arg in &args[1..] { - result = result.gcd(arg); + result = result.gcd(*arg); if result.is_one() { return result; } @@ -35,7 +31,7 @@ pub fn gcd(args: &[BigInt]) -> BigInt { /// lcm() with no arguments returns 1. /// lcm(a) returns abs(a). #[inline] -pub fn lcm(args: &[BigInt]) -> BigInt { +pub fn lcm(args: &[&BigInt]) -> BigInt { if args.is_empty() { return BigInt::one(); } @@ -44,7 +40,7 @@ pub fn lcm(args: &[BigInt]) -> BigInt { if result.is_zero() || arg.is_zero() { return BigInt::zero(); } - let g = result.gcd(arg); + let g = result.gcd(*arg); result = result / &g * arg.abs(); } result @@ -587,7 +583,7 @@ const FAST_PERM_LIMITS: [u64; 21] = [ ]; /// Calculate C(n, k) or P(n, k) for n in the 63-bit range. -fn perm_comb_small(n: u64, k: u64, is_comb: bool) -> BigUint { +pub(super) fn perm_comb_small(n: u64, k: u64, is_comb: bool) -> BigUint { if k == 0 { return BigUint::one(); } @@ -654,31 +650,6 @@ fn perm_comb_small(n: u64, k: u64, is_comb: bool) -> BigUint { result } -/// Calculate P(n, k) or C(n, k) using recursive formulas for big n. -/// Reserved for future BigUint n support. -#[allow(dead_code)] -fn perm_comb(n: &BigUint, k: u64, is_comb: bool) -> BigUint { - if k == 0 { - return BigUint::one(); - } - if k == 1 { - return n.clone(); - } - - // P(n, k) = P(n, j) * P(n-j, k-j) - // C(n, k) = C(n, j) * C(n-j, k-j) / C(k, j) - let j = k / 2; - let a = perm_comb(n, j, is_comb); - let n_minus_j = n - BigUint::from(j); - let b = perm_comb(&n_minus_j, k - j, is_comb); - let mut result = a * b; - if is_comb { - let c = perm_comb_small(k, j, true); - result /= c; - } - result -} - /// Return the number of ways to choose k items from n items (n choose k). /// /// Returns Err(()) if n or k is negative. @@ -781,7 +752,8 @@ mod tests { fn test_gcd_impl(args: &[i64]) { use std::str::FromStr; let bigints: Vec = args.iter().map(|&x| BigInt::from(x)).collect(); - let rs = gcd(&bigints); + let refs: Vec<_> = bigints.iter().collect(); + let rs = gcd(&refs); crate::test::with_py_math(|py, math| { let py_args = pyo3::types::PyTuple::new(py, args).unwrap(); let py_result = math.getattr("gcd").unwrap().call1(py_args).unwrap(); @@ -794,7 +766,8 @@ mod tests { fn test_lcm_impl(args: &[i64]) { use std::str::FromStr; let bigints: Vec = args.iter().map(|&x| BigInt::from(x)).collect(); - let rs = lcm(&bigints); + let refs: Vec<_> = bigints.iter().collect(); + let rs = lcm(&refs); crate::test::with_py_math(|py, math| { let py_args = pyo3::types::PyTuple::new(py, args).unwrap(); let py_result = math.getattr("lcm").unwrap().call1(py_args).unwrap();