Skip to content

gh-125631: Enable setting persistent_id and persistent_load of pickler and unpickler #125752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions Lib/test/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,84 @@ def persistent_load(subself, pid):
self.assertEqual(unpickler.load(), 'abc')
self.assertEqual(called, ['abc'])

def test_pickler_instance_attribute(self):
def persistent_id(obj):
called.append(obj)
return obj

for proto in range(pickle.HIGHEST_PROTOCOL + 1):
f = io.BytesIO()
pickler = self.pickler(f, proto)
called = []
old_persistent_id = pickler.persistent_id
pickler.persistent_id = persistent_id
self.assertEqual(pickler.persistent_id, persistent_id)
pickler.dump('abc')
self.assertEqual(called, ['abc'])
self.assertEqual(self.loads(f.getvalue()), 'abc')
del pickler.persistent_id
self.assertEqual(pickler.persistent_id, old_persistent_id)

def test_unpickler_instance_attribute(self):
def persistent_load(pid):
called.append(pid)
return pid

for proto in range(pickle.HIGHEST_PROTOCOL + 1):
unpickler = self.unpickler(io.BytesIO(self.dumps('abc', proto)))
called = []
old_persistent_load = unpickler.persistent_load
unpickler.persistent_load = persistent_load
self.assertEqual(unpickler.persistent_load, persistent_load)
self.assertEqual(unpickler.load(), 'abc')
self.assertEqual(called, ['abc'])
del unpickler.persistent_load
self.assertEqual(unpickler.persistent_load, old_persistent_load)

def test_pickler_super_instance_attribute(self):
class PersPickler(self.pickler):
def persistent_id(subself, obj):
raise AssertionError('should never be called')
def _persistent_id(subself, obj):
called.append(obj)
self.assertIsNone(super().persistent_id(obj))
return obj

for proto in range(pickle.HIGHEST_PROTOCOL + 1):
f = io.BytesIO()
pickler = PersPickler(f, proto)
called = []
old_persistent_id = pickler.persistent_id
pickler.persistent_id = pickler._persistent_id
self.assertEqual(pickler.persistent_id, pickler._persistent_id)
pickler.dump('abc')
self.assertEqual(called, ['abc'])
self.assertEqual(self.loads(f.getvalue()), 'abc')
del pickler.persistent_id
self.assertEqual(pickler.persistent_id, old_persistent_id)

def test_unpickler_super_instance_attribute(self):
class PersUnpickler(self.unpickler):
def persistent_load(subself, pid):
raise AssertionError('should never be called')
def _persistent_load(subself, pid):
called.append(pid)
with self.assertRaises(self.persistent_load_error):
super().persistent_load(pid)
return pid

for proto in range(pickle.HIGHEST_PROTOCOL + 1):
unpickler = PersUnpickler(io.BytesIO(self.dumps('abc', proto)))
called = []
old_persistent_load = unpickler.persistent_load
unpickler.persistent_load = unpickler._persistent_load
self.assertEqual(unpickler.persistent_load, unpickler._persistent_load)
self.assertEqual(unpickler.load(), 'abc')
self.assertEqual(called, ['abc'])
del unpickler.persistent_load
self.assertEqual(unpickler.persistent_load, old_persistent_load)


class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests, unittest.TestCase):

pickler_class = pickle._Pickler
Expand Down Expand Up @@ -373,7 +451,7 @@ class SizeofTests(unittest.TestCase):
check_sizeof = support.check_sizeof

def test_pickler(self):
basesize = support.calcobjsize('6P2n3i2n3i2P')
basesize = support.calcobjsize('7P2n3i2n3i2P')
p = _pickle.Pickler(io.BytesIO())
self.assertEqual(object.__sizeof__(p), basesize)
MT_size = struct.calcsize('3nP0n')
Expand All @@ -390,7 +468,7 @@ def test_pickler(self):
0) # Write buffer is cleared after every dump().

