@@ -89,11 +89,36 @@ def complete_bq_uri_mock():
89
89
90
90
91
91
@pytest .fixture
92
- def get_batch_prediction_job_mock ():
92
+ def get_batch_prediction_job_with_bq_output_mock ():
93
93
with mock .patch .object (
94
94
job_service_client .JobServiceClient , "get_batch_prediction_job"
95
95
) as get_job_mock :
96
- get_job_mock .return_value = _TEST_GAPIC_BATCH_PREDICTION_JOB
96
+ get_job_mock .return_value = gca_batch_prediction_job_compat .BatchPredictionJob (
97
+ name = _TEST_BATCH_PREDICTION_JOB_NAME ,
98
+ display_name = _TEST_DISPLAY_NAME ,
99
+ model = _TEST_GEMINI_MODEL_RESOURCE_NAME ,
100
+ state = _TEST_JOB_STATE_SUCCESS ,
101
+ output_info = gca_batch_prediction_job_compat .BatchPredictionJob .OutputInfo (
102
+ bigquery_output_table = _TEST_BQ_OUTPUT_PREFIX
103
+ ),
104
+ )
105
+ yield get_job_mock
106
+
107
+
108
+ @pytest .fixture
109
+ def get_batch_prediction_job_with_gcs_output_mock ():
110
+ with mock .patch .object (
111
+ job_service_client .JobServiceClient , "get_batch_prediction_job"
112
+ ) as get_job_mock :
113
+ get_job_mock .return_value = gca_batch_prediction_job_compat .BatchPredictionJob (
114
+ name = _TEST_BATCH_PREDICTION_JOB_NAME ,
115
+ display_name = _TEST_DISPLAY_NAME ,
116
+ model = _TEST_GEMINI_MODEL_RESOURCE_NAME ,
117
+ state = _TEST_JOB_STATE_SUCCESS ,
118
+ output_info = gca_batch_prediction_job_compat .BatchPredictionJob .OutputInfo (
119
+ gcs_output_directory = _TEST_GCS_OUTPUT_PREFIX
120
+ ),
121
+ )
97
122
yield get_job_mock
98
123
99
124
@@ -120,6 +145,39 @@ def create_batch_prediction_job_mock():
120
145
yield create_job_mock
121
146
122
147
148
+ @pytest .fixture
149
+ def cancel_batch_prediction_job_mock ():
150
+ with mock .patch .object (
151
+ job_service_client .JobServiceClient , "cancel_batch_prediction_job"
152
+ ) as cancel_job_mock :
153
+ yield cancel_job_mock
154
+
155
+
156
+ @pytest .fixture
157
+ def delete_batch_prediction_job_mock ():
158
+ with mock .patch .object (
159
+ job_service_client .JobServiceClient , "delete_batch_prediction_job"
160
+ ) as delete_job_mock :
161
+ yield delete_job_mock
162
+
163
+
164
+ @pytest .fixture
165
+ def list_batch_prediction_jobs_mock ():
166
+ with mock .patch .object (
167
+ job_service_client .JobServiceClient , "list_batch_prediction_jobs"
168
+ ) as list_jobs_mock :
169
+ list_jobs_mock .return_value = [
170
+ _TEST_GAPIC_BATCH_PREDICTION_JOB ,
171
+ gca_batch_prediction_job_compat .BatchPredictionJob (
172
+ name = _TEST_BATCH_PREDICTION_JOB_NAME ,
173
+ display_name = _TEST_DISPLAY_NAME ,
174
+ model = _TEST_PALM_MODEL_RESOURCE_NAME ,
175
+ state = _TEST_JOB_STATE_SUCCESS ,
176
+ ),
177
+ ]
178
+ yield list_jobs_mock
179
+
180
+
123
181
@pytest .mark .usefixtures (
124
182
"google_auth_mock" , "generate_display_name_mock" , "complete_bq_uri_mock"
125
183
)
@@ -138,10 +196,12 @@ def setup_method(self):
138
196
def teardown_method (self ):
139
197
aiplatform_initializer .global_pool .shutdown (wait = True )
140
198
141
- def test_init_batch_prediction_job (self , get_batch_prediction_job_mock ):
199
+ def test_init_batch_prediction_job (
200
+ self , get_batch_prediction_job_with_gcs_output_mock
201
+ ):
142
202
batch_prediction .BatchPredictionJob (_TEST_BATCH_PREDICTION_JOB_ID )
143
203
144
- get_batch_prediction_job_mock .assert_called_once_with (
204
+ get_batch_prediction_job_with_gcs_output_mock .assert_called_once_with (
145
205
name = _TEST_BATCH_PREDICTION_JOB_NAME , retry = aiplatform_base ._DEFAULT_RETRY
146
206
)
147
207
@@ -157,6 +217,7 @@ def test_init_batch_prediction_job_invalid_model(self):
157
217
):
158
218
batch_prediction .BatchPredictionJob (_TEST_BATCH_PREDICTION_JOB_ID )
159
219
220
+ @pytest .mark .usefixtures ("get_batch_prediction_job_with_gcs_output_mock" )
160
221
def test_submit_batch_prediction_job_with_gcs_input (
161
222
self , create_batch_prediction_job_mock
162
223
):
@@ -167,6 +228,15 @@ def test_submit_batch_prediction_job_with_gcs_input(
167
228
)
168
229
169
230
assert job .gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
231
+ assert job .state == _TEST_JOB_STATE_RUNNING
232
+ assert not job .has_ended
233
+ assert not job .has_succeeded
234
+
235
+ job .refresh ()
236
+ assert job .state == _TEST_JOB_STATE_SUCCESS
237
+ assert job .has_ended
238
+ assert job .has_succeeded
239
+ assert job .output_location == _TEST_GCS_OUTPUT_PREFIX
170
240
171
241
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat .BatchPredictionJob (
172
242
display_name = _TEST_DISPLAY_NAME ,
@@ -188,6 +258,7 @@ def test_submit_batch_prediction_job_with_gcs_input(
188
258
timeout = None ,
189
259
)
190
260
261
+ @pytest .mark .usefixtures ("get_batch_prediction_job_with_bq_output_mock" )
191
262
def test_submit_batch_prediction_job_with_bq_input (
192
263
self , create_batch_prediction_job_mock
193
264
):
@@ -198,6 +269,15 @@ def test_submit_batch_prediction_job_with_bq_input(
198
269
)
199
270
200
271
assert job .gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
272
+ assert job .state == _TEST_JOB_STATE_RUNNING
273
+ assert not job .has_ended
274
+ assert not job .has_succeeded
275
+
276
+ job .refresh ()
277
+ assert job .state == _TEST_JOB_STATE_SUCCESS
278
+ assert job .has_ended
279
+ assert job .has_succeeded
280
+ assert job .output_location == _TEST_BQ_OUTPUT_PREFIX
201
281
202
282
expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat .BatchPredictionJob (
203
283
display_name = _TEST_DISPLAY_NAME ,
@@ -349,3 +429,35 @@ def test_submit_batch_prediction_job_without_output_uri_prefix_and_bucket(self):
349
429
source_model = _TEST_GEMINI_MODEL_NAME ,
350
430
input_dataset = _TEST_GCS_INPUT_URI ,
351
431
)
432
+
433
+ @pytest .mark .usefixtures ("create_batch_prediction_job_mock" )
434
+ def test_cancel_batch_prediction_job (self , cancel_batch_prediction_job_mock ):
435
+ job = batch_prediction .BatchPredictionJob .submit (
436
+ source_model = _TEST_GEMINI_MODEL_NAME ,
437
+ input_dataset = _TEST_GCS_INPUT_URI ,
438
+ output_uri_prefix = _TEST_GCS_OUTPUT_PREFIX ,
439
+ )
440
+ job .cancel ()
441
+
442
+ cancel_batch_prediction_job_mock .assert_called_once_with (
443
+ name = _TEST_BATCH_PREDICTION_JOB_NAME ,
444
+ )
445
+
446
+ @pytest .mark .usefixtures ("get_batch_prediction_job_with_gcs_output_mock" )
447
+ def test_delete_batch_prediction_job (self , delete_batch_prediction_job_mock ):
448
+ job = batch_prediction .BatchPredictionJob (_TEST_BATCH_PREDICTION_JOB_ID )
449
+ job .delete ()
450
+
451
+ delete_batch_prediction_job_mock .assert_called_once_with (
452
+ name = _TEST_BATCH_PREDICTION_JOB_NAME ,
453
+ )
454
+
455
+ def tes_list_batch_prediction_jobs (self , list_batch_prediction_jobs_mock ):
456
+ jobs = batch_prediction .BatchPredictionJob .list ()
457
+
458
+ assert len (jobs ) == 1
459
+ assert jobs [0 ].gca_resource == _TEST_GAPIC_BATCH_PREDICTION_JOB
460
+
461
+ list_batch_prediction_jobs_mock .assert_called_once_with (
462
+ request = {"parent" : _TEST_PARENT }
463
+ )
0 commit comments