Skip to content

Commit deba06b

Browse files
feat: add Service Account support to BatchPredictionJob
COPYBARA_INTEGRATE_REVIEW=#1872 from cymarechal-devoteam:feature/batch-prediction/service-account 4f015f3 PiperOrigin-RevId: 501301075
1 parent 369a0cc commit deba06b

File tree

6 files changed

+52
-6
lines changed

6 files changed

+52
-6
lines changed

README.rst

+4-3
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,11 @@ To create a batch prediction job:
359359
360360
batch_prediction_job = model.batch_predict(
361361
job_display_name='my-batch-prediction-job',
362-
instances_format='csv'
362+
instances_format='csv',
363363
machine_type='n1-standard-4',
364-
gcs_source=['gs://path/to/my/file.csv']
365-
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
364+
gcs_source=['gs://path/to/my/file.csv'],
365+
gcs_destination_prefix='gs://path/to/my/batch_prediction/results/',
366+
service_account='[email protected]'
366367
)
367368
368369
You can also create a batch prediction job asynchronously by including the `sync=False` argument:

docs/README.rst

+4-3
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,11 @@ To create a batch prediction job:
284284
285285
batch_prediction_job = model.batch_predict(
286286
job_display_name='my-batch-prediction-job',
287-
instances_format='csv'
287+
instances_format='csv',
288288
machine_type='n1-standard-4',
289-
gcs_source=['gs://path/to/my/file.csv']
290-
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
289+
gcs_source=['gs://path/to/my/file.csv'],
290+
gcs_destination_prefix='gs://path/to/my/batch_prediction/results/',
291+
service_account='[email protected]'
291292
)
292293
293294
You can also create a batch prediction job asynchronously by including the `sync=False` argument:

google/cloud/aiplatform/jobs.py

+7
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ def create(
403403
"aiplatform.model_monitoring.AlertConfig"
404404
] = None,
405405
analysis_instance_schema_uri: Optional[str] = None,
406+
service_account: Optional[str] = None,
406407
) -> "BatchPredictionJob":
407408
"""Create a batch prediction job.
408409
@@ -586,6 +587,9 @@ def create(
586587
and TFDV instance, this field can be used to override the schema.
587588
For models trained with Vertex AI, this field must be set as all the
588589
fields in predict instance formatted as string.
590+
service_account (str):
591+
Optional. Specifies the service account for workload run-as account.
592+
Users submitting jobs must have act-as permission on this run-as account.
589593
Returns:
590594
(jobs.BatchPredictionJob):
591595
Instantiated representation of the created batch prediction job.
@@ -745,6 +749,9 @@ def create(
745749
)
746750
gapic_batch_prediction_job.explanation_spec = explanation_spec
747751

752+
if service_account:
753+
gapic_batch_prediction_job.service_account = service_account
754+
748755
empty_batch_prediction_job = cls._empty_constructor(
749756
project=project,
750757
location=location,

google/cloud/aiplatform/models.py

+5
Original file line numberDiff line numberDiff line change
@@ -3511,6 +3511,7 @@ def batch_predict(
35113511
sync: bool = True,
35123512
create_request_timeout: Optional[float] = None,
35133513
batch_size: Optional[int] = None,
3514+
service_account: Optional[str] = None,
35143515
) -> jobs.BatchPredictionJob:
35153516
"""Creates a batch prediction job using this Model and outputs
35163517
prediction results to the provided destination prefix in the specified
@@ -3673,6 +3674,9 @@ def batch_predict(
36733674
but too high value will result in a whole batch not fitting in a machine's memory,
36743675
and the whole operation will fail.
36753676
The default value is 64.
3677+
service_account (str):
3678+
Optional. Specifies the service account for workload run-as account.
3679+
Users submitting jobs must have act-as permission on this run-as account.
36763680
36773681
Returns:
36783682
job (jobs.BatchPredictionJob):
@@ -3705,6 +3709,7 @@ def batch_predict(
37053709
encryption_spec_key_name=encryption_spec_key_name,
37063710
sync=sync,
37073711
create_request_timeout=create_request_timeout,
3712+
service_account=service_account,
37083713
)
37093714

37103715
@classmethod

tests/unit/aiplatform/test_jobs.py

+23
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
_TEST_BQ_JOB_ID = "123459876"
7777
_TEST_BQ_MAX_RESULTS = 100
7878
_TEST_GCS_BUCKET_NAME = "my-bucket"
79+
_TEST_SERVICE_ACCOUNT = "[email protected]"
80+
7981

8082
_TEST_BQ_PATH = f"bq://{_TEST_BQ_PROJECT_ID}.{_TEST_BQ_DATASET_ID}"
8183
_TEST_GCS_BUCKET_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}"
@@ -719,6 +721,7 @@ def test_batch_predict_gcs_source_and_dest(
719721
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
720722
sync=sync,
721723
create_request_timeout=None,
724+
service_account=_TEST_SERVICE_ACCOUNT,
722725
)
723726

