Skip to content

Commit 0ae969d

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Enable vertexai preview persistent cluster executor
PiperOrigin-RevId: 569371823
1 parent e1cedba commit 0ae969d

File tree

11 files changed

+444
-102
lines changed

11 files changed

+444
-102
lines changed

google/cloud/aiplatform/preview/jobs.py

+20
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from google.cloud.aiplatform import utils
2929
from google.cloud.aiplatform.compat.types import (
3030
custom_job_v1beta1 as gca_custom_job_compat,
31+
job_state as gca_job_state,
32+
job_state_v1beta1 as gca_job_state_v1beta1,
3133
)
3234
from google.cloud.aiplatform.compat.types import (
3335
execution_v1beta1 as gcs_execution_compat,
@@ -42,6 +44,24 @@
4244

4345
_LOGGER = base.Logger(__name__)
4446
_DEFAULT_RETRY = retry.Retry()
47+
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
48+
_JOB_COMPLETE_STATES = (
49+
gca_job_state.JobState.JOB_STATE_SUCCEEDED,
50+
gca_job_state.JobState.JOB_STATE_FAILED,
51+
gca_job_state.JobState.JOB_STATE_CANCELLED,
52+
gca_job_state.JobState.JOB_STATE_PAUSED,
53+
gca_job_state_v1beta1.JobState.JOB_STATE_SUCCEEDED,
54+
gca_job_state_v1beta1.JobState.JOB_STATE_FAILED,
55+
gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLED,
56+
gca_job_state_v1beta1.JobState.JOB_STATE_PAUSED,
57+
)
58+
59+
_JOB_ERROR_STATES = (
60+
gca_job_state.JobState.JOB_STATE_FAILED,
61+
gca_job_state.JobState.JOB_STATE_CANCELLED,
62+
gca_job_state_v1beta1.JobState.JOB_STATE_FAILED,
63+
gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLED,
64+
)
4565

4666

4767
class CustomJob(jobs.CustomJob):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2023 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
from typing import NamedTuple, Optional, Dict, Union
18+
19+
from google.cloud.aiplatform import utils
20+
from google.cloud.aiplatform.compat.types import (
21+
accelerator_type_v1beta1 as gca_accelerator_type_compat,
22+
)
23+
24+
25+
class _ResourcePool(NamedTuple):
26+
"""Specification container for Worker Pool specs used for distributed training.
27+
28+
Usage:
29+
30+
resource_pool = _ResourcePool(
31+
replica_count=1,
32+
machine_type='n1-standard-4',
33+
accelerator_count=1,
34+
accelerator_type='NVIDIA_TESLA_K80',
35+
boot_disk_type='pd-ssd',
36+
boot_disk_size_gb=100,
37+
)
38+
39+
Note that container and python package specs are not stored with this spec.
40+
"""
41+
42+
replica_count: int = 1
43+
machine_type: str = "n1-standard-4"
44+
accelerator_count: int = 0
45+
accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED"
46+
boot_disk_type: str = "pd-ssd"
47+
boot_disk_size_gb: int = 100
48+
49+
def _get_accelerator_type(self) -> Optional[str]:
50+
"""Validates accelerator_type and returns the name of the accelerator.
51+
52+
Returns:
53+
None if no accelerator or valid accelerator name.
54+
55+
Raise:
56+
ValueError if accelerator type is invalid.
57+
"""
58+
59+
# Raises ValueError if invalid accelerator_type
60+
utils.validate_accelerator_type(self.accelerator_type)
61+
62+
accelerator_enum = getattr(
63+
gca_accelerator_type_compat.AcceleratorType, self.accelerator_type
64+
)
65+
66+
if (
67+
accelerator_enum
68+
!= gca_accelerator_type_compat.AcceleratorType.ACCELERATOR_TYPE_UNSPECIFIED
69+
):
70+
return self.accelerator_type
71+
72+
@property
73+
def spec_dict(self) -> Dict[str, Union[int, str, Dict[str, Union[int, str]]]]:
74+
"""Return specification as a Dict."""
75+
spec = {
76+
"machine_spec": {"machine_type": self.machine_type},
77+
"replica_count": self.replica_count,
78+
"disk_spec": {
79+
"boot_disk_type": self.boot_disk_type,
80+
"boot_disk_size_gb": self.boot_disk_size_gb,
81+
},
82+
}
83+
84+
accelerator_type = self._get_accelerator_type()
85+
if accelerator_type and self.accelerator_count:
86+
spec["machine_spec"]["accelerator_type"] = accelerator_type
87+
spec["machine_spec"]["accelerator_count"] = self.accelerator_count
88+
89+
return spec
90+
91+
@property
92+
def is_empty(self) -> bool:
93+
"""Returns True is replica_count > 0 False otherwise."""
94+
return self.replica_count <= 0

tests/unit/vertexai/conftest.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,13 @@
3838
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
3939
PersistentResourceServiceClient,
4040
)
41-
import constants as test_constants
4241
from pyfakefs import fake_filesystem_unittest
4342
import pytest
4443
import tensorflow.saved_model as tf_saved_model
44+
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
45+
PersistentResource,
46+
ResourcePool,
47+
)
4548

