Skip to content

Commit 05bb71f

Browse files
jaycee-licopybara-github
authored andcommitted
feat: Support a list of GCS URIs in CustomPythonPackageTrainingJob
PiperOrigin-RevId: 503270789
1 parent 4415c10 commit 05bb71f

File tree

2 files changed

+93
-33
lines changed

2 files changed

+93
-33
lines changed

google/cloud/aiplatform/training_jobs.py

+38-30
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2022 Google LLC
3+
# Copyright 2023 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -5827,7 +5827,7 @@ def __init__(
58275827
self,
58285828
# TODO(b/223262536): Make display_name parameter fully optional in next major release
58295829
display_name: str,
5830-
python_package_gcs_uri: str,
5830+
python_package_gcs_uri: Union[str, List[str]],
58315831
python_module_name: str,
58325832
container_uri: str,
58335833
model_serving_container_image_uri: Optional[str] = None,
@@ -5891,53 +5891,56 @@ def __init__(
58915891
Args:
58925892
display_name (str):
58935893
Required. The user-defined name of this TrainingPipeline.
5894-
python_package_gcs_uri (str):
5895-
Required: GCS location of the training python package.
5894+
python_package_gcs_uri (Union[str, List[str]]):
5895+
Required. GCS location of the training python package.
5896+
Could be a string for single package or a list of string for
5897+
multiple packages.
58965898
python_module_name (str):
5897-
Required: The module name of the training python package.
5899+
Required. The module name of the training python package.
58985900
container_uri (str):
5899-
Required: Uri of the training container image in the GCR.
5901+
Required. Uri of the training container image in the GCR.
59005902
model_serving_container_image_uri (str):
5901-
If the training produces a managed Vertex AI Model, the URI of the
5902-
Model serving container suitable for serving the model produced by the
5903-
training script.
5903+
Optional. If the training produces a managed Vertex AI Model,
5904+
the URI of the model serving container suitable for serving the
5905+
model produced by the training script.
59045906
model_serving_container_predict_route (str):
5905-
If the training produces a managed Vertex AI Model, An HTTP path to
5906-
send prediction requests to the container, and which must be supported
5907-
by it. If not specified a default HTTP path will be used by Vertex AI.
5907+
Optional. If the training produces a managed Vertex AI Model,
5908+
an HTTP path to send prediction requests to the container,
5909+
and which must be supported by it. If not specified a default
5910+
HTTP path will be used by Vertex AI.
59085911
model_serving_container_health_route (str):
5909-
If the training produces a managed Vertex AI Model, an HTTP path to
5910-
send health check requests to the container, and which must be supported
5911-
by it. If not specified a standard HTTP path will be used by AI
5912-
Platform.
5912+
Optional. If the training produces a managed Vertex AI Model,
5913+
an HTTP path to send health check requests to the container,
5914+
and which must be supported by it. If not specified a standard
5915+
HTTP path will be used by AI Platform.
59135916
model_serving_container_command (Sequence[str]):
5914-
The command with which the container is run. Not executed within a
5917+
Optional. The command with which the container is run. Not executed within a
59155918
shell. The Docker image's ENTRYPOINT is used if this is not provided.
59165919
Variable references $(VAR_NAME) are expanded using the container's
59175920
environment. If a variable cannot be resolved, the reference in the
59185921
input string will be unchanged. The $(VAR_NAME) syntax can be escaped
59195922
with a double $$, ie: $$(VAR_NAME). Escaped references will never be
59205923
expanded, regardless of whether the variable exists or not.
59215924
model_serving_container_args (Sequence[str]):
5922-
The arguments to the command. The Docker image's CMD is used if this is
5923-
not provided. Variable references $(VAR_NAME) are expanded using the
5925+
Optional. The arguments to the command. The Docker image's CMD is used if this
5926+
is not provided. Variable references $(VAR_NAME) are expanded using the
59245927
container's environment. If a variable cannot be resolved, the reference
59255928
in the input string will be unchanged. The $(VAR_NAME) syntax can be
59265929
escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
59275930
never be expanded, regardless of whether the variable exists or not.
59285931
model_serving_container_environment_variables (Dict[str, str]):
5929-
The environment variables that are to be present in the container.
5932+
Optional. The environment variables that are to be present in the container.
59305933
Should be a dictionary where keys are environment variable names
59315934
and values are environment variable values for those names.
59325935
model_serving_container_ports (Sequence[int]):
5933-
Declaration of ports that are exposed by the container. This field is
5934-
primarily informational, it gives Vertex AI information about the
5935-
network connections the container uses. Listing or not a port here has
5936-
no impact on whether the port is actually exposed, any port listening on
5937-
the default "0.0.0.0" address inside a container will be accessible from
5938-
the network.
5936+
Optional. Declaration of ports that are exposed by the container.
5937+
This field is primarily informational, it gives Vertex AI information
5938+
about the network connections the container uses. Listing or not
5939+
a port here has no impact on whether the port is actually exposed,
5940+
any port listening on the default "0.0.0.0" address inside a
5941+
container will be accessible from the network.
59395942
model_description (str):
5940-
The description of the Model.
5943+
Optional. The description of the Model.
59415944
model_instance_schema_uri (str):
59425945
Optional. Points to a YAML file stored on Google Cloud
59435946
Storage describing the format of a single instance, which
@@ -6036,7 +6039,7 @@ def __init__(
60366039
60376040
Overrides encryption_spec_key_name set in aiplatform.init.
60386041
staging_bucket (str):
6039-
Bucket used to stage source and training artifacts. Overrides
6042+
Optional. Bucket used to stage source and training artifacts. Overrides
60406043
staging_bucket set in aiplatform.init.
60416044
"""
60426045
if not display_name:
@@ -6066,7 +6069,12 @@ def __init__(
60666069
staging_bucket=staging_bucket,
60676070
)
60686071

6069-
self._package_gcs_uri = python_package_gcs_uri
6072+
if isinstance(python_package_gcs_uri, str):
6073+
self._package_gcs_uri = [python_package_gcs_uri]
6074+
elif isinstance(python_package_gcs_uri, list):
6075+
self._package_gcs_uri = python_package_gcs_uri
6076+
else:
6077+
raise ValueError("'python_package_gcs_uri' must be a string or list.")
60706078
self._python_module = python_module_name
60716079

60726080
def run(
@@ -6668,7 +6676,7 @@ def _run(
66686676
spec["python_package_spec"] = {
66696677
"executor_image_uri": self._container_uri,
66706678
"python_module": self._python_module,
6671-
"package_uris": [self._package_gcs_uri],
6679+
"package_uris": self._package_gcs_uri,
66726680
}
66736681

66746682
if args:

tests/unit/aiplatform/test_training_jobs.py

+55-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
# Copyright 2022 Google LLC
3+
# Copyright 2023 Google LLC
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -161,6 +161,7 @@
161161
_TEST_MODEL_DESCRIPTION = "test description"
162162

163163
_TEST_OUTPUT_PYTHON_PACKAGE_PATH = "gs://test-staging-bucket/trainer.tar.gz"
164+
_TEST_PACKAGE_GCS_URIS = [_TEST_OUTPUT_PYTHON_PACKAGE_PATH] * 2
164165
_TEST_PYTHON_MODULE_NAME = "aiplatform.task"
165166

166167
_TEST_MODEL_NAME = f"projects/{_TEST_PROJECT}/locations/us-central1/models/{_TEST_ID}"
@@ -4987,13 +4988,18 @@ def teardown_method(self):
49874988
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
49884989
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
49894990
@pytest.mark.parametrize("sync", [True, False])
4991+
@pytest.mark.parametrize(
4992+
"python_package_gcs_uri",
4993+
[_TEST_OUTPUT_PYTHON_PACKAGE_PATH, _TEST_PACKAGE_GCS_URIS],
4994+
)
49904995
def test_run_call_pipeline_service_create_with_tabular_dataset(
49914996
self,
49924997
mock_pipeline_service_create,
49934998
mock_pipeline_service_get,
49944999
mock_tabular_dataset,
49955000
mock_model_service_get,
49965001
sync,
5002+
python_package_gcs_uri,
49975003
):
49985004
aiplatform.init(
49995005
project=_TEST_PROJECT,
@@ -5004,7 +5010,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
50045010
job = training_jobs.CustomPythonPackageTrainingJob(
50055011
display_name=_TEST_DISPLAY_NAME,
50065012
labels=_TEST_LABELS,
5007-
python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH,
5013+
python_package_gcs_uri=python_package_gcs_uri,
50085014
python_module_name=_TEST_PYTHON_MODULE_NAME,
50095015
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
50105016
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
@@ -5050,6 +5056,11 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
50505056
for key, value in _TEST_ENVIRONMENT_VARIABLES.items()
50515057
]
50525058

5059+
if isinstance(python_package_gcs_uri, str):
5060+
package_uris = [python_package_gcs_uri]
5061+
else:
5062+
package_uris = python_package_gcs_uri
5063+
50535064
true_worker_pool_spec = {
50545065
"replica_count": _TEST_REPLICA_COUNT,
50555066
"machine_spec": {
@@ -5064,7 +5075,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
50645075
"python_package_spec": {
50655076
"executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE,
50665077
"python_module": _TEST_PYTHON_MODULE_NAME,
5067-
"package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH],
5078+
"package_uris": package_uris,
50685079
"args": true_args,
50695080
"env": true_env,
50705081
},
@@ -5164,6 +5175,47 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
51645175

51655176
assert job.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
51665177

5178+
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
5179+
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
5180+
def test_custom_python_package_training_job_run_raises_with_wrong_package_uris(
5181+
self,
5182+
mock_pipeline_service_create,
5183+
mock_pipeline_service_get,
5184+
mock_tabular_dataset,
5185+
mock_model_service_get,
5186+
):
5187+
aiplatform.init(
5188+
project=_TEST_PROJECT,
5189+
staging_bucket=_TEST_BUCKET_NAME,
5190+
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
5191+
)
5192+
5193+
wrong_package_gcs_uri = {"package": _TEST_OUTPUT_PYTHON_PACKAGE_PATH}
5194+
5195+
with pytest.raises(ValueError) as e:
5196+
training_jobs.CustomPythonPackageTrainingJob(
5197+
display_name=_TEST_DISPLAY_NAME,
5198+
labels=_TEST_LABELS,
5199+
python_package_gcs_uri=wrong_package_gcs_uri,
5200+
python_module_name=_TEST_PYTHON_MODULE_NAME,
5201+
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
5202+
model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE,
5203+
model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE,
5204+
model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE,
5205+
model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND,
5206+
model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS,
5207+
model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES,
5208+
model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS,
5209+
model_description=_TEST_MODEL_DESCRIPTION,
5210+
model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI,
5211+
model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI,
5212+
model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI,
5213+
explanation_metadata=_TEST_EXPLANATION_METADATA,
5214+
explanation_parameters=_TEST_EXPLANATION_PARAMETERS,
5215+
)
5216+
5217+
assert e.match("'python_package_gcs_uri' must be a string or list.")
5218+
51675219
@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
51685220
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
51695221
def test_custom_python_package_training_job_run_raises_with_impartial_explanation_spec(

0 commit comments

Comments
 (0)