Skip to content

Commit 0823323

Browse files
yeesiancopybara-github
authored andcommitted
fix: Use typing.TYPE_CHECKING to differentiate pytype checks and runtime imports for pydantic.
PiperOrigin-RevId: 713367981
1 parent 73fa981 commit 0823323

File tree

1 file changed

+76
-46
lines changed

1 file changed

+76
-46
lines changed

vertexai/generative_models/_generative_models.py

+76-46
Original file line numberDiff line numberDiff line change
@@ -88,53 +88,10 @@
8888
# These type defnitions are expanded to help the user see all the types
8989
ContentDict = Dict[str, Any]
9090
GenerationConfigDict = Dict[str, Any]
91-
try:
92-
# For Pydantic to resolve the forward references inside these aliases.
93-
from typing_extensions import TypeAliasType
9491

95-
PartsType = TypeAliasType(
96-
"PartsType",
97-
Union[
98-
str,
99-
"Image",
100-
"Part",
101-
List[Union[str, "Image", "Part"]],
102-
],
103-
)
104-
ContentsType = TypeAliasType(
105-
"ContentsType",
106-
Union[
107-
List["Content"],
108-
List[ContentDict],
109-
str,
110-
"Image",
111-
"Part",
112-
List[Union[str, "Image", "Part"]],
113-
],
114-
)
115-
GenerationConfigType = TypeAliasType(
116-
"GenerationConfigType",
117-
Union[
118-
"GenerationConfig",
119-
GenerationConfigDict,
120-
],
121-
)
122-
SafetySettingsType = TypeAliasType(
123-
"SafetySettingsType",
124-
Union[
125-
List["SafetySetting"],
126-
Dict[
127-
gapic_content_types.HarmCategory,
128-
gapic_content_types.SafetySetting.HarmBlockThreshold,
129-
],
130-
],
131-
)
132-
except (ImportError, RuntimeError) as e:
133-
from google.cloud.aiplatform import base
134-
135-
_LOGGER = base.Logger(__name__)
136-
_LOGGER.debug(f"Failed to import typing_extensions.TypeAliasType: {e}")
137-
# Use existing definitions if typing_extensions is not available.
92+
if TYPE_CHECKING:
93+
# Default to the current definitions if pytype is being used for type checks
94+
# because it does not support try-except for types.
13895
PartsType = Union[
13996
str,
14097
"Image",
@@ -160,6 +117,79 @@
160117
gapic_content_types.SafetySetting.HarmBlockThreshold,
161118
],
162119
]
120+
else:
121+
try:
122+
# For Pydantic to resolve the forward references inside these aliases.
123+
from typing_extensions import TypeAliasType
124+
125+
PartsType = TypeAliasType(
126+
"PartsType",
127+
Union[
128+
str,
129+
"Image",
130+
"Part",
131+
List[Union[str, "Image", "Part"]],
132+
],
133+
)
134+
ContentsType = TypeAliasType(
135+
"ContentsType",
136+
Union[
137+
List["Content"],
138+
List[ContentDict],
139+
str,
140+
"Image",
141+
"Part",
142+
List[Union[str, "Image", "Part"]],
143+
],
144+
)
145+
GenerationConfigType = TypeAliasType(
146+
"GenerationConfigType",
147+
Union[
148+
"GenerationConfig",
149+
GenerationConfigDict,
150+
],
151+
)
152+
SafetySettingsType = TypeAliasType(
153+
"SafetySettingsType",
154+
Union[
155+
List["SafetySetting"],
156+
Dict[
157+
gapic_content_types.HarmCategory,
158+
gapic_content_types.SafetySetting.HarmBlockThreshold,
159+
],
160+
],
161+
)
162+
except (ImportError, RuntimeError) as e:
163+
from google.cloud.aiplatform import base
164+
165+
_LOGGER = base.Logger(__name__)
166+
_LOGGER.debug(f"Failed to import typing_extensions.TypeAliasType: {e}")
167+
# Use existing definitions if typing_extensions is not available.
168+
PartsType = Union[
169+
str,
170+
"Image",
171+
"Part",
172+
List[Union[str, "Image", "Part"]],
173+
]
174+
ContentsType = Union[
175+
List["Content"],
176+
List[ContentDict],
177+
str,
178+
"Image",
179+
"Part",
180+
List[Union[str, "Image", "Part"]],
181+
]
182+
GenerationConfigType = Union[
183+
"GenerationConfig",
184+
GenerationConfigDict,
185+
]
186+
SafetySettingsType = Union[
187+
List["SafetySetting"],
188+
Dict[
189+
gapic_content_types.HarmCategory,
190+
gapic_content_types.SafetySetting.HarmBlockThreshold,
191+
],
192+
]
163193

164194

165195
def _reconcile_model_name(model_name: str, project: str, location: str) -> str:

0 commit comments

Comments
 (0)