Skip to content

Commit 7ff8071

Browse files
jaycee-licopybara-github
authored andcommitted
feat: GenAI - Add cancel, delete, list methods in BatchPredictionJob
PiperOrigin-RevId: 633702194
1 parent 4d091c6 commit 7ff8071

File tree

2 files changed

+197
-10
lines changed

2 files changed

+197
-10
lines changed

tests/unit/vertexai/test_batch_prediction.py

+116-4
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,36 @@ def complete_bq_uri_mock():
8989

9090

9191
@pytest.fixture
92-
def get_batch_prediction_job_mock():
92+
def get_batch_prediction_job_with_bq_output_mock():
9393
with mock.patch.object(
9494
job_service_client.JobServiceClient, "get_batch_prediction_job"
9595
) as get_job_mock:
96-
get_job_mock.return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB
96+
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
97+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
98+
display_name=_TEST_DISPLAY_NAME,
99+
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
100+
state=_TEST_JOB_STATE_SUCCESS,
101+
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
102+
bigquery_output_table=_TEST_BQ_OUTPUT_PREFIX
103+
),
104+
)
105+
yield get_job_mock
106+
107+
108+
@pytest.fixture
109+
def get_batch_prediction_job_with_gcs_output_mock():
110+
with mock.patch.object(
111+
job_service_client.JobServiceClient, "get_batch_prediction_job"
112+
) as get_job_mock:
113+
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob(
114+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
115+
display_name=_TEST_DISPLAY_NAME,
116+
model=_TEST_GEMINI_MODEL_RESOURCE_NAME,
117+
state=_TEST_JOB_STATE_SUCCESS,
118+
output_info=gca_batch_prediction_job_compat.BatchPredictionJob.OutputInfo(
119+
gcs_output_directory=_TEST_GCS_OUTPUT_PREFIX
120+
),
121+
)
97122
yield get_job_mock
98123

99124

@@ -120,6 +145,39 @@ def create_batch_prediction_job_mock():
120145
yield create_job_mock
121146

122147

148+
@pytest.fixture
149+
def cancel_batch_prediction_job_mock():
150+
with mock.patch.object(
151+
job_service_client.JobServiceClient, "cancel_batch_prediction_job"
152+
) as cancel_job_mock:
153+
yield cancel_job_mock
154+
155+
156+
@pytest.fixture
157+
def delete_batch_prediction_job_mock():
158+
with mock.patch.object(
159+
job_service_client.JobServiceClient, "delete_batch_prediction_job"
160+
) as delete_job_mock:
161+
yield delete_job_mock
162+
163+
164+
@pytest.fixture
165+
def list_batch_prediction_jobs_mock():
166+
with mock.patch.object(
167+
job_service_client.JobServiceClient, "list_batch_prediction_jobs"
168+
) as list_jobs_mock:
169+
list_jobs_mock.return_value = [
170+
_TEST_GAPIC_BATCH_PREDICTION_JOB,
171+
gca_batch_prediction_job_compat.BatchPredictionJob(
172+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
173+
display_name=_TEST_DISPLAY_NAME,
174+
model=_TEST_PALM_MODEL_RESOURCE_NAME,
175+
state=_TEST_JOB_STATE_SUCCESS,
176+
),
177+
]
178+
yield list_jobs_mock
179+
180+
123181
@pytest.mark.usefixtures(
124182
"google_auth_mock", "generate_display_name_mock", "complete_bq_uri_mock"
125183
)
@@ -138,10 +196,12 @@ def setup_method(self):
138196
def teardown_method(self):
139197
aiplatform_initializer.global_pool.shutdown(wait=True)
140198

141-
def test_init_batch_prediction_job(self, get_batch_prediction_job_mock):
199+
def test_init_batch_prediction_job(
200+
self, get_batch_prediction_job_with_gcs_output_mock
201+
):
142202
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
143203

144-
get_batch_prediction_job_mock.assert_called_once_with(
204+
get_batch_prediction_job_with_gcs_output_mock.assert_called_once_with(
145205
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY
146206
)
147207

@@ -157,6 +217,7 @@ def test_init_batch_prediction_job_invalid_model(self):
157217
):
158218
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
159219

220+
@pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock")
160221
def test_submit_batch_prediction_job_with_gcs_input(
161222
self, create_batch_prediction_job_mock
162223
):
@@ -167,6 +228,15 @@ def test_submit_batch_prediction_job_with_gcs_input(
167228
)
168229

