Skip to content

Commit 051e8eb

Browse files
authored
Add pickle support for tstrings (#75)
1 parent 280e90e commit 051e8eb

File tree

6 files changed

+132
-8
lines changed

6 files changed

+132
-8
lines changed

Lib/string/templatelib.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
11
"""Support for template string literals (t-strings)."""
22

33
from _templatelib import Interpolation, Template
4+
5+
__all__ = [
6+
"Interpolation",
7+
"Template",
8+
]
9+
10+
def _template_unpickle(*args):
11+
import itertools
12+
13+
if len(args) != 2:
14+
raise ValueError('Template expects tuple of length 2 to unpickle')
15+
16+
strings, interpolations = args
17+
parts = []
18+
for string, interpolation in itertools.zip_longest(strings, interpolations):
19+
if string is not None:
20+
parts.append(string)
21+
if interpolation is not None:
22+
parts.append(interpolation)
23+
return Template(*parts)

Lib/test/test_string/_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from string.templatelib import Interpolation
33

44

5-
class TStringTestCase(unittest.TestCase):
5+
class TStringBaseCase:
66
def assertTStringEqual(self, t, strings, interpolations):
77
"""Test template string literal equality.
88
@@ -23,7 +23,7 @@ def assertTStringEqual(self, t, strings, interpolations):
2323
continue
2424

2525
if len(exp) == 3:
26-
self.assertEqual((i.value, i.expression, i.conversion), exp)
26+
self.assertEqual((i.value, i.expression, i.conversion), exp)
2727
self.assertEqual(i.format_spec, '')
2828
continue
2929

Lib/test/test_string/test_templatelib.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1+
import pickle
2+
import unittest
13
from string.templatelib import Template, Interpolation
24

3-
from test.test_string._support import TStringTestCase, fstring
5+
from test.test_string._support import TStringBaseCase, fstring
46

57

6-
class TestTemplate(TStringTestCase):
8+
class TestTemplate(unittest.TestCase, TStringBaseCase):
9+
10+
def test_common(self):
11+
self.assertEqual(type(t'').__name__, 'Template')
12+
self.assertEqual(type(t'').__qualname__, 'Template')
13+
self.assertEqual(type(t'').__module__, 'string.templatelib')
14+
15+
a = 'a'
16+
i = t'{a}'.interpolations[0]
17+
self.assertEqual(type(i).__name__, 'Interpolation')
18+
self.assertEqual(type(i).__qualname__, 'Interpolation')
19+
self.assertEqual(type(i).__module__, 'string.templatelib')
720

821
def test_basic_creation(self):
922
# Simple t-string creation
@@ -53,3 +66,44 @@ def test_creation_interleaving(self):
5366
t, ('', '', ''), [('Maria', 'name'), ('Python', 'language')]
5467
)
5568
self.assertEqual(fstring(t), 'MariaPython')
69+
70+
def test_pickle_template(self):
71+
user = 'test'
72+
for template in (
73+
t'',
74+
t"No values",
75+
t'With inter {user}',
76+
t'With ! {user!r}',
77+
t'With format {1 / 0.3:.2f}',
78+
Template(),
79+
Template('a'),
80+
Template(Interpolation('Nikita', 'name', None, '')),
81+
Template('a', Interpolation('Nikita', 'name', 'r', '')),
82+
):
83+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
84+
with self.subTest(proto=proto, template=template):
85+
pickled = pickle.dumps(template, protocol=proto)
86+
unpickled = pickle.loads(pickled)
87+
88+
self.assertEqual(unpickled.values, template.values)
89+
self.assertEqual(fstring(unpickled), fstring(template))
90+
91+
def test_pickle_interpolation(self):
92+
for interpolation in (
93+
Interpolation('Nikita', 'name', None, ''),
94+
Interpolation('Nikita', 'name', 'r', ''),
95+
Interpolation(1/3, 'x', None, '.2f'),
96+
):
97+
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
98+
with self.subTest(proto=proto, interpolation=interpolation):
99+
pickled = pickle.dumps(interpolation, protocol=proto)
100+
unpickled = pickle.loads(pickled)
101+
102+
self.assertEqual(unpickled.value, interpolation.value)
103+
self.assertEqual(unpickled.expression, interpolation.expression)
104+
self.assertEqual(unpickled.conversion, interpolation.conversion)
105+
self.assertEqual(unpickled.format_spec, interpolation.format_spec)
106+
107+
108+
if __name__ == '__main__':
109+
unittest.main()

Lib/test/test_tstring.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import unittest
22

3-
from test.test_string._support import TStringTestCase, fstring
3+
from test.test_string._support import TStringBaseCase, fstring
44

55

6-
class TestTString(TStringTestCase):
6+
class TestTString(unittest.TestCase, TStringBaseCase):
77
def test_string_representation(self):
88
# Test __repr__
99
t = t"Hello"

Objects/interpolationobject.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,22 @@ static PyMemberDef interpolation_members[] = {
125125
{NULL}
126126
};
127127

128+
static PyObject*
129+
interpolation_reduce(PyObject *op, PyObject *Py_UNUSED(dummy))
130+
{
131+
interpolationobject *self = interpolationobject_CAST(op);
132+
return Py_BuildValue("(O(OOOO))", (PyObject *)Py_TYPE(op),
133+
self->value, self->expression,
134+
self->conversion, self->format_spec);
135+
}
136+
137+
static PyMethodDef interpolation_methods[] = {
138+
{"__reduce__", interpolation_reduce, METH_VARARGS,
139+
PyDoc_STR("__reduce__() -> (cls, state)")},
140+
141+
{NULL, NULL},
142+
};
143+
128144
PyTypeObject _PyInterpolation_Type = {
129145
PyVarObject_HEAD_INIT(NULL, 0)
130146
.tp_name = "string.templatelib.Interpolation",
@@ -139,6 +155,7 @@ PyTypeObject _PyInterpolation_Type = {
139155
.tp_free = PyObject_GC_Del,
140156
.tp_repr = interpolation_repr,
141157
.tp_members = interpolation_members,
158+
.tp_methods = interpolation_methods,
142159
.tp_traverse = interpolation_traverse,
143160
};
144161

Objects/templateobject.c

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ template_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
114114
last_was_str = 0;
115115
}
116116
else {
117-
PyErr_SetString(PyExc_TypeError, "Template.__new__ *args need to be of type 'str' or 'Interpolation'");
117+
PyErr_Format(
118+
PyExc_TypeError,
119+
"Template.__new__ *args need to be of type 'str' or 'Interpolation', got %T",
120+
item);
118121
return NULL;
119122
}
120123
}
@@ -418,14 +421,43 @@ static PyMemberDef template_members[] = {
418421
};
419422

420423
static PyGetSetDef template_getset[] = {
421-
{"values", template_values_get, NULL, "Values of interpolations", NULL},
424+
{"values", template_values_get, NULL,
425+
PyDoc_STR("Values of interpolations"), NULL},
422426
{NULL},
423427
};
424428

425429
static PySequenceMethods template_as_sequence = {
426430
.sq_concat = _PyTemplate_Concat,
427431
};
428432

433+
static PyObject*
434+
template_reduce(PyObject *op, PyObject *Py_UNUSED(dummy))
435+
{
436+
PyObject *mod = PyImport_ImportModule("string.templatelib");
437+
if (mod == NULL) {
438+
return NULL;
439+
}
440+
PyObject *func = PyObject_GetAttrString(mod, "_template_unpickle");
441+
Py_DECREF(mod);
442+
if (func == NULL) {
443+
return NULL;
444+
}
445+
446+
templateobject *self = templateobject_CAST(op);
447+
PyObject *result = Py_BuildValue("O(OO)",
448+
func,
449+
self->strings,
450+
self->interpolations);
451+
452+
Py_DECREF(func);
453+
return result;
454+
}
455+
456+
static PyMethodDef template_methods[] = {
457+
{"__reduce__", template_reduce, METH_NOARGS, NULL},
458+
{NULL, NULL},
459+
};
460+
429461
PyTypeObject _PyTemplate_Type = {
430462
PyVarObject_HEAD_INIT(NULL, 0)
431463
.tp_name = "string.templatelib.Template",
@@ -441,6 +473,7 @@ PyTypeObject _PyTemplate_Type = {
441473
.tp_free = PyObject_GC_Del,
442474
.tp_repr = template_repr,
443475
.tp_members = template_members,
476+
.tp_methods = template_methods,
444477
.tp_getset = template_getset,
445478
.tp_iter = template_iter,
446479
.tp_traverse = template_traverse,

0 commit comments

Comments
 (0)