def test_unpickler(self):
basesize = support.calcobjsize('2P2nP 2P2n2i5P 2P3n8P2n2i')
basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n8P2n2i')
unpickler = _pickle.Unpickler
P = struct.calcsize('P') # Size of memo table entry.
n = struct.calcsize('n') # Size of mark table entry.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Restore ability to set :attr:`~pickle.Pickler.persistent_id` and
:attr:`~pickle.Unpickler.persistent_load` attributes of instances of the
:class:`!Pickler` and :class:`!Unpickler` classes in the :mod:`pickle`
module.
62 changes: 62 additions & 0 deletions Modules/_pickle.c
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ typedef struct PicklerObject {
objects to support self-referential objects
pickling. */
PyObject *persistent_id; /* persistent_id() method, can be NULL */
PyObject *persistent_id_attr; /* instance attribute, can be NULL */
PyObject *dispatch_table; /* private dispatch_table, can be NULL */
PyObject *reducer_override; /* hook for invoking user-defined callbacks
instead of save_global when pickling
Expand Down Expand Up @@ -655,6 +656,7 @@ typedef struct UnpicklerObject {
size_t memo_len; /* Number of objects in the memo */

PyObject *persistent_load; /* persistent_load() method, can be NULL. */
PyObject *persistent_load_attr; /* instance attribute, can be NULL. */

Py_buffer buffer;
char *input_buffer;
Expand Down Expand Up @@ -1108,6 +1110,7 @@ _Pickler_New(PickleState *st)

self->memo = memo;
self->persistent_id = NULL;
self->persistent_id_attr = NULL;
self->dispatch_table = NULL;
self->reducer_override = NULL;
self->write = NULL;
Expand Down Expand Up @@ -1606,6 +1609,7 @@ _Unpickler_New(PyObject *module)
self->memo_size = MEMO_SIZE;
self->memo_len = 0;
self->persistent_load = NULL;
self->persistent_load_attr = NULL;
memset(&self->buffer, 0, sizeof(Py_buffer));
self->input_buffer = NULL;
self->input_line = NULL;
Expand Down Expand Up @@ -5092,6 +5096,33 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj, void *Py_UNUSED(ignored))
return -1;
}

static PyObject *
Pickler_getattr(PyObject *self, PyObject *name)
{
if (PyUnicode_Check(name)
&& PyUnicode_EqualToUTF8(name, "persistent_id")
&& ((PicklerObject *)self)->persistent_id_attr)
{
return Py_NewRef(((PicklerObject *)self)->persistent_id_attr);
}

return PyObject_GenericGetAttr(self, name);
}

static int
Pickler_setattr(PyObject *self, PyObject *name, PyObject *value)
{
if (PyUnicode_Check(name)
&& PyUnicode_EqualToUTF8(name, "persistent_id"))
{
Py_XINCREF(value);
Py_XSETREF(((PicklerObject *)self)->persistent_id_attr, value);
return 0;
}

return PyObject_GenericSetAttr(self, name, value);
}

static PyMemberDef Pickler_members[] = {
{"bin", Py_T_INT, offsetof(PicklerObject, bin)},
{"fast", Py_T_INT, offsetof(PicklerObject, fast)},
Expand All @@ -5107,6 +5138,8 @@ static PyGetSetDef Pickler_getsets[] = {

static PyType_Slot pickler_type_slots[] = {
{Py_tp_dealloc, Pickler_dealloc},
{Py_tp_getattro, Pickler_getattr},
{Py_tp_setattro, Pickler_setattr},
{Py_tp_methods, Pickler_methods},
{Py_tp_members, Pickler_members},
{Py_tp_getset, Pickler_getsets},
Expand Down Expand Up @@ -7566,6 +7599,33 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj, void *Py_UNUSED(ignored
return -1;
}

static PyObject *
Unpickler_getattr(PyObject *self, PyObject *name)
{
if (PyUnicode_Check(name)
&& PyUnicode_EqualToUTF8(name, "persistent_load")
&& ((UnpicklerObject *)self)->persistent_load_attr)
{
return Py_NewRef(((UnpicklerObject *)self)->persistent_load_attr);
}

return PyObject_GenericGetAttr(self, name);
}

static int
Unpickler_setattr(PyObject *self, PyObject *name, PyObject *value)
{
if (PyUnicode_Check(name)
&& PyUnicode_EqualToUTF8(name, "persistent_load"))
{
Py_XINCREF(value);
Py_XSETREF(((UnpicklerObject *)self)->persistent_load_attr, value);
return 0;
}

return PyObject_GenericSetAttr(self, name, value);
}

static PyGetSetDef Unpickler_getsets[] = {
{"memo", (getter)Unpickler_get_memo, (setter)Unpickler_set_memo},
{NULL}
Expand All @@ -7574,6 +7634,8 @@ static PyGetSetDef Unpickler_getsets[] = {
static PyType_Slot unpickler_type_slots[] = {
{Py_tp_dealloc, Unpickler_dealloc},
{Py_tp_doc, (char *)_pickle_Unpickler___init____doc__},
{Py_tp_getattro, Unpickler_getattr},
{Py_tp_setattro, Unpickler_setattr},
{Py_tp_traverse, Unpickler_traverse},
{Py_tp_clear, Unpickler_clear},
{Py_tp_methods, Unpickler_methods},
Expand Down
Loading