Skip to content

Commit 94d838d

Browse files
Ark-kuncopybara-github
authored andcommitted
chore: GenAI - Enforce formatting in the vertexai module
Also, format all existing modules. PiperOrigin-RevId: 637170526
1 parent d47a5be commit 94d838d

File tree

13 files changed

+141
-143
lines changed

13 files changed

+141
-143
lines changed

noxfile.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
FLAKE8_VERSION = "flake8==6.1.0"
2929
BLACK_VERSION = "black==22.3.0"
3030
ISORT_VERSION = "isort==5.10.1"
31-
LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"]
31+
LINT_PATHS = ["docs", "google", "vertexai", "tests", "noxfile.py", "setup.py"]
3232

3333
DEFAULT_PYTHON_VERSION = "3.8"
3434

vertexai/extensions/_extensions.py

+25-28
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,20 @@ def __init__(self, extension_name: str):
9595
self.execution_api_client = initializer.global_config.create_client(
9696
client_class=aip_utils.ExtensionExecutionClientWithOverride,
9797
)
98-
self._gca_resource = self._get_gca_resource(
99-
resource_name=extension_name
100-
)
98+
self._gca_resource = self._get_gca_resource(resource_name=extension_name)
10199
self._api_spec = None
102100
self._operation_schemas = None
103101

