@@ -341,18 +341,8 @@ _approximate_isqrt(uint64_t n)
341341 return (u << 15 ) + (uint32_t )((n >> 17 ) / u );
342342}
343343
344- /*[clinic input]
345- math.integer.isqrt
346-
347- n: object
348- /
349-
350- Return the integer part of the square root of the input.
351- [clinic start generated code]*/
352-
353344static PyObject *
354- math_integer_isqrt (PyObject * module , PyObject * n )
355- /*[clinic end generated code: output=551031e41a0f5d9e input=921ddd9853133d8d]*/
345+ _isqrt_rem (PyObject * n , PyObject * * rem )
356346{
357347 int a_too_large , c_bit_length ;
358348 int64_t c , d ;
@@ -373,6 +363,9 @@ math_integer_isqrt(PyObject *module, PyObject *n)
373363 }
374364 if (_PyLong_IsZero ((PyLongObject * )n )) {
375365 Py_DECREF (n );
366+ if (rem ) {
367+ * rem = PyLong_FromLong (0 );
368+ }
376369 return PyLong_FromLong (0 );
377370 }
378371
@@ -392,7 +385,15 @@ math_integer_isqrt(PyObject *module, PyObject *n)
392385 return NULL ;
393386 }
394387 u = _approximate_isqrt (m << 2 * shift ) >> shift ;
395- u -= (uint64_t )u * u > m ;
388+ uint64_t sq = (uint64_t )u * u ;
389+ u -= sq > m ;
390+ if (rem ) {
391+ if (sq > m ) {
392+ sq -= 2 * (uint64_t )u + 1 ;
393+ }
394+ m -= sq ;
395+ * rem = PyLong_FromUnsignedLongLong (m );
396+ }
396397 return PyLong_FromUnsignedLong (u );
397398 }
398399
@@ -460,14 +461,26 @@ math_integer_isqrt(PyObject *module, PyObject *n)
460461 goto error ;
461462 }
462463 a_too_large = PyObject_RichCompareBool (n , b , Py_LT );
463- Py_DECREF (b );
464464 if (a_too_large == -1 ) {
465+ Py_DECREF (b );
465466 goto error ;
466467 }
467468
468469 if (a_too_large ) {
470+ if (rem ) {
471+ Py_SETREF (b , PyNumber_Add (b , _PyLong_GetOne ()));
472+ Py_SETREF (b , PyNumber_Subtract (b , a ));
473+ Py_SETREF (b , PyNumber_Subtract (b , a ));
474+ }
469475 Py_SETREF (a , PyNumber_Subtract (a , _PyLong_GetOne ()));
470476 }
477+ if (!rem ) {
478+ Py_DECREF (b );
479+ }
480+ else {
481+ Py_SETREF (b , PyNumber_Subtract (n , b ));
482+ * rem = b ;
483+ }
471484 Py_DECREF (n );
472485 return a ;
473486
@@ -478,6 +491,51 @@ math_integer_isqrt(PyObject *module, PyObject *n)
478491}
479492
480493
494+ /*[clinic input]
495+ math.integer.isqrt
496+
497+ n: object
498+ /
499+
500+ Return the integer part of the square root of the input.
501+ [clinic start generated code]*/
502+
503+ static PyObject *
504+ math_integer_isqrt (PyObject * module , PyObject * n )
505+ /*[clinic end generated code: output=551031e41a0f5d9e input=921ddd9853133d8d]*/
506+ {
507+ return _isqrt_rem (n , NULL );
508+ }
509+
510+ /*[clinic input]
511+ math.integer.isqrt_rem
512+
513+ n: object
514+ /
515+
516+ Return a pair of values (s,t) such that s=isqrt(n) and t=n-s*s.
517+ [clinic start generated code]*/
518+
519+ static PyObject *
520+ math_integer_isqrt_rem (PyObject * module , PyObject * n )
521+ /*[clinic end generated code: output=b17d11479d08cdc4 input=7ed2dd870818d2bb]*/
522+ {
523+ PyObject * rem = NULL ;
524+ PyObject * root = _isqrt_rem (n , & rem );
525+
526+ if (root && rem ) {
527+ PyObject * tup = PyTuple_Pack (2 , root , rem );
528+
529+ Py_DECREF (root );
530+ Py_DECREF (rem );
531+ return tup ;
532+ }
533+ Py_XDECREF (root );
534+ Py_XDECREF (rem );
535+ return NULL ;
536+ }
537+
538+
481539static unsigned long
482540count_set_bits (unsigned long n )
483541{
@@ -1231,6 +1289,7 @@ static PyMethodDef math_integer_methods[] = {
12311289 MATH_INTEGER_FACTORIAL_METHODDEF
12321290 MATH_INTEGER_GCD_METHODDEF
12331291 MATH_INTEGER_ISQRT_METHODDEF
1292+ MATH_INTEGER_ISQRT_REM_METHODDEF
12341293 MATH_INTEGER_LCM_METHODDEF
12351294 MATH_INTEGER_PERM_METHODDEF
12361295 {NULL , NULL } /* sentinel */
0 commit comments