169230
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
231+
assert job.state == _TEST_JOB_STATE_RUNNING
232+
assert not job.has_ended
233+
assert not job.has_succeeded
234+
235+
job.refresh()
236+
assert job.state == _TEST_JOB_STATE_SUCCESS
237+
assert job.has_ended
238+
assert job.has_succeeded
239+
assert job.output_location == _TEST_GCS_OUTPUT_PREFIX
170240

171241
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
172242
display_name=_TEST_DISPLAY_NAME,
@@ -188,6 +258,7 @@ def test_submit_batch_prediction_job_with_gcs_input(
188258
timeout=None,
189259
)
190260

261+
@pytest.mark.usefixtures("get_batch_prediction_job_with_bq_output_mock")
191262
def test_submit_batch_prediction_job_with_bq_input(
192263
self, create_batch_prediction_job_mock
193264
):
@@ -198,6 +269,15 @@ def test_submit_batch_prediction_job_with_bq_input(
198269
)
199270

200271
assert job.gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
272+
assert job.state == _TEST_JOB_STATE_RUNNING
273+
assert not job.has_ended
274+
assert not job.has_succeeded
275+
276+
job.refresh()
277+
assert job.state == _TEST_JOB_STATE_SUCCESS
278+
assert job.has_ended
279+
assert job.has_succeeded
280+
assert job.output_location == _TEST_BQ_OUTPUT_PREFIX
201281

202282
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
203283
display_name=_TEST_DISPLAY_NAME,
@@ -349,3 +429,35 @@ def test_submit_batch_prediction_job_without_output_uri_prefix_and_bucket(self):
349429
source_model=_TEST_GEMINI_MODEL_NAME,
350430
input_dataset=_TEST_GCS_INPUT_URI,
351431
)
432+
433+
@pytest.mark.usefixtures("create_batch_prediction_job_mock")
434+
def test_cancel_batch_prediction_job(self, cancel_batch_prediction_job_mock):
435+
job = batch_prediction.BatchPredictionJob.submit(
436+
source_model=_TEST_GEMINI_MODEL_NAME,
437+
input_dataset=_TEST_GCS_INPUT_URI,
438+
output_uri_prefix=_TEST_GCS_OUTPUT_PREFIX,
439+
)
440+
job.cancel()
441+
442+
cancel_batch_prediction_job_mock.assert_called_once_with(
443+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
444+
)
445+
446+
@pytest.mark.usefixtures("get_batch_prediction_job_with_gcs_output_mock")
447+
def test_delete_batch_prediction_job(self, delete_batch_prediction_job_mock):
448+
job = batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID)
449+
job.delete()
450+
451+
delete_batch_prediction_job_mock.assert_called_once_with(
452+
name=_TEST_BATCH_PREDICTION_JOB_NAME,
453+
)
454+
455+
def tes_list_batch_prediction_jobs(self, list_batch_prediction_jobs_mock):
456+
jobs = batch_prediction.BatchPredictionJob.list()
457+
458+
assert len(jobs) == 1
459+
assert jobs[0].gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
460+
461+
list_batch_prediction_jobs_mock.assert_called_once_with(
462+
request={"parent": _TEST_PARENT}
463+
)

vertexai/batch_prediction/_batch_prediction.py

+81-6
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@
2323
from google.cloud.aiplatform import initializer as aiplatform_initializer
2424
from google.cloud.aiplatform import jobs
2525
from google.cloud.aiplatform import utils as aiplatform_utils
26+
from google.cloud.aiplatform_v1 import types as gca_types
2627
from vertexai import generative_models
2728

29+
from google.rpc import status_pb2
30+
2831

2932
_LOGGER = aiplatform_base.Logger(__name__)
3033