104102
@classmethod
105103
def create(
106-
cls,
107-
manifest: Union[_utils.JsonDict, types.ExtensionManifest],
108-
*,
109-
extension_name: Optional[str] = None,
110-
display_name: Optional[str] = None,
111-
description: Optional[str] = None,
112-
runtime_config: Optional[_RuntimeConfigOrJson] = None,
113-
):
104+
cls,
105+
manifest: Union[_utils.JsonDict, types.ExtensionManifest],
106+
*,
107+
extension_name: Optional[str] = None,
108+
display_name: Optional[str] = None,
109+
description: Optional[str] = None,
110+
runtime_config: Optional[_RuntimeConfigOrJson] = None,
111+
):
114112
"""Creates a new Extension.
115113
116114
Args:
@@ -150,7 +148,8 @@ def create(
150148
)
151149
if runtime_config:
152150
extension.runtime_config = _utils.to_proto(
153-
runtime_config, types.RuntimeConfig(),
151+
runtime_config,
152+
types.RuntimeConfig(),
154153
)
155154
operation_future = sdk_resource.api_client.import_extension(
156155
parent=initializer.global_config.common_location_path(),
@@ -169,10 +168,8 @@ def create(
169168
sdk_resource._gca_resource = sdk_resource._get_gca_resource(
170169
resource_name=created_extension.name
171170
)
172-
sdk_resource.execution_api_client = (
173-
initializer.global_config.create_client(
174-
client_class=aip_utils.ExtensionExecutionClientWithOverride,
175-
)
171+
sdk_resource.execution_api_client = initializer.global_config.create_client(
172+
client_class=aip_utils.ExtensionExecutionClientWithOverride,
176173
)
177174
sdk_resource._api_spec = None
178175
sdk_resource._operation_schemas = None
@@ -186,9 +183,7 @@ def resource_name(self) -> str:
186183
def api_spec(self) -> _utils.JsonDict:
187184
"""Returns the (Open)API Spec of the extension."""
188185
if self._api_spec is None:
189-
self._api_spec = _load_api_spec(
190-
self._gca_resource.manifest.api_spec
191-
)
186+
self._api_spec = _load_api_spec(self._gca_resource.manifest.api_spec)
192187
return self._api_spec
193188

194189
def operation_schemas(self) -> Sequence[_utils.JsonDict]:
@@ -201,11 +196,11 @@ def operation_schemas(self) -> Sequence[_utils.JsonDict]:
201196
return self._operation_schemas
202197

203198
def execute(
204-
self,
205-
operation_id: str,
206-
operation_params: Optional[_StructOrJson] = None,
207-
runtime_auth_config: Optional[_AuthConfigOrJson] = None,
208-
) -> Union[_utils.JsonDict, str]:
199+
self,
200+
operation_id: str,
201+
operation_params: Optional[_StructOrJson] = None,
202+
runtime_auth_config: Optional[_AuthConfigOrJson] = None,
203+
) -> Union[_utils.JsonDict, str]:
209204
"""Executes an operation of the extension with the specified params.
210205
211206
Args:
@@ -230,7 +225,8 @@ def execute(
230225
)
231226
if runtime_auth_config:
232227
request.runtime_auth_config = _utils.to_proto(
233-
runtime_auth_config, types.AuthConfig(),
228+
runtime_auth_config,
229+
types.AuthConfig(),
234230
)
235231
response = self.execution_api_client.execute_extension(request)
236232
return _try_parse_execution_response(response)
@@ -263,7 +259,8 @@ def from_hub(
263259
"""
264260
if runtime_config:
265261
runtime_config = _utils.to_proto(
266-
runtime_config, types.RuntimeConfig(),
262+
runtime_config,
263+
types.RuntimeConfig(),
267264
)
268265
if name == "code_interpreter":
269266
if runtime_config and not getattr(
@@ -301,8 +298,8 @@ def from_hub(
301298

302299

303300
def _try_parse_execution_response(
304-
response: types.ExecuteExtensionResponse
305-
) -> Union[_utils.JsonDict, str]:
301+
response: types.ExecuteExtensionResponse,
302+
) -> Union[_utils.JsonDict, str]:
306303
content: str = response.content
307304
try:
308305
content = json.loads(response.content)

vertexai/generative_models/_function_calling_utils.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -71,24 +71,23 @@ def _generate_json_schema_from_function_using_pydantic(
7171
name: (
7272
# 1. We infer the argument type here: use Any rather than None so
7373
# it will not try to auto-infer the type based on the default value.
74-
(
75-
param.annotation if param.annotation != inspect.Parameter.empty
76-
else Any
77-
),
74+
(param.annotation if param.annotation != inspect.Parameter.empty else Any),
7875
pydantic.Field(
7976
# 2. We do not support default values for now.
8077
default=(
81-
param.default if param.default != inspect.Parameter.empty
78+
param.default
79+
if param.default != inspect.Parameter.empty
8280
# ! Need to use Undefined instead of None
8381
else pydantic_fields.Undefined
8482
),
8583
# 3. We support user-provided descriptions.
8684
description=parameter_descriptions.get(name, None),
87-
)
85+
),
8886
)
8987
for name, param in defaults.items()
9088
# We do not support *args or **kwargs
91-
if param.kind in (
89+
if param.kind
90+
in (
9291
inspect.Parameter.POSITIONAL_OR_KEYWORD,
9392
inspect.Parameter.KEYWORD_ONLY,
9493
inspect.Parameter.POSITIONAL_ONLY,
@@ -105,10 +104,9 @@ def _generate_json_schema_from_function_using_pydantic(
105104
# * https://github.com/pydantic/pydantic/issues/1270
106105
# * https://stackoverflow.com/a/58841311
107106
# * https://github.com/pydantic/pydantic/discussions/4872
108-
if (
109-
typing.get_origin(annotation) is typing.Union
110-
and type(None) in typing.get_args(annotation)
111-
):
107+
if typing.get_origin(annotation) is typing.Union and type(
108+
None
109+
) in typing.get_args(annotation):
112110
# for "typing.Optional" arguments, function_arg might be a
113111
# dictionary like
114112
#
@@ -121,9 +119,12 @@ def _generate_json_schema_from_function_using_pydantic(
121119
property_schema["nullable"] = True
122120
# 6. Annotate required fields.
123121
function_schema["required"] = [
124-
k for k in defaults if (
122+
k
123+
for k in defaults
124+
if (
125125
defaults[k].default == inspect.Parameter.empty
126-
and defaults[k].kind in (
126+
and defaults[k].kind
127+
in (
127128
inspect.Parameter.POSITIONAL_OR_KEYWORD,
128129
inspect.Parameter.KEYWORD_ONLY,
129130
inspect.Parameter.POSITIONAL_ONLY,

vertexai/generative_models/_generative_models.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -1660,8 +1660,7 @@ def text(self) -> str:
16601660
raise ValueError(
16611661
"Response has no candidates (and thus no text)."
16621662
" The response is likely blocked by the safety filters.\n"
1663-
"Response:\n"
1664-
+ _dict_to_pretty_string(self.to_dict())
1663+
"Response:\n" + _dict_to_pretty_string(self.to_dict())
16651664
)
16661665
try:
16671666
return self.candidates[0].text
@@ -1671,8 +1670,7 @@ def text(self) -> str:
16711670
raise ValueError(
16721671
"Cannot get the response text.\n"
16731672
f"{e}\n"
1674-
"Response:\n"
1675-
+ _dict_to_pretty_string(self.to_dict())
1673+
"Response:\n" + _dict_to_pretty_string(self.to_dict())
16761674
) from e
16771675

16781676
@property
@@ -1754,8 +1752,7 @@ def text(self) -> str:
17541752
raise ValueError(
17551753
"Cannot get the Candidate text.\n"
17561754
f"{e}\n"
1757-
"Candidate:\n"
1758-
+ _dict_to_pretty_string(self.to_dict())
1755+
"Candidate:\n" + _dict_to_pretty_string(self.to_dict())
17591756
) from e
17601757

17611758
@property
@@ -1830,8 +1827,7 @@ def text(self) -> str:
18301827
raise ValueError(
18311828
"Response candidate content has no parts (and thus no text)."
18321829
" The candidate is likely blocked by the safety filters.\n"
1833-
"Content:\n"
1834-
+ _dict_to_pretty_string(self.to_dict())
1830+
"Content:\n" + _dict_to_pretty_string(self.to_dict())
18351831
)
18361832
return self.parts[0].text
18371833

@@ -1921,8 +1917,7 @@ def text(self) -> str:
19211917
if "text" not in self._raw_part:
19221918
raise AttributeError(
19231919
"Response candidate content part has no text.\n"
1924-
"Part:\n"
1925-
+ _dict_to_pretty_string(self.to_dict())
1920+
"Part:\n" + _dict_to_pretty_string(self.to_dict())
19261921
)
19271922
return self._raw_part.text
19281923

@@ -2023,8 +2018,7 @@ class GoogleSearchRetrieval:
20232018
"""
20242019

20252020
def __init__(self):
2026-
"""Initializes a Google Search Retrieval tool.
2027-
"""
2021+
"""Initializes a Google Search Retrieval tool."""
20282022
self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval()
20292023

20302024

@@ -2400,7 +2394,9 @@ def respond_to_model_response(
24002394
)
24012395
callable_function = None
24022396
for tool in tools:
2403-
new_callable_function = tool._callable_functions.get(function_call.name)
2397+
new_callable_function = tool._callable_functions.get(
2398+
function_call.name
2399+
)
24042400
if new_callable_function and callable_function:
24052401
raise ValueError(
24062402
"Multiple functions with the same name are not supported."

vertexai/language_models/_distillation.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ def submit_distillation_pipeline_job(
105105
pipeline_arguments["learning_rate_multiplier"] = learning_rate_multiplier
106106
if evaluation_spec is not None:
107107
pipeline_arguments["evaluation_data_uri"] = evaluation_spec.evaluation_data
108-
pipeline_arguments[
109-
"evaluation_interval"
110-
] = evaluation_spec.evaluation_interval
108+
pipeline_arguments["evaluation_interval"] = evaluation_spec.evaluation_interval
111109
pipeline_arguments[
112110
"enable_early_stopping"
113111
] = evaluation_spec.enable_early_stopping
@@ -126,8 +124,7 @@ def submit_distillation_pipeline_job(
126124
pipeline_arguments["max_context_length"] = max_context_length
127125
if model_display_name is None:
128126
model_display_name = (
129-
f"{student_short_model_id}"
130-
f" distilled from {teacher_short_model_id}"
127+
f"{student_short_model_id} distilled from {teacher_short_model_id}"
131128
)
132129
pipeline_arguments["model_display_name"] = model_display_name
133130
# # Not exposing these parameters:

vertexai/preview/extensions.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,4 @@
2020
Extension,
2121
)
2222

23-
__all__ = (
24-
"Extension",
25-
)
23+
__all__ = ("Extension",)

vertexai/preview/reasoning_engines/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
ReasoningEngine,
2222
)
2323
from vertexai.preview.reasoning_engines.templates.langchain import (
24-
LangchainAgent
24+
LangchainAgent,
2525
)
2626

2727
__all__ = (

0 commit comments

Comments
 (0)