Skip to content

Commit 5d83b6c

Browse files
[3.13] gh-130230: Fix crash in pow() with only Decimal third argument (GH-130237) (GH-130246)
(cherry picked from commit b93b7e5) Co-authored-by: Serhiy Storchaka <[email protected]>
1 parent fc1c9f8 commit 5d83b6c

File tree

5 files changed

+41
-1
lines changed

5 files changed

+41
-1
lines changed

Include/internal/pycore_typeobject.h

+1
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ extern PyObject * _PyType_GetMRO(PyTypeObject *type);
199199
extern PyObject* _PyType_GetSubclasses(PyTypeObject *);
200200
extern int _PyType_HasSubclasses(PyTypeObject *);
201201
PyAPI_FUNC(PyObject *) _PyType_GetModuleByDef2(PyTypeObject *, PyTypeObject *, PyModuleDef *);
202+
PyAPI_FUNC(PyObject *) _PyType_GetModuleByDef3(PyTypeObject *, PyTypeObject *, PyTypeObject *, PyModuleDef *);
202203

203204
// PyType_Ready() must be called if _PyType_IsReady() is false.
204205
// See also the Py_TPFLAGS_READY flag.

Lib/test/test_decimal.py

+9
Original file line numberDiff line numberDiff line change
@@ -4458,6 +4458,15 @@ def test_implicit_context(self):
44584458
self.assertIs(Decimal("NaN").fma(7, 1).is_nan(), True)
44594459
# three arg power
44604460
self.assertEqual(pow(Decimal(10), 2, 7), 2)
4461+
if self.decimal == C:
4462+
self.assertEqual(pow(10, Decimal(2), 7), 2)
4463+
self.assertEqual(pow(10, 2, Decimal(7)), 2)
4464+
else:
4465+
# XXX: Three-arg power doesn't use __rpow__.
4466+
self.assertRaises(TypeError, pow, 10, Decimal(2), 7)
4467+
# XXX: There is no special method to dispatch on the
4468+
# third arg of three-arg power.
4469+
self.assertRaises(TypeError, pow, 10, 2, Decimal(7))
44614470
# exp
44624471
self.assertEqual(Decimal("1.01").exp(), 3)
44634472
# is_normal
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix crash in :func:`pow` with only :class:`~decimal.Decimal` third argument.

Modules/_decimal/_decimal.c

+10-1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,15 @@ find_state_left_or_right(PyObject *left, PyObject *right)
140140
return get_module_state(mod);
141141
}
142142

143+
static inline decimal_state *
144+
find_state_ternary(PyObject *left, PyObject *right, PyObject *modulus)
145+
{
146+
PyObject *mod = _PyType_GetModuleByDef3(Py_TYPE(left), Py_TYPE(right), Py_TYPE(modulus),
147+
&_decimal_module);
148+
assert(mod != NULL);
149+
return get_module_state(mod);
150+
}
151+
143152

144153
#if !defined(MPD_VERSION_HEX) || MPD_VERSION_HEX < 0x02050000
145154
#error "libmpdec version >= 2.5.0 required"
@@ -4305,7 +4314,7 @@ nm_mpd_qpow(PyObject *base, PyObject *exp, PyObject *mod)
43054314
PyObject *context;
43064315
uint32_t status = 0;
43074316

4308-
decimal_state *state = find_state_left_or_right(base, exp);
4317+
decimal_state *state = find_state_ternary(base, exp, mod);
43094318
CURRENT_CONTEXT(state, context);
43104319
CONVERT_BINOP(&a, &b, base, exp, context);
43114320

Objects/typeobject.c

+20
Original file line numberDiff line numberDiff line change
@@ -5038,6 +5038,26 @@ _PyType_GetModuleByDef2(PyTypeObject *left, PyTypeObject *right,
50385038
return module;
50395039
}
50405040

5041+
PyObject *
5042+
_PyType_GetModuleByDef3(PyTypeObject *left, PyTypeObject *right, PyTypeObject *third,
5043+
PyModuleDef *def)
5044+
{
5045+
PyObject *module = get_module_by_def(left, def);
5046+
if (module == NULL) {
5047+
module = get_module_by_def(right, def);
5048+
if (module == NULL) {
5049+
module = get_module_by_def(third, def);
5050+
if (module == NULL) {
5051+
PyErr_Format(
5052+
PyExc_TypeError,
5053+
"PyType_GetModuleByDef: No superclass of '%s', '%s' nor '%s' has "
5054+
"the given module", left->tp_name, right->tp_name, third->tp_name);
5055+
}
5056+
}
5057+
}
5058+
return module;
5059+
}
5060+
50415061
void *
50425062
PyObject_GetTypeData(PyObject *obj, PyTypeObject *cls)
50435063
{

0 commit comments

Comments
 (0)