@@ -37,7 +40,6 @@ class BatchPredictionJob(aiplatform_base._VertexAiResourceNounPlus):
3740
_resource_noun = "batchPredictionJobs"
3841
_getter_method = "get_batch_prediction_job"
3942
_list_method = "list_batch_prediction_jobs"
40-
_cancel_method = "cancel_batch_prediction_job"
4143
_delete_method = "delete_batch_prediction_job"
4244
_job_type = "batch-predictions"
4345
_parse_resource_name_method = "parse_batch_prediction_job_path"
@@ -63,13 +65,46 @@ def __init__(self, batch_prediction_job_name: str):
6365
resource_name=batch_prediction_job_name
6466
)
6567
# TODO(b/338452508) Support tuned GenAI models.
66-
if not re.search(_GEMINI_MODEL_PATTERN, self._gca_resource.model):
68+
if not re.search(_GEMINI_MODEL_PATTERN, self.model_name):
6769
raise ValueError(
6870
f"BatchPredictionJob '{batch_prediction_job_name}' "
69-
f"runs with the model '{self._gca_resource.model}', "
71+
f"runs with the model '{self.model_name}', "
7072
"which is not a GenAI model."
7173
)
7274

75+
@property
76+
def model_name(self) -> str:
77+
"""Returns the model name used for this batch prediction job."""
78+
return self._gca_resource.model
79+
80+
@property
81+
def state(self) -> gca_types.JobState:
82+
"""Returns the state of this batch prediction job."""
83+
return self._gca_resource.state
84+
85+
@property
86+
def has_ended(self) -> bool:
87+
"""Returns true if this batch prediction job has ended."""
88+
return self.state in jobs._JOB_COMPLETE_STATES
89+
90+
@property
91+
def has_succeeded(self) -> bool:
92+
"""Returns true if this batch prediction job has succeeded."""
93+
return self.state == gca_types.JobState.JOB_STATE_SUCCEEDED
94+
95+
@property
96+
def error(self) -> Optional[status_pb2.Status]:
97+
"""Returns detailed error info for this Job resource."""
98+
return self._gca_resource.error
99+
100+
@property
101+
def output_location(self) -> str:
102+
"""Returns the output location of this batch prediction job."""
103+
return (
104+
self._gca_resource.output_info.gcs_output_directory
105+
or self._gca_resource.output_info.bigquery_output_table
106+
)
107+
73108
@classmethod
74109
def submit(
75110
cls,
@@ -178,14 +213,54 @@ def submit(
178213
_LOGGER.log_create_complete(
179214
cls, job._gca_resource, "job", module_name="batch_prediction"
180215
)
181-
_LOGGER.info(
182-
"View Batch Prediction Job:\n%s" % aiplatform_job._dashboard_uri()
183-
)
216+
_LOGGER.info("View Batch Prediction Job:\n%s" % job._dashboard_uri())
184217

185218
return job
186219
finally:
187220
logging.getLogger("google.cloud.aiplatform.jobs").disabled = False
188221

222+
def refresh(self) -> "BatchPredictionJob":
223+
"""Refreshes the batch prediction job from the service."""
224+
self._sync_gca_resource()
225+
return self
226+
227+
def cancel(self):
228+
"""Cancels this BatchPredictionJob.
229+
230+
Success of cancellation is not guaranteed. Use `job.refresh()` and
231+
`job.state` to verify if cancellation was successful.
232+
"""
233+
_LOGGER.log_action_start_against_resource("Cancelling", "run", self)
234+
self.api_client.cancel_batch_prediction_job(name=self.resource_name)
235+
236+
def delete(self):
237+
"""Deletes this BatchPredictionJob resource.
238+
239+
WARNING: This deletion is permanent.
240+
"""
241+
self._delete()
242+
243+
@classmethod
244+
def list(cls, filter=None) -> List["BatchPredictionJob"]:
245+
"""Lists all BatchPredictionJob instances that run with GenAI models."""
246+
return cls._list(
247+
cls_filter=lambda gca_resource: re.search(
248+
_GEMINI_MODEL_PATTERN, gca_resource.model
249+
),
250+
filter=filter,
251+
)
252+
253+
def _dashboard_uri(self) -> Optional[str]:
254+
"""Returns the Google Cloud console URL where job can be viewed."""
255+
fields = self._parse_resource_name(self.resource_name)
256+
location = fields.pop("location")
257+
project = fields.pop("project")
258+
job = list(fields.values())[0]
259+
return (
260+
"https://console.cloud.google.com/ai/platform/locations/"
261+
f"{location}/{self._job_type}/{job}?project={project}"
262+
)
263+
189264
@classmethod
190265
def _reconcile_model_name(cls, model_name: str) -> str:
191266
"""Reconciles model name to a publisher model resource name."""

0 commit comments

Comments
 (0)