724727
batch_prediction_job.wait_for_resource_creation()
@@ -741,6 +744,7 @@ def test_batch_predict_gcs_source_and_dest(
741744
),
742745
predictions_format="jsonl",
743746
),
747+
service_account=_TEST_SERVICE_ACCOUNT,
744748
)
745749

746750
create_batch_prediction_job_mock.assert_called_once_with(
@@ -766,6 +770,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout(
766770
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
767771
sync=sync,
768772
create_request_timeout=180.0,
773+
service_account=_TEST_SERVICE_ACCOUNT,
769774
)
770775

771776
batch_prediction_job.wait_for_resource_creation()
@@ -788,6 +793,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout(
788793
),
789794
predictions_format="jsonl",
790795
),
796+
service_account=_TEST_SERVICE_ACCOUNT,
791797
)
792798

793799
create_batch_prediction_job_mock.assert_called_once_with(
@@ -812,6 +818,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set(
812818
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
813819
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
814820
sync=sync,
821+
service_account=_TEST_SERVICE_ACCOUNT,
815822
)
816823

817824
batch_prediction_job.wait_for_resource_creation()
@@ -834,6 +841,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set(
834841
),
835842
predictions_format="jsonl",
836843
),
844+
service_account=_TEST_SERVICE_ACCOUNT,
837845
)
838846

839847
create_batch_prediction_job_mock.assert_called_once_with(
@@ -855,6 +863,7 @@ def test_batch_predict_job_done_create(self, create_batch_prediction_job_mock):
855863
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
856864
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
857865
sync=False,
866+
service_account=_TEST_SERVICE_ACCOUNT,
858867
)
859868

860869
batch_prediction_job.wait_for_resource_creation()
@@ -881,6 +890,7 @@ def test_batch_predict_gcs_source_bq_dest(
881890
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
882891
sync=sync,
883892
create_request_timeout=None,
893+
service_account=_TEST_SERVICE_ACCOUNT,
884894
)
885895

886896
batch_prediction_job.wait_for_resource_creation()
@@ -908,6 +918,7 @@ def test_batch_predict_gcs_source_bq_dest(
908918
),
909919
predictions_format="bigquery",
910920
),
921+
service_account=_TEST_SERVICE_ACCOUNT,
911922
)
912923

913924
create_batch_prediction_job_mock.assert_called_once_with(
@@ -946,6 +957,7 @@ def test_batch_predict_with_all_args(
946957
sync=sync,
947958
create_request_timeout=None,
948959
batch_size=_TEST_BATCH_SIZE,
960+
service_account=_TEST_SERVICE_ACCOUNT,
949961
)
950962

951963
batch_prediction_job.wait_for_resource_creation()
@@ -986,6 +998,7 @@ def test_batch_predict_with_all_args(
986998
parameters=_TEST_EXPLANATION_PARAMETERS,
987999
),
9881000
labels=_TEST_LABEL,
1001+
service_account=_TEST_SERVICE_ACCOUNT,
9891002
)
9901003

9911004
create_batch_prediction_job_with_explanations_mock.assert_called_once_with(
@@ -1047,6 +1060,7 @@ def test_batch_predict_with_all_args_and_model_monitoring(
10471060
model_monitoring_objective_config=mm_obj_cfg,
10481061
model_monitoring_alert_config=mm_alert_cfg,
10491062
analysis_instance_schema_uri="",
1063+
service_account=_TEST_SERVICE_ACCOUNT,
10501064
)
10511065

10521066
batch_prediction_job.wait_for_resource_creation()
@@ -1086,6 +1100,7 @@ def test_batch_predict_with_all_args_and_model_monitoring(
10861100
generate_explanation=True,
10871101
model_monitoring_config=_TEST_MODEL_MONITORING_CFG,
10881102
labels=_TEST_LABEL,
1103+
service_account=_TEST_SERVICE_ACCOUNT,
10891104
)
10901105
create_batch_prediction_job_v1beta1_mock.assert_called_once_with(
10911106
parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
@@ -1103,6 +1118,7 @@ def test_batch_predict_create_fails(self):
11031118
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
11041119
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
11051120
sync=False,
1121+
service_account=_TEST_SERVICE_ACCOUNT,
11061122
)
11071123

11081124
with pytest.raises(RuntimeError) as e:
@@ -1143,6 +1159,7 @@ def test_batch_predict_no_source(self, create_batch_prediction_job_mock):
11431159
model_name=_TEST_MODEL_NAME,
11441160
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
11451161
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
1162+
service_account=_TEST_SERVICE_ACCOUNT,
11461163
)
11471164

