@@ -2215,6 +2215,61 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
2215
2215
timeout=None,
2216
2216
)
2217
2217
2218
+ @pytest.mark.parametrize("sync", [True, False])
2219
+ @pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock")
2220
+ def test_init_aiplatform_with_service_account_and_batch_predict_gcs_source_and_dest(
2221
+ self, create_batch_prediction_job_mock, sync
2222
+ ):
2223
+ aiplatform.init(
2224
+ project=_TEST_PROJECT,
2225
+ location=_TEST_LOCATION,
2226
+ encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
2227
+ service_account=_TEST_SERVICE_ACCOUNT,
2228
+ )
2229
+ test_model = models.Model(_TEST_ID)
2230
+
2231
+ # Make SDK batch_predict method call
2232
+ batch_prediction_job = test_model.batch_predict(
2233
+ job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
2234
+ gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
2235
+ gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
2236
+ sync=sync,
2237
+ create_request_timeout=None,
2238
+ )
2239
+
2240
+ if not sync:
2241
+ batch_prediction_job.wait()
2242
+
2243
+ # Construct expected request
2244
+ expected_gapic_batch_prediction_job = (
2245
+ gca_batch_prediction_job.BatchPredictionJob(
2246
+ display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
2247
+ model=model_service_client.ModelServiceClient.model_path(
2248
+ _TEST_PROJECT, _TEST_LOCATION, _TEST_ID
2249
+ ),
2250
+ input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
2251
+ instances_format="jsonl",
2252
+ gcs_source=gca_io.GcsSource(
2253
+ uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
2254
+ ),
2255
+ ),
2256
+ output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
2257
+ gcs_destination=gca_io.GcsDestination(
2258
+ output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
2259
+ ),
2260
+ predictions_format="jsonl",
2261
+ ),
2262
+ encryption_spec=_TEST_ENCRYPTION_SPEC,
2263
+ service_account=_TEST_SERVICE_ACCOUNT,
2264
+ )
2265
+ )
2266
+
2267
+ create_batch_prediction_job_mock.assert_called_once_with(
2268
+ parent=_TEST_PARENT,
2269
+ batch_prediction_job=expected_gapic_batch_prediction_job,
2270
+ timeout=None,
2271
+ )
2272
+
2218
2273
@pytest.mark.parametrize("sync", [True, False])
2219
2274
@pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock")
2220
2275
def test_batch_predict_gcs_source_and_dest(
0 commit comments