Skip to content

Commit 15a7986

Browse files
authored
Merge branch 'main' into owl-bot-copy
2 parents 70a666a + 4ce1d96 commit 15a7986

File tree

7 files changed

+185
-3
lines changed

7 files changed

+185
-3
lines changed

samples/model-builder/conftest.py

+29
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest.mock import patch
1717

1818
from google.cloud import aiplatform
19+
from google.cloud import aiplatform_v1beta1
1920
import vertexai
2021
from vertexai.resources import preview as preview_resources
2122
import pytest
@@ -368,6 +369,34 @@ def mock_run_custom_package_training_job(mock_custom_package_training_job):
368369
yield mock
369370

370371

372+
@pytest.fixture
373+
def mock_job_service_client_v1beta1():
374+
mock = MagicMock(aiplatform_v1beta1.JobServiceClient)
375+
yield mock
376+
377+
378+
@pytest.fixture
379+
def mock_get_job_service_client_v1beta1(mock_job_service_client_v1beta1):
380+
with patch.object(aiplatform_v1beta1, "JobServiceClient") as mock:
381+
mock.return_value = mock_job_service_client_v1beta1
382+
yield mock
383+
384+
385+
@pytest.fixture
386+
def mock_create_custom_job_v1beta1(mock_job_service_client_v1beta1):
387+
with patch.object(
388+
mock_job_service_client_v1beta1, "create_custom_job"
389+
) as mock:
390+
yield mock
391+
392+
393+
@pytest.fixture
394+
def mock_get_create_custom_job_request_v1beta1():
395+
with patch.object(aiplatform_v1beta1, "CreateCustomJobRequest") as mock:
396+
mock.return_value = mock_custom_job
397+
yield mock
398+
399+
371400
@pytest.fixture
372401
def mock_custom_job():
373402
mock = MagicMock(aiplatform.CustomJob)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_sdk_create_custom_job_psci_sample]
16+
from google.cloud import aiplatform
17+
from google.cloud import aiplatform_v1beta1
18+
19+
20+
def create_custom_job_psci_sample(
21+
project: str,
22+
location: str,
23+
bucket: str,
24+
display_name: str,
25+
machine_type: str,
26+
replica_count: int,
27+
image_uri: str,
28+
network_attachment_name: str,
29+
):
30+
"""Custom training job sample with PSC-I through aiplatform_v1beta1."""
31+
aiplatform.init(project=project, location=location, staging_bucket=bucket)
32+
33+
client_options = {"api_endpoint": f"{location}-aiplatform.googleapis.com"}
34+
35+
client = aiplatform_v1beta1.JobServiceClient(client_options=client_options)
36+
37+
request = aiplatform_v1beta1.CreateCustomJobRequest(
38+
parent=f"projects/{project}/locations/{location}",
39+
custom_job=aiplatform_v1beta1.CustomJob(
40+
display_name=display_name,
41+
job_spec=aiplatform_v1beta1.CustomJobSpec(
42+
worker_pool_specs=[
43+
aiplatform_v1beta1.WorkerPoolSpec(
44+
machine_spec=aiplatform_v1beta1.MachineSpec(
45+
machine_type=machine_type,
46+
),
47+
replica_count=replica_count,
48+
container_spec=aiplatform_v1beta1.ContainerSpec(
49+
image_uri=image_uri,
50+
),
51+
)
52+
],
53+
psc_interface_config=aiplatform_v1beta1.PscInterfaceConfig(
54+
network_attachment=network_attachment_name,
55+
),
56+
),
57+
),
58+
)
59+
60+
response = client.create_custom_job(request=request)
61+
62+
return response
63+
64+
65+
# [END aiplatform_sdk_create_custom_job_psci_sample]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import create_custom_job_psci_sample
16+
from google.cloud import aiplatform_v1beta1
17+
import test_constants as constants
18+
19+
20+
def test_create_custom_job_psci_sample(
21+
mock_sdk_init,
22+
mock_get_job_service_client_v1beta1,
23+
mock_get_create_custom_job_request_v1beta1,
24+
mock_create_custom_job_v1beta1,
25+
):
26+
"""Custom training job sample with PSC-I through aiplatform_v1beta1."""
27+
create_custom_job_psci_sample.create_custom_job_psci_sample(
28+
project=constants.PROJECT,
29+
location=constants.LOCATION,
30+
bucket=constants.STAGING_BUCKET,
31+
display_name=constants.DISPLAY_NAME,
32+
machine_type=constants.MACHINE_TYPE,
33+
replica_count=1,
34+
image_uri=constants.TRAIN_IMAGE,
35+
network_attachment_name=constants.NETWORK_ATTACHMENT_NAME,
36+
)
37+
38+
mock_sdk_init.assert_called_once_with(
39+
project=constants.PROJECT,
40+
location=constants.LOCATION,
41+
staging_bucket=constants.STAGING_BUCKET,
42+
)
43+
44+
mock_get_job_service_client_v1beta1.assert_called_once_with(
45+
client_options={
46+
"api_endpoint": f"{constants.LOCATION}-aiplatform.googleapis.com"
47+
}
48+
)
49+
50+
mock_get_create_custom_job_request_v1beta1.assert_called_once_with(
51+
parent=f"projects/{constants.PROJECT}/locations/{constants.LOCATION}",
52+
custom_job=aiplatform_v1beta1.CustomJob(
53+
display_name=constants.DISPLAY_NAME,
54+
job_spec=aiplatform_v1beta1.CustomJobSpec(
55+
worker_pool_specs=[
56+
aiplatform_v1beta1.WorkerPoolSpec(
57+
machine_spec=aiplatform_v1beta1.MachineSpec(
58+
machine_type=constants.MACHINE_TYPE,
59+
),
60+
replica_count=constants.REPLICA_COUNT,
61+
container_spec=aiplatform_v1beta1.ContainerSpec(
62+
image_uri=constants.TRAIN_IMAGE,
63+
),
64+
)
65+
],
66+
psc_interface_config=aiplatform_v1beta1.PscInterfaceConfig(
67+
network_attachment=constants.NETWORK_ATTACHMENT_NAME,
68+
),
69+
),
70+
),
71+
)
72+
73+
request = aiplatform_v1beta1.CreateCustomJobRequest(
74+
mock_get_create_custom_job_request_v1beta1.return_value
75+
)
76+
77+
mock_create_custom_job_v1beta1.assert_called_once_with(request=request)