4649

4750
_TEST_PROJECT = "test-project"
@@ -83,6 +86,18 @@
8386
labels={"trained_by_vertex_ai": "true"},
8487
)
8588

89+
_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource()
90+
resource_pool = ResourcePool()
91+
resource_pool.machine_spec.machine_type = "n1-standard-4"
92+
resource_pool.replica_count = 1
93+
resource_pool.disk_spec.boot_disk_type = "pd-ssd"
94+
resource_pool.disk_spec.boot_disk_size_gb = 100
95+
_TEST_REQUEST_RUNNING_DEFAULT.resource_pools = [resource_pool]
96+
97+
98+
_TEST_PERSISTENT_RESOURCE_RUNNING = PersistentResource()
99+
_TEST_PERSISTENT_RESOURCE_RUNNING.state = "RUNNING"
100+
86101

87102
@pytest.fixture(scope="module")
88103
def google_auth_mock():
@@ -264,7 +279,7 @@ def persistent_resource_running_mock():
264279
"get_persistent_resource",
265280
) as persistent_resource_running_mock:
266281
persistent_resource_running_mock.return_value = (
267-
test_constants._TEST_PERSISTENT_RESOURCE_RUNNING
282+
_TEST_PERSISTENT_RESOURCE_RUNNING
268283
)
269284
yield persistent_resource_running_mock
270285

