Skip to content

Commit 80db7a0

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add support for reading requirements from a file.
PiperOrigin-RevId: 625068034
1 parent 67de093 commit 80db7a0

File tree

2 files changed

+62
-4
lines changed

2 files changed

+62
-4
lines changed

tests/unit/vertexai/test_reasoning_engines.py

+46
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,36 @@ def test_create_reasoning_engine(
294294
retry=_TEST_RETRY,
295295
)
296296

297+
def test_create_reasoning_engine_requirements_from_file(
298+
self,
299+
create_reasoning_engine_mock,
300+
cloud_storage_create_bucket_mock,
301+
tarfile_open_mock,
302+
cloudpickle_dump_mock,
303+
get_reasoning_engine_mock,
304+
):
305+
with mock.patch(
306+
"builtins.open",
307+
mock.mock_open(read_data="google-cloud-aiplatform==1.29.0"),
308+
) as mock_file:
309+
test_reasoning_engine = reasoning_engines.ReasoningEngine.create(
310+
self.test_app,
311+
reasoning_engine_name=_TEST_REASONING_ENGINE_RESOURCE_NAME,
312+
display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME,
313+
requirements="requirements.txt",
314+
)
315+
mock_file.assert_called_with("requirements.txt")
316+
# Manually set _gca_resource here to prevent the mocks from propagating.
317+
test_reasoning_engine._gca_resource = _TEST_REASONING_ENGINE_OBJ
318+
create_reasoning_engine_mock.assert_called_with(
319+
parent=_TEST_PARENT,
320+
reasoning_engine=test_reasoning_engine.gca_resource,
321+
)
322+
get_reasoning_engine_mock.assert_called_with(
323+
name=_TEST_REASONING_ENGINE_RESOURCE_NAME,
324+
retry=_TEST_RETRY,
325+
)
326+
297327
def test_delete_after_create_reasoning_engine(
298328
self,
299329
create_reasoning_engine_mock,
@@ -407,6 +437,22 @@ def test_create_reasoning_engine_unsupported_sys_version(
407437
sys_version="2.6",
408438
)
409439

440+
def test_create_reasoning_engine_requirements_ioerror(
441+
self,
442+
create_reasoning_engine_mock,
443+
cloud_storage_create_bucket_mock,
444+
tarfile_open_mock,
445+
cloudpickle_dump_mock,
446+
get_reasoning_engine_mock,
447+
):
448+
with pytest.raises(IOError, match="Failed to read requirements"):
449+
reasoning_engines.ReasoningEngine.create(
450+
self.test_app,
451+
reasoning_engine_name=_TEST_REASONING_ENGINE_RESOURCE_NAME,
452+
display_name=_TEST_REASONING_ENGINE_DISPLAY_NAME,
453+
requirements="nonexistent_requirements.txt",
454+
)
455+
410456
def test_create_reasoning_engine_nonexistent_extra_packages(
411457
self,
412458
create_reasoning_engine_mock,

vertexai/reasoning_engines/_reasoning_engines.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import sys
2121
import tarfile
2222
import typing
23-
from typing import Optional, Protocol, Sequence
23+
from typing import Optional, Protocol, Sequence, Union
2424

2525
from google.cloud.aiplatform import base
2626
from google.cloud.aiplatform import initializer
@@ -84,7 +84,7 @@ def create(
8484
cls,
8585
reasoning_engine: Queryable,
8686
*,
87-
requirements: Optional[Sequence[str]] = None,
87+
requirements: Optional[Union[str, Sequence[str]]] = None,
8888
reasoning_engine_name: Optional[str] = None,
8989
display_name: Optional[str] = None,
9090
description: Optional[str] = None,
@@ -131,8 +131,10 @@ def create(
131131
Args:
132132
reasoning_engine (ReasoningEngineInterface):
133133
Required. The Reasoning Engine to be created.
134-
requirements (Sequence[str]):
135-
Optional. The set of PyPI dependencies needed.
134+
requirements (Union[str, Sequence[str]]):
135+
Optional. The set of PyPI dependencies needed. It can either be
136+
the path to a single file (requirements.txt), or an ordered list
137+
of strings corresponding to each line of the requirements file.
136138
reasoning_engine_name (str):
137139
Optional. A fully-qualified resource name or ID such as
138140
"projects/123/locations/us-central1/reasoningEngines/456" or
@@ -202,6 +204,16 @@ def create(
202204
"Invalid query signature. This might be due to a missing "
203205
"`self` argument in the reasoning_engine.query method."
204206
) from err
207+
if isinstance(requirements, str):
208+
try:
209+
_LOGGER.info(f"Reading requirements from {requirements=}")
210+
with open(requirements) as f:
211+
requirements = f.read().splitlines()
212+
_LOGGER.info(f"Read the following lines: {requirements}")
213+
except IOError as err:
214+
raise IOError(
215+
f"Failed to read requirements from {requirements=}"
216+
) from err
205217
requirements = requirements or []
206218
extra_packages = extra_packages or []
207219
for extra_package in extra_packages:

0 commit comments

Comments
 (0)