Skip to content

Commit 4ce1d96

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: Add sample for custom training job with PSC-I through aiplatform_v1beta1.
PiperOrigin-RevId: 699196687
1 parent 88aaed1 commit 4ce1d96

File tree

4 files changed

+172
-0
lines changed

4 files changed

+172
-0
lines changed

samples/model-builder/conftest.py

+29
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest.mock import patch
1717

1818
from google.cloud import aiplatform
19+
from google.cloud import aiplatform_v1beta1
1920
import vertexai
2021
from vertexai.resources import preview as preview_resources
2122
import pytest
@@ -368,6 +369,34 @@ def mock_run_custom_package_training_job(mock_custom_package_training_job):
368369
yield mock
369370

370371

372+
@pytest.fixture
373+
def mock_job_service_client_v1beta1():
374+
mock = MagicMock(aiplatform_v1beta1.JobServiceClient)
375+
yield mock
376+
377+
378+
@pytest.fixture
379+
def mock_get_job_service_client_v1beta1(mock_job_service_client_v1beta1):
380+
with patch.object(aiplatform_v1beta1, "JobServiceClient") as mock:
381+
mock.return_value = mock_job_service_client_v1beta1
382+
yield mock
383+
384+
385+
@pytest.fixture
386+
def mock_create_custom_job_v1beta1(mock_job_service_client_v1beta1):
387+
with patch.object(
388+
mock_job_service_client_v1beta1, "create_custom_job"
389+
) as mock:
390+
yield mock
391+
392+
393+
@pytest.fixture
394+
def mock_get_create_custom_job_request_v1beta1():
395+
with patch.object(aiplatform_v1beta1, "CreateCustomJobRequest") as mock:
396+
mock.return_value = mock_custom_job
397+
yield mock
398+
399+
371400
@pytest.fixture
372401
def mock_custom_job():
373402
mock = MagicMock(aiplatform.CustomJob)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_sdk_create_custom_job_psci_sample]
16+
from google.cloud import aiplatform
17+
from google.cloud import aiplatform_v1beta1
18+
19+
20+
def create_custom_job_psci_sample(
21+
project: str,
22+
location: str,
23+
bucket: str,
24+
display_name: str,
25+
machine_type: str,
26+
replica_count: int,
27+
image_uri: str,
28+
network_attachment_name: str,
29+
):
30+
"""Custom training job sample with PSC-I through aiplatform_v1beta1."""
31+
aiplatform.init(project=project, location=location, staging_bucket=bucket)
32+
33+
client_options = {"api_endpoint": f"{location}-aiplatform.googleapis.com"}
34+
35+
client = aiplatform_v1beta1.JobServiceClient(client_options=client_options)
36+
37+
request = aiplatform_v1beta1.CreateCustomJobRequest(
38+
parent=f"projects/{project}/locations/{location}",
39+
custom_job=aiplatform_v1beta1.CustomJob(
40+
display_name=display_name,
41+
job_spec=aiplatform_v1beta1.CustomJobSpec(
42+
worker_pool_specs=[
43+
aiplatform_v1beta1.WorkerPoolSpec(
44+
machine_spec=aiplatform_v1beta1.MachineSpec(
45+
machine_type=machine_type,
46+
),
47+
replica_count=replica_count,
48+
container_spec=aiplatform_v1beta1.ContainerSpec(
49+
image_uri=image_uri,
50+
),
51+
)
52+
],
53+
psc_interface_config=aiplatform_v1beta1.PscInterfaceConfig(
54+
network_attachment=network_attachment_name,
55+
),
56+
),
57+
),
58+
)
59+
60+
response = client.create_custom_job(request=request)
61+
62+
return response
63+
64+
65+
# [END aiplatform_sdk_create_custom_job_psci_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import create_custom_job_psci_sample
16+
from google.cloud import aiplatform_v1beta1
17+
import test_constants as constants
18+
19+
20+
def test_create_custom_job_psci_sample(
21+
mock_sdk_init,
22+
mock_get_job_service_client_v1beta1,
23+
mock_get_create_custom_job_request_v1beta1,
24+
mock_create_custom_job_v1beta1,
25+
):
26+
"""Custom training job sample with PSC-I through aiplatform_v1beta1."""
27+
create_custom_job_psci_sample.create_custom_job_psci_sample(
28+
project=constants.PROJECT,
29+
location=constants.LOCATION,
30+
bucket=constants.STAGING_BUCKET,
31+
display_name=constants.DISPLAY_NAME,
32+
machine_type=constants.MACHINE_TYPE,
33+
replica_count=1,
34+
image_uri=constants.TRAIN_IMAGE,
35+
network_attachment_name=constants.NETWORK_ATTACHMENT_NAME,
36+
)
37+
38+
mock_sdk_init.assert_called_once_with(
39+
project=constants.PROJECT,
40+
location=constants.LOCATION,
41+
staging_bucket=constants.STAGING_BUCKET,
42+
)
43+
44+
mock_get_job_service_client_v1beta1.assert_called_once_with(
45+
client_options={
46+
"api_endpoint": f"{constants.LOCATION}-aiplatform.googleapis.com"
47+
}
48+
)
49+
50+
mock_get_create_custom_job_request_v1beta1.assert_called_once_with(
51+
parent=f"projects/{constants.PROJECT}/locations/{constants.LOCATION}",
52+
custom_job=aiplatform_v1beta1.CustomJob(
53+
display_name=constants.DISPLAY_NAME,
54+
job_spec=aiplatform_v1beta1.CustomJobSpec(
55+
worker_pool_specs=[
56+
aiplatform_v1beta1.WorkerPoolSpec(
57+
machine_spec=aiplatform_v1beta1.MachineSpec(
58+
machine_type=constants.MACHINE_TYPE,
59+
),
60+
replica_count=constants.REPLICA_COUNT,
61+
container_spec=aiplatform_v1beta1.ContainerSpec(
62+
image_uri=constants.TRAIN_IMAGE,
63+
),
64+
)
65+
],
66+
psc_interface_config=aiplatform_v1beta1.PscInterfaceConfig(
67+
network_attachment=constants.NETWORK_ATTACHMENT_NAME,
68+
),
69+
),
70+
),
71+
)
72+
73+
request = aiplatform_v1beta1.CreateCustomJobRequest(
74+
mock_get_create_custom_job_request_v1beta1.return_value
75+
)
76+
77+
mock_create_custom_job_v1beta1.assert_called_once_with(request=request)

samples/model-builder/test_constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
MACHINE_TYPE = "n1-standard-4"
115115
ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED"
116116
ACCELERATOR_COUNT = 0
117+
NETWORK_ATTACHMENT_NAME = "network-attachment-name"
117118

118119
# Model constants
119120
MODEL_RESOURCE_NAME = f"{PARENT}/models/1234"

0 commit comments

Comments
 (0)