@@ -287,7 +302,7 @@ def create_persistent_resource_default_mock():
287302
) as create_persistent_resource_default_mock:
288303
create_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation)
289304
create_persistent_resource_lro_mock.result.return_value = (
290-
test_constants._TEST_REQUEST_RUNNING_DEFAULT
305+
_TEST_REQUEST_RUNNING_DEFAULT
291306
)
292307
create_persistent_resource_default_mock.return_value = (
293308
create_persistent_resource_lro_mock

tests/unit/vertexai/constants.py

-45
This file was deleted.

tests/unit/vertexai/test_persistent_resource_util.py

+67-34
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@
1919
from google.api_core import operation as ga_operation
2020
from google.cloud import aiplatform
2121
import vertexai
22+
from vertexai.preview.developer import remote_specs
2223
from google.cloud.aiplatform_v1beta1.services.persistent_resource_service import (
2324
PersistentResourceServiceClient,
2425
)
2526
from google.cloud.aiplatform_v1beta1.types import persistent_resource_service
27+
from google.cloud.aiplatform_v1beta1.types.machine_resources import DiskSpec
28+
from google.cloud.aiplatform_v1beta1.types.machine_resources import (
29+
MachineSpec,
30+
)
2631
from google.cloud.aiplatform_v1beta1.types.persistent_resource import (
2732
PersistentResource,
2833
)
@@ -48,55 +53,64 @@
4853
_TEST_PERSISTENT_RESOURCE_ERROR = PersistentResource()
4954
_TEST_PERSISTENT_RESOURCE_ERROR.state = "ERROR"
5055

51-
_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource()
52-
resource_pool = ResourcePool()
53-
resource_pool.machine_spec.machine_type = "n1-standard-4"
54-
resource_pool.replica_count = 1
55-
resource_pool.disk_spec.boot_disk_type = "pd-ssd"
56-
resource_pool.disk_spec.boot_disk_size_gb = 100
57-
_TEST_REQUEST_RUNNING_DEFAULT.resource_pools = [resource_pool]
58-
56+
resource_pool_0 = ResourcePool(
57+
machine_spec=MachineSpec(machine_type="n1-standard-4"),
58+
disk_spec=DiskSpec(
59+
boot_disk_type="pd-ssd",
60+
boot_disk_size_gb=100,
61+
),
62+
replica_count=1,
63+
)
64+
resource_pool_1 = ResourcePool(
65+
machine_spec=MachineSpec(
66+
machine_type="n1-standard-8",
67+
accelerator_type="NVIDIA_TESLA_T4",
68+
accelerator_count=1,
69+
),
70+
disk_spec=DiskSpec(
71+
boot_disk_type="pd-ssd",
72+
boot_disk_size_gb=100,
73+
),
74+
replica_count=2,
75+
)
76+
_TEST_REQUEST_RUNNING_DEFAULT = PersistentResource(
77+
resource_pools=[resource_pool_0],
78+
)
79+
_TEST_REQUEST_RUNNING_CUSTOM = PersistentResource(
80+
resource_pools=[resource_pool_0, resource_pool_1],
81+
)
5982

6083
_TEST_PERSISTENT_RESOURCE_RUNNING = PersistentResource()
6184
_TEST_PERSISTENT_RESOURCE_RUNNING.state = "RUNNING"
6285

63-
64-
@pytest.fixture
65-
def persistent_resource_running_mock():
66-
with mock.patch.object(
67-
PersistentResourceServiceClient,
68-
"get_persistent_resource",
69-
) as persistent_resource_running_mock:
70-
persistent_resource_running_mock.return_value = (
71-
_TEST_PERSISTENT_RESOURCE_RUNNING
72-
)
73-
yield persistent_resource_running_mock
74-
75-
76-
@pytest.fixture
77-
def persistent_resource_exception_mock():
78-
with mock.patch.object(
79-
PersistentResourceServiceClient,
80-
"get_persistent_resource",
81-
) as persistent_resource_exception_mock:
82-
persistent_resource_exception_mock.side_effect = Exception
83-
yield persistent_resource_exception_mock
86+
# user-configured remote_specs.ResourcePool
87+
remote_specs_resource_pool_0 = remote_specs.ResourcePool(replica_count=1)
88+
remote_specs_resource_pool_1 = remote_specs.ResourcePool(
89+
machine_type="n1-standard-8",
90+
replica_count=2,
91+
accelerator_type="NVIDIA_TESLA_T4",
92+
accelerator_count=1,
93+
)
94+
_TEST_CUSTOM_RESOURCE_POOLS = [
95+
remote_specs_resource_pool_0,
96+
remote_specs_resource_pool_1,
97+
]
8498

8599

86100
@pytest.fixture
87-
def create_persistent_resource_default_mock():
101+
def create_persistent_resource_custom_mock():
88102
with mock.patch.object(
89103
PersistentResourceServiceClient,
90104
"create_persistent_resource",
91-
) as create_persistent_resource_default_mock:
105+
) as create_persistent_resource_custom_mock:
92106
create_persistent_resource_lro_mock = mock.Mock(ga_operation.Operation)
93107
create_persistent_resource_lro_mock.result.return_value = (
94-
_TEST_REQUEST_RUNNING_DEFAULT
108+
_TEST_REQUEST_RUNNING_CUSTOM
95109
)
96-
create_persistent_resource_default_mock.return_value = (
110+
create_persistent_resource_custom_mock.return_value = (
97111
create_persistent_resource_lro_mock
98112
)
99-
yield create_persistent_resource_default_mock
113+
yield create_persistent_resource_custom_mock
100114

101115

102116
@pytest.fixture
@@ -180,6 +194,25 @@ def test_create_persistent_resource_default_success(
180194
request,
181195
)
182196

197+
@pytest.mark.usefixtures("persistent_resource_running_mock")
198+
def test_create_persistent_resource_custom_success(
199+
self, create_persistent_resource_custom_mock
200+
):
201+
persistent_resource_util.create_persistent_resource(
202+
cluster_resource_name=_TEST_CLUSTER_RESOURCE_NAME,
203+
resource_pools=_TEST_CUSTOM_RESOURCE_POOLS,
204+
)
205+
206+
request = persistent_resource_service.CreatePersistentResourceRequest(
207+
parent=_TEST_PARENT,
208+
persistent_resource=_TEST_REQUEST_RUNNING_CUSTOM,
209+
persistent_resource_id=_TEST_CLUSTER_NAME,
210+
)
211+
212+
create_persistent_resource_custom_mock.assert_called_with(
213+
request,
214+
)
215+
183216
@pytest.mark.usefixtures("create_persistent_resource_exception_mock")
184217
def test_create_ray_cluster_state_error(self):
185218
with pytest.raises(ValueError) as e:

0 commit comments

Comments
 (0)