Skip to content

Commit a8f4139

Browse files
committed
update pickle protocol
1 parent bf0fd98 commit a8f4139

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

Lib/functools.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,12 @@ def __get__(self, obj, objtype=None):
398398
return self
399399
return MethodType(self, obj)
400400

401+
def __reduce_ex__(self, protocol):
402+
if protocol >= 2:
403+
return self.__reduce__()
404+
return type(self), (self.func,), (self.func, self.args,
405+
dict(self.keywords) or None, self.__dict__ or None)
406+
401407
def __reduce__(self):
402408
return type(self), (self.func,), (self.func, self.args,
403409
self.keywords or None, self.__dict__ or None)

Lib/test/test_functools.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ class BadTuple(tuple):
6262
def __add__(self, other):
6363
return list(self) + list(other)
6464

65-
6665
class MyDict(frozendict):
6766
pass
6867

@@ -342,7 +341,7 @@ def test_pickle(self):
342341
with replaced_module('functools', self.module):
343342
f = self.partial(signature, ['asdf'], bar=[True])
344343
f.attr = []
345-
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
344+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
346345
f_copy = pickle.loads(pickle.dumps(f, proto))
347346
self.assertEqual(signature(f_copy), signature(f))
348347

@@ -470,7 +469,7 @@ def test_recursive_pickle(self):
470469
f = self.partial(capture)
471470
f.__setstate__((f, (), {}, {}))
472471
try:
473-
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
472+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
474473
# gh-117008: Small limit since pickle uses C stack memory
475474
with support.infinite_recursion(100):
476475
with self.assertRaises(RecursionError):
@@ -481,7 +480,7 @@ def test_recursive_pickle(self):
481480
f = self.partial(capture)
482481
f.__setstate__((capture, (f,), {}, {}))
483482
try:
484-
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
483+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
485484
f_copy = pickle.loads(pickle.dumps(f, proto))
486485
try:
487486
self.assertIs(f_copy.args[0], f_copy)
@@ -493,7 +492,7 @@ def test_recursive_pickle(self):
493492
f = self.partial(capture)
494493
f.__setstate__((capture, (), {'a': f}, {}))
495494
try:
496-
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
495+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
497496
f_copy = pickle.loads(pickle.dumps(f, proto))
498497
try:
499498
self.assertIs(f_copy.keywords['a'], f_copy)

Modules/_functoolsmodule.c

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,10 @@ partial_repr(PyObject *self)
777777
operation so we define a __setstate__ that replaces all the information
778778
about the partial. If we only replaced part of it someone would use
779779
it as a hook to do strange things.
780+
781+
Additionally, since frozendict does not work for pickle protocols 0 and 1,
782+
__reduce_ex__ creates a temporary dict for protocols 0 and 1 and calls
783+
__reduce__ for protocols 2+.
780784
*/
781785

782786
static PyObject *
@@ -788,6 +792,28 @@ partial_reduce(PyObject *self, PyObject *Py_UNUSED(args))
788792
pto->dict ? pto->dict : Py_None);
789793
}
790794

795+
static PyObject *
796+
partial_reduce_ex(PyObject *self, PyObject *args)
797+
{
798+
int64_t protocol;
799+
800+
if (!PyArg_ParseTuple(args, "l", &protocol)) {
801+
return NULL;
802+
}
803+
804+
if (protocol >= 2) {
805+
return partial_reduce(self, NULL);
806+
}
807+
808+
partialobject *pto = partialobject_CAST(self);
809+
PyObject *keywords_dict = PyObject_CallOneArg((PyObject*)&PyDict_Type, pto->kw);
810+
PyObject *result = Py_BuildValue("O(O)(OOOO)", Py_TYPE(pto), pto->fn, pto->fn,
811+
pto->args, keywords_dict,
812+
pto->dict ? pto->dict : Py_None);
813+
Py_DECREF(keywords_dict);
814+
return result;
815+
}
816+
791817
static PyObject *
792818
partial_setstate(PyObject *self, PyObject *state)
793819
{
@@ -878,6 +904,7 @@ partial_setstate(PyObject *self, PyObject *state)
878904

879905
static PyMethodDef partial_methods[] = {
880906
{"__reduce__", partial_reduce, METH_NOARGS},
907+
{"__reduce_ex__", partial_reduce_ex, METH_VARARGS},
881908
{"__setstate__", partial_setstate, METH_O},
882909
{"__class_getitem__", Py_GenericAlias,
883910
METH_O|METH_CLASS, PyDoc_STR("See PEP 585")},

0 commit comments

Comments
 (0)