Skip to content

Commit ebb8b51

Browse files
authored
gh-91621: Fix typing.get_type_hints for collections.abc.Callable (#91656)
This mirrors logic in typing.get_args. The trickiness comes from how we flatten args in collections.abc.Callable, see https://bugs.python.org/issue42195
1 parent aff8c4f commit ebb8b51

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

Lib/test/test_typing.py

+11
Original file line numberDiff line numberDiff line change
@@ -4876,6 +4876,17 @@ def test_get_type_hints_typeddict(self):
48764876
'a': Annotated[Required[int], "a", "b", "c"]
48774877
})
48784878

4879+
def test_get_type_hints_collections_abc_callable(self):
4880+
# https://github.com/python/cpython/issues/91621
4881+
P = ParamSpec('P')
4882+
def f(x: collections.abc.Callable[[int], int]): ...
4883+
def g(x: collections.abc.Callable[..., int]): ...
4884+
def h(x: collections.abc.Callable[P, int]): ...
4885+
4886+
self.assertEqual(get_type_hints(f), {'x': collections.abc.Callable[[int], int]})
4887+
self.assertEqual(get_type_hints(g), {'x': collections.abc.Callable[..., int]})
4888+
self.assertEqual(get_type_hints(h), {'x': collections.abc.Callable[P, int]})
4889+
48794890

48804891
class GetUtilitiesTestCase(TestCase):
48814892
def test_get_origin(self):

Lib/typing.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,24 @@ def _is_param_expr(arg):
203203
(tuple, list, ParamSpec, _ConcatenateGenericAlias))
204204

205205

206+
def _should_unflatten_callable_args(typ, args):
207+
"""Internal helper for munging collections.abc.Callable's __args__.
208+
209+
The canonical representation for a Callable's __args__ flattens the
210+
argument types, see https://bugs.python.org/issue42195. For example:
211+
212+
collections.abc.Callable[[int, int], str].__args__ == (int, int, str)
213+
collections.abc.Callable[ParamSpec, str].__args__ == (ParamSpec, str)
214+
215+
As a result, if we need to reconstruct the Callable from its __args__,
216+
we need to unflatten it.
217+
"""
218+
return (
219+
typ.__origin__ is collections.abc.Callable
220+
and not (len(args) == 2 and _is_param_expr(args[0]))
221+
)
222+
223+
206224
def _type_repr(obj):
207225
"""Return the repr() of an object, special-casing types (internal helper).
208226
@@ -351,7 +369,10 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
351369
ForwardRef(arg) if isinstance(arg, str) else arg
352370
for arg in t.__args__
353371
)
354-
t = t.__origin__[args]
372+
if _should_unflatten_callable_args(t, args):
373+
t = t.__origin__[(args[:-1], args[-1])]
374+
else:
375+
t = t.__origin__[args]
355376
ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
356377
if ev_args == t.__args__:
357378
return t
@@ -2361,8 +2382,7 @@ def get_args(tp):
23612382
return (tp.__origin__,) + tp.__metadata__
23622383
if isinstance(tp, (_GenericAlias, GenericAlias)):
23632384
res = tp.__args__
2364-
if (tp.__origin__ is collections.abc.Callable
2365-
and not (len(res) == 2 and _is_param_expr(res[0]))):
2385+
if _should_unflatten_callable_args(tp, res):
23662386
res = (list(res[:-1]), res[-1])
23672387
return res
23682388
if isinstance(tp, types.UnionType):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix :func:`typing.get_type_hints` for :class:`collections.abc.Callable`. Patch by Shantanu Jain.

0 commit comments

Comments
 (0)