11481165
assert e.match(regexp=r"source")
@@ -1159,6 +1176,7 @@ def test_batch_predict_two_sources(self, create_batch_prediction_job_mock):
11591176
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
11601177
bigquery_source=_TEST_BATCH_PREDICTION_BQ_PREFIX,
11611178
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
1179+
service_account=_TEST_SERVICE_ACCOUNT,
11621180
)
11631181

11641182
assert e.match(regexp=r"source")
@@ -1173,6 +1191,7 @@ def test_batch_predict_no_destination(self):
11731191
model_name=_TEST_MODEL_NAME,
11741192
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
11751193
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
1194+
service_account=_TEST_SERVICE_ACCOUNT,
11761195
)
11771196

11781197
assert e.match(regexp=r"destination")
@@ -1189,6 +1208,7 @@ def test_batch_predict_wrong_instance_format(self):
11891208
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
11901209
instances_format="wrong",
11911210
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
1211+
service_account=_TEST_SERVICE_ACCOUNT,
11921212
)
11931213

11941214
assert e.match(regexp=r"accepted instances format")
@@ -1205,6 +1225,7 @@ def test_batch_predict_wrong_prediction_format(self):
12051225
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
12061226
predictions_format="wrong",
12071227
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
1228+
service_account=_TEST_SERVICE_ACCOUNT,
12081229
)
12091230

12101231
assert e.match(regexp=r"accepted prediction format")
@@ -1222,6 +1243,7 @@ def test_batch_predict_job_with_versioned_model(
12221243
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
12231244
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
12241245
sync=True,
1246+
service_account=_TEST_SERVICE_ACCOUNT,
12251247
)
12261248
assert (
12271249
create_batch_prediction_job_mock.call_args_list[0][1][
@@ -1237,6 +1259,7 @@ def test_batch_predict_job_with_versioned_model(
12371259
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
12381260
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
12391261
sync=True,
1262+
service_account=_TEST_SERVICE_ACCOUNT,
12401263
)
12411264
assert (
12421265
create_batch_prediction_job_mock.call_args_list[0][1][

tests/unit/aiplatform/test_models.py

+9
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
16441644
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
16451645
sync=sync,
16461646
create_request_timeout=None,
1647+
service_account=_TEST_SERVICE_ACCOUNT,
16471648
)
16481649

16491650
if not sync:
@@ -1669,6 +1670,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
16691670
predictions_format="jsonl",
16701671
),
16711672
encryption_spec=_TEST_ENCRYPTION_SPEC,
1673+
service_account=_TEST_SERVICE_ACCOUNT,
16721674
)
16731675
)
16741676

@@ -1693,6 +1695,7 @@ def test_batch_predict_gcs_source_and_dest(
16931695
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
16941696
sync=sync,
16951697
create_request_timeout=None,
1698+
service_account=_TEST_SERVICE_ACCOUNT,
16961699
)
16971700

16981701
if not sync:
@@ -1711,6 +1714,7 @@ def test_batch_predict_with_version(self, sync, create_batch_prediction_job_mock
17111714
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
17121715
sync=sync,
17131716
create_request_timeout=None,
1717+
service_account=_TEST_SERVICE_ACCOUNT,
17141718
)
17151719

17161720
if not sync:
@@ -1733,6 +1737,7 @@ def test_batch_predict_with_version(self, sync, create_batch_prediction_job_mock
17331737
),
17341738
predictions_format="jsonl",
17351739
),
1740+
service_account=_TEST_SERVICE_ACCOUNT,
17361741
)
17371742
)
17381743

@@ -1757,6 +1762,7 @@ def test_batch_predict_gcs_source_bq_dest(
17571762
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX,
17581763
sync=sync,
17591764
create_request_timeout=None,
1765+
service_account=_TEST_SERVICE_ACCOUNT,
17601766
)
17611767

17621768
if not sync:
@@ -1781,6 +1787,7 @@ def test_batch_predict_gcs_source_bq_dest(
17811787
),
17821788
predictions_format="bigquery",
17831789
),
1790+
service_account=_TEST_SERVICE_ACCOUNT,
17841791
)
17851792
)
17861793

@@ -1817,6 +1824,7 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn
18171824
sync=sync,
18181825
create_request_timeout=None,
18191826
batch_size=_TEST_BATCH_SIZE,
1827+
service_account=_TEST_SERVICE_ACCOUNT,
18201828
)
18211829

18221830
if not sync:
@@ -1857,6 +1865,7 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn
18571865
),
18581866
labels=_TEST_LABEL,
18591867
encryption_spec=_TEST_ENCRYPTION_SPEC,
1868+
service_account=_TEST_SERVICE_ACCOUNT,
18601869
)
18611870

18621871
create_batch_prediction_job_mock.assert_called_once_with(

0 commit comments

Comments
 (0)