Skip to content

Commit b3b5f47

Browse files
committed
gh-90345: Add math.integer.isqrt_rem()
1 parent f159419 commit b3b5f47

6 files changed

Lines changed: 135 additions & 14 deletions

File tree

Doc/library/math.integer.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ computed exactly and are integers.
6262
:trim:
6363

6464

65+
.. function:: isqrt_rem(n, /)
66+
67+
Return a pair of values (s,t) such that s=isqrt(n) and t=n-s*s.
68+
The remainder *t* will be zero, if *n* is a perfect square.
69+
70+
6571
.. function:: lcm(*integers)
6672

6773
Return the least common multiple of the specified integer arguments.

Doc/whatsnew/3.16.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ New modules
8686
Improved modules
8787
================
8888

89+
math.integer
90+
------------
91+
92+
* Add :func:`math.integer.isqrt_rem` to compute integer square root with
93+
a remainder.
94+
(Contributed by Sergey B Kirpichev in :gh:`90345`.)
95+
8996
os
9097
--
9198

Lib/test/test_math_integer.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from fractions import Fraction
33
import unittest
44
from test import support
5+
from math.integer import isqrt_rem
56

67

78
class IntSubclass(int):
@@ -249,6 +250,44 @@ def test_isqrt_huge(self, size):
249250
self.assertEqual(w.bit_length(), size // 2 + 1)
250251
self.assertEqual(w.bit_count(), 1)
251252

253+
def test_isqrt_rem(self):
254+
test_values = (
255+
list(range(1000))
256+
+ list(range(10**6 - 1000, 10**6 + 1000))
257+
+ [2**e + i for e in range(60, 200) for i in range(-40, 40)]
258+
+ [3**9999, 10**5001]
259+
)
260+
for value in test_values:
261+
with self.subTest(value=value):
262+
root, rem = isqrt_rem(value)
263+
self.assertIs(type(root), int)
264+
self.assertLessEqual(root*root, value)
265+
self.assertLess(value, (root+1)*(root+1))
266+
self.assertIs(type(rem), int)
267+
self.assertEqual(rem, value - root*root)
268+
269+
# Negative values
270+
with self.assertRaises(ValueError):
271+
isqrt_rem(-1)
272+
273+
# Integer-like things
274+
self.assertEqual(isqrt_rem(True), (1, 0))
275+
self.assertEqual(isqrt_rem(False), (0, 0))
276+
self.assertEqual(isqrt_rem(MyIndexable(1729)), (41, 48))
277+
278+
with self.assertRaises(ValueError):
279+
isqrt_rem(MyIndexable(-3))
280+
281+
# Non-integer-like things
282+
bad_values = [
283+
3.5, "a string", Decimal("3.5"), 3.5j,
284+
100.0, -4.0,
285+
]
286+
for value in bad_values:
287+
with self.subTest(value=value):
288+
with self.assertRaises(TypeError):
289+
isqrt_rem(value)
290+
252291
def test_perm(self):
253292
perm = self.module.perm
254293
factorial = self.module.factorial
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add :func:`math.integer.isqrt_rem`.

Modules/clinic/mathintegermodule.c.h

Lines changed: 10 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Modules/mathintegermodule.c

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -341,18 +341,8 @@ _approximate_isqrt(uint64_t n)
341341
return (u << 15) + (uint32_t)((n >> 17) / u);
342342
}
343343

344-
/*[clinic input]
345-
math.integer.isqrt
346-
347-
n: object
348-
/
349-
350-
Return the integer part of the square root of the input.
351-
[clinic start generated code]*/
352-
353344
static PyObject *
354-
math_integer_isqrt(PyObject *module, PyObject *n)
355-
/*[clinic end generated code: output=551031e41a0f5d9e input=921ddd9853133d8d]*/
345+
_isqrt_rem(PyObject *n, PyObject **rem)
356346
{
357347
int a_too_large, c_bit_length;
358348
int64_t c, d;
@@ -373,6 +363,9 @@ math_integer_isqrt(PyObject *module, PyObject *n)
373363
}
374364
if (_PyLong_IsZero((PyLongObject *)n)) {
375365
Py_DECREF(n);
366+
if (rem) {
367+
*rem = PyLong_FromLong(0);
368+
}
376369
return PyLong_FromLong(0);
377370
}
378371

