Skip to content

Commit 7e0c26d

Browse files
malfetpytorchmergebot
authored andcommitted
[JIT] Allow tuple and list generics (#98703)
As in Python-3.9+ `Dict`, `List`, and `Tuple` from `typing` module are deprecated in favor of their `builtins` counterparts, see [PEP 585](https://peps.python.org/pep-0585/) Test plan: Run: ``` import torch from typing import Union @torch.jit.script def to_tuple(v: Union[int, tuple[int, int]]) -> tuple[int, int]: """Converts int or tuple to tuple of ints.""" if torch.jit.isinstance(v, int): return v, v else: return v print(to_tuple(1), to_tuple((3, 4))) ``` It's almost impossible to add test to an existing CI, as test script will not be parseable by Python-3.8, which is a oldest supported Python version Fixes #98521 Pull Request resolved: #98703 Approved by: https://github.com/kit1980
1 parent 2400cb1 commit 7e0c26d

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

torch/_jit_internal.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def get_source(self, fn):
7979
loader = SourceLoader()
8080

8181

82+
IS_PY39_PLUS = sys.version_info >= (3, 9)
83+
84+
8285
def createResolutionCallbackFromEnv(lookup_base):
8386
"""
8487
Creates a resolution callback that will look up qualified names in an
@@ -339,7 +342,7 @@ def get_annotation_str(annotation):
339342
return ".".join([get_annotation_str(annotation.value), annotation.attr])
340343
elif isinstance(annotation, ast.Subscript):
341344
# In Python3.9+ subscript indicies are not wrapped in ast.Index
342-
subscript_slice = annotation.slice if sys.version_info >= (3, 9) else annotation.slice.value # type: ignore[attr-defined]
345+
subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value # type: ignore[attr-defined]
343346
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
344347
elif isinstance(annotation, ast.Tuple):
345348
return ",".join([get_annotation_str(elt) for elt in annotation.elts])
@@ -983,10 +986,11 @@ def is_tuple(ann) -> bool:
983986
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
984987
if not hasattr(ann, "__module__"):
985988
return False
986-
return ann.__module__ == "typing" and (
987-
getattr(ann, "__origin__", None) is Tuple
988-
or getattr(ann, "__origin__", None) is tuple
989-
)
989+
990+
ann_origin = getattr(ann, "__origin__", None)
991+
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
992+
return True
993+
return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
990994

991995

992996
def is_list(ann) -> bool:
@@ -995,10 +999,11 @@ def is_list(ann) -> bool:
995999

9961000
if not hasattr(ann, "__module__"):
9971001
return False
998-
return ann.__module__ == "typing" and (
999-
getattr(ann, "__origin__", None) is List
1000-
or getattr(ann, "__origin__", None) is list
1001-
)
1002+
1003+
ann_origin = getattr(ann, "__origin__", None)
1004+
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
1005+
return True
1006+
return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
10021007

10031008

10041009
def is_dict(ann) -> bool:
@@ -1007,10 +1012,11 @@ def is_dict(ann) -> bool:
10071012

10081013
if not hasattr(ann, "__module__"):
10091014
return False
1010-
return ann.__module__ == "typing" and (
1011-
getattr(ann, "__origin__", None) is Dict
1012-
or getattr(ann, "__origin__", None) is dict
1013-
)
1015+
1016+
ann_origin = getattr(ann, "__origin__", None)
1017+
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
1018+
return True
1019+
return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
10141020

10151021

10161022
def is_union(ann):

0 commit comments

Comments
 (0)