samples/model-builder/test_constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
MACHINE_TYPE = "n1-standard-4"
115115
ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED"
116116
ACCELERATOR_COUNT = 0
117+
NETWORK_ATTACHMENT_NAME = "network-attachment-name"
117118

118119
# Model constants
119120
MODEL_RESOURCE_NAME = f"{PARENT}/models/1234"

tests/unit/vertexai/test_feature_group.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@
9494
_TEST_FG1_FM_LIST,
9595
)
9696
from test_feature import feature_eq
97-
from test_feature_monitor import feature_monitor_eq
97+
from test_feature_monitor import (
98+
feature_monitor_eq,
99+
)
98100

99101

100102
pytestmark = pytest.mark.usefixtures("google_auth_mock")
@@ -948,7 +950,9 @@ def test_create_feature_monitor(
948950
)
949951

950952

951-
def test_list_feature_monitors(get_fg_mock, list_feature_monitors_mock):
953+
def test_list_feature_monitors(
954+
get_fg_mock, get_feature_monitor_mock, list_feature_monitors_mock
955+
):
952956
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
953957

954958
feature_monitors = FeatureGroup(_TEST_FG1_ID).list_feature_monitors()

tests/unit/vertexai/test_feature_monitor.py

+4
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ def feature_monitor_eq(
120120
assert feature_monitor_to_check.location == location
121121
assert feature_monitor_to_check.description == description
122122
assert feature_monitor_to_check.labels == labels
123+
assert feature_monitor_to_check.schedule_config == schedule_config
124+
assert (
125+
feature_monitor_to_check.feature_selection_configs == feature_selection_configs
126+
)
123127

124128

125129
def feature_monitor_job_eq(

vertexai/resources/preview/feature_store/feature_monitor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ def feature_selection_configs(self) -> List[Tuple[str, float]]:
127127
configs.append(
128128
(
129129
feature_config.feature_id,
130-
feature_config.threshold if feature_config.threshold else 0.3,
130+
feature_config.drift_threshold
131+
if feature_config.drift_threshold
132+
else 0.3,
131133
)
132134
)
133135
return configs

0 commit comments

Comments
 (0)