@@ -392,7 +385,15 @@ math_integer_isqrt(PyObject *module, PyObject *n)
392385
return NULL;
393386
}
394387
u = _approximate_isqrt(m << 2*shift) >> shift;
395-
u -= (uint64_t)u * u > m;
388+
uint64_t sq = (uint64_t)u * u;
389+
u -= sq > m;
390+
if (rem) {
391+
if (sq > m) {
392+
sq -= 2*(uint64_t)u + 1;
393+
}
394+
m -= sq;
395+
*rem = PyLong_FromUnsignedLongLong(m);
396+
}
396397
return PyLong_FromUnsignedLong(u);
397398
}
398399

@@ -460,14 +461,26 @@ math_integer_isqrt(PyObject *module, PyObject *n)
460461
goto error;
461462
}
462463
a_too_large = PyObject_RichCompareBool(n, b, Py_LT);
463-
Py_DECREF(b);
464464
if (a_too_large == -1) {
465+
Py_DECREF(b);
465466
goto error;
466467
}
467468

468469
if (a_too_large) {
470+
if (rem) {
471+
Py_SETREF(b, PyNumber_Add(b, _PyLong_GetOne()));
472+
Py_SETREF(b, PyNumber_Subtract(b, a));
473+
Py_SETREF(b, PyNumber_Subtract(b, a));
474+
}
469475
Py_SETREF(a, PyNumber_Subtract(a, _PyLong_GetOne()));
470476
}
477+
if (!rem) {
478+
Py_DECREF(b);
479+
}
480+
else {
481+
Py_SETREF(b, PyNumber_Subtract(n, b));
482+
*rem = b;
483+
}
471484
Py_DECREF(n);
472485
return a;
473486

@@ -478,6 +491,51 @@ math_integer_isqrt(PyObject *module, PyObject *n)
478491
}
479492

480493

494+
/*[clinic input]
495+
math.integer.isqrt
496+
497+
n: object
498+
/
499+
500+
Return the integer part of the square root of the input.
501+
[clinic start generated code]*/
502+
503+
static PyObject *
504+
math_integer_isqrt(PyObject *module, PyObject *n)
505+
/*[clinic end generated code: output=551031e41a0f5d9e input=921ddd9853133d8d]*/
506+
{
507+
return _isqrt_rem(n, NULL);
508+
}
509+
510+
/*[clinic input]
511+
math.integer.isqrt_rem
512+
513+
n: object
514+
/
515+
516+
Return a pair of values (s,t) such that s=isqrt(n) and t=n-s*s.
517+
[clinic start generated code]*/
518+
519+
static PyObject *
520+
math_integer_isqrt_rem(PyObject *module, PyObject *n)
521+
/*[clinic end generated code: output=b17d11479d08cdc4 input=7ed2dd870818d2bb]*/
522+
{
523+
PyObject *rem = NULL;
524+
PyObject *root = _isqrt_rem(n, &rem);
525+
526+
if (root && rem) {
527+
PyObject *tup = PyTuple_Pack(2, root, rem);
528+
529+
Py_DECREF(root);
530+
Py_DECREF(rem);
531+
return tup;
532+
}
533+
Py_XDECREF(root);
534+
Py_XDECREF(rem);
535+
return NULL;
536+
}
537+
538+
481539
static unsigned long
482540
count_set_bits(unsigned long n)
483541
{
@@ -1231,6 +1289,7 @@ static PyMethodDef math_integer_methods[] = {
12311289
MATH_INTEGER_FACTORIAL_METHODDEF
12321290
MATH_INTEGER_GCD_METHODDEF
12331291
MATH_INTEGER_ISQRT_METHODDEF
1292+
MATH_INTEGER_ISQRT_REM_METHODDEF
12341293
MATH_INTEGER_LCM_METHODDEF
12351294
MATH_INTEGER_PERM_METHODDEF
12361295
{NULL, NULL} /* sentinel */

0 commit comments

Comments
 (0)