Skip to content

Commit d11b8e6

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: Allow setting default service account
PiperOrigin-RevId: 559266585
1 parent 7eaa1d4 commit d11b8e6

15 files changed

+102
-7
lines changed

google/cloud/aiplatform/initializer.py

+15
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(self):
9898
self._credentials = None
9999
self._encryption_spec_key_name = None
100100
self._network = None
101+
self._service_account = None
101102

102103
def init(
103104
self,
@@ -113,6 +114,7 @@ def init(
113114
credentials: Optional[auth_credentials.Credentials] = None,
114115
encryption_spec_key_name: Optional[str] = None,
115116
network: Optional[str] = None,
117+
service_account: Optional[str] = None,
116118
):
117119
"""Updates common initialization parameters with provided options.
118120

@@ -155,6 +157,12 @@ def init(
155157
Private services access must already be configured for the network.
156158
If specified, all eligible jobs and resources created will be peered
157159
with this VPC.
160+
service_account (str):
161+
Optional. The service account used to launch jobs and deploy models.
162+
Jobs that use service_account: BatchPredictionJob, CustomJob,
163+
PipelineJob, HyperparameterTuningJob, CustomTrainingJob,
164+
CustomPythonPackageTrainingJob, CustomContainerTrainingJob,
165+
ModelEvaluationJob.
158166
Raises:
159167
ValueError:
160168
If experiment_description is provided but experiment is not.
@@ -194,6 +202,8 @@ def init(
194202
self._encryption_spec_key_name = encryption_spec_key_name
195203
if network is not None:
196204
self._network = network
205+
if service_account is not None:
206+
self._service_account = service_account
197207

198208
if experiment:
199209
metadata._experiment_tracker.set_experiment(
@@ -297,6 +307,11 @@ def network(self) -> Optional[str]:
297307
"""Default Compute Engine network to peer to, if provided."""
298308
return self._network
299309

310+
@property
311+
def service_account(self) -> Optional[str]:
312+
"""Default service account, if provided."""
313+
return self._service_account
314+
300315
@property
301316
def experiment_name(self) -> Optional[str]:
302317
"""Default experiment name, if provided."""

google/cloud/aiplatform/jobs.py

+5
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ def create(
761761
)
762762
gapic_batch_prediction_job.explanation_spec = explanation_spec
763763

764+
service_account = service_account or initializer.global_config.service_account
764765
if service_account:
765766
gapic_batch_prediction_job.service_account = service_account
766767

@@ -1693,6 +1694,7 @@ def run(
16931694
`restart_job_on_worker_restart` to False.
16941695
"""
16951696
network = network or initializer.global_config.network
1697+
service_account = service_account or initializer.global_config.service_account
16961698

16971699
self._run(
16981700
service_account=service_account,
@@ -1880,6 +1882,8 @@ def submit(
18801882
raise ValueError(
18811883
"'experiment' is required since you've enabled autolog in 'from_local_script'."
18821884
)
1885+
1886+
service_account = service_account or initializer.global_config.service_account
18831887
if service_account:
18841888
self._gca_resource.job_spec.service_account = service_account
18851889

@@ -2356,6 +2360,7 @@ def run(
23562360
`restart_job_on_worker_restart` to False.
23572361
"""
23582362
network = network or initializer.global_config.network
2363+
service_account = service_account or initializer.global_config.service_account
23592364

23602365
self._run(
23612366
service_account=service_account,

google/cloud/aiplatform/model_evaluation/model_evaluation_job.py

+1
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def submit(
278278
Returns:
279279
(ModelEvaluationJob): Instantiated represnetation of the model evaluation job.
280280
"""
281+
service_account = service_account or initializer.global_config.service_account
281282

282283
if isinstance(model_name, aiplatform.Model):
283284
model_resource_name = model_name.versioned_resource_name

google/cloud/aiplatform/models.py

+3
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,7 @@ def _deploy_call(
10961096
to the resource project.
10971097
Users deploying the Model must have the `iam.serviceAccounts.actAs`
10981098
permission on this service account.
1099+
If not specified, uses the service account set in aiplatform.init.
10991100
explanation_spec (aiplatform.explain.ExplanationSpec):
11001101
Optional. Specification of Model explanation.
11011102
metadata (Sequence[Tuple[str, str]]):
@@ -1120,6 +1121,8 @@ def _deploy_call(
11201121
is not 0 or 100.
11211122
"""
11221123

1124+
service_account = service_account or initializer.global_config.service_account
1125+
11231126
max_replica_count = max(min_replica_count, max_replica_count)
11241127

11251128
if bool(accelerator_type) != bool(accelerator_count):

google/cloud/aiplatform/pipeline_job_schedules.py

+1
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def _create(
226226
if max_concurrent_run_count:
227227
self._gca_resource.max_concurrent_run_count = max_concurrent_run_count
228228

229+
service_account = service_account or initializer.global_config.service_account
229230
network = network or initializer.global_config.network
230231

231232
if service_account:

google/cloud/aiplatform/pipeline_jobs.py

+1
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ def submit(
382382
current Experiment Run.
383383
"""
384384
network = network or initializer.global_config.network
385+
service_account = service_account or initializer.global_config.service_account
385386

386387
if service_account:
387388
self._gca_resource.service_account = service_account

google/cloud/aiplatform/training_jobs.py

+4
Original file line numberDiff line numberDiff line change
@@ -3223,6 +3223,7 @@ def run(
32233223
produce a Vertex AI Model.
32243224
"""
32253225
network = network or initializer.global_config.network
3226+
service_account = service_account or initializer.global_config.service_account
32263227

32273228
worker_pool_specs, managed_model = self._prepare_and_validate_run(
32283229
model_display_name=model_display_name,
@@ -4579,6 +4580,7 @@ def run(
45794580
were not provided in constructor.
45804581
"""
45814582
network = network or initializer.global_config.network
4583+
service_account = service_account or initializer.global_config.service_account
45824584

45834585
worker_pool_specs, managed_model = self._prepare_and_validate_run(
45844586
model_display_name=model_display_name,
@@ -7348,6 +7350,7 @@ def run(
73487350
service_account (str):
73497351
Specifies the service account for workload run-as account.
73507352
Users submitting jobs must have act-as permission on this run-as account.
7353+
If not specified, uses the service account set in aiplatform.init.
73517354
network (str):
73527355
The full name of the Compute Engine network to which the job
73537356
should be peered. For example, projects/12345/global/networks/myVPC.
@@ -7501,6 +7504,7 @@ def run(
75017504
produce a Vertex AI Model.
75027505
"""
75037506
network = network or initializer.global_config.network
7507+
service_account = service_account or initializer.global_config.service_account
75047508

75057509
worker_pool_specs, managed_model = self._prepare_and_validate_run(
75067510
model_display_name=model_display_name,

google/cloud/aiplatform/utils/gcs_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
217217
"""
218218
project = project or initializer.global_config.project
219219
location = location or initializer.global_config.location
220+
service_account = service_account or initializer.global_config.service_account
220221
credentials = credentials or initializer.global_config.credentials
221222

222223
output_artifacts_gcs_dir = (

samples/model-builder/init_sample.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def init_sample(
2525
staging_bucket: Optional[str] = None,
2626
credentials: Optional[auth_credentials.Credentials] = None,
2727
encryption_spec_key_name: Optional[str] = None,
28+
service_account: Optional[str] = None,
2829
):
2930

3031
from google.cloud import aiplatform
@@ -36,6 +37,7 @@ def init_sample(
3637
staging_bucket=staging_bucket,
3738
credentials=credentials,
3839
encryption_spec_key_name=encryption_spec_key_name,
40+
service_account=service_account,
3941
)
4042

4143

samples/model-builder/init_sample_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def test_init_sample(mock_sdk_init):
2626
staging_bucket=constants.STAGING_BUCKET,
2727
credentials=constants.CREDENTIALS,
2828
encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME,
29+
service_account=constants.SERVICE_ACCOUNT,
2930
)
3031

3132
mock_sdk_init.assert_called_once_with(
@@ -35,4 +36,5 @@ def test_init_sample(mock_sdk_init):
3536
staging_bucket=constants.STAGING_BUCKET,
3637
credentials=constants.CREDENTIALS,
3738
encryption_spec_key_name=constants.ENCRYPTION_SPEC_KEY_NAME,
39+
service_account=constants.SERVICE_ACCOUNT,
3840
)

tests/unit/aiplatform/test_custom_job.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy
406406
location=_TEST_LOCATION,
407407
staging_bucket=_TEST_STAGING_BUCKET,
408408
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
409+
network=_TEST_NETWORK,
410+
service_account=_TEST_SERVICE_ACCOUNT,
409411
)
410412

411413
job = aiplatform.CustomJob(
@@ -416,8 +418,6 @@ def test_create_custom_job(self, create_custom_job_mock, get_custom_job_mock, sy
416418
)
417419

418420
job.run(
419-
service_account=_TEST_SERVICE_ACCOUNT,
420-
network=_TEST_NETWORK,
421421
timeout=_TEST_TIMEOUT,
422422
restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART,
423423
sync=sync,

tests/unit/aiplatform/test_initializer.py

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
_TEST_DESCRIPTION = "test-description"
4545
_TEST_STAGING_BUCKET = "test-bucket"
4646
_TEST_NETWORK = "projects/12345/global/networks/myVPC"
47+
_TEST_SERVICE_ACCOUNT = "[email protected]"
4748

4849
# tensorboard
4950
_TEST_TENSORBOARD_ID = "1028944691210842416"
@@ -105,6 +106,10 @@ def test_init_network_sets_network(self):
105106
initializer.global_config.init(network=_TEST_NETWORK)
106107
assert initializer.global_config.network == _TEST_NETWORK
107108

109+
def test_init_service_account_sets_service_account(self):
110+
initializer.global_config.init(service_account=_TEST_SERVICE_ACCOUNT)
111+
assert initializer.global_config.service_account == _TEST_SERVICE_ACCOUNT
112+
108113
@patch.object(_experiment_tracker, "set_experiment")
109114
def test_init_experiment_sets_experiment(self, set_experiment_mock):
110115
initializer.global_config.init(experiment=_TEST_EXPERIMENT)

tests/unit/aiplatform/test_models.py

+55
Original file line numberDiff line numberDiff line change
@@ -2215,6 +2215,61 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a
22152215
timeout=None,
22162216
)
22172217

2218+
@pytest.mark.parametrize("sync", [True, False])
2219+
@pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock")
2220+
def test_init_aiplatform_with_service_account_and_batch_predict_gcs_source_and_dest(
2221+
self, create_batch_prediction_job_mock, sync
2222+
):
2223+
aiplatform.init(
2224+
project=_TEST_PROJECT,
2225+
location=_TEST_LOCATION,
2226+
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
2227+
service_account=_TEST_SERVICE_ACCOUNT,
2228+
)
2229+
test_model = models.Model(_TEST_ID)
2230+
2231+
# Make SDK batch_predict method call
2232+
batch_prediction_job = test_model.batch_predict(
2233+
job_display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
2234+
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
2235+
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
2236+
sync=sync,
2237+
create_request_timeout=None,
2238+
)
2239+
2240+
if not sync:
2241+
batch_prediction_job.wait()
2242+
2243+
# Construct expected request
2244+
expected_gapic_batch_prediction_job = (
2245+
gca_batch_prediction_job.BatchPredictionJob(
2246+
display_name=_TEST_BATCH_PREDICTION_DISPLAY_NAME,
2247+
model=model_service_client.ModelServiceClient.model_path(
2248+
_TEST_PROJECT, _TEST_LOCATION, _TEST_ID
2249+
),
2250+
input_config=gca_batch_prediction_job.BatchPredictionJob.InputConfig(
2251+
instances_format="jsonl",
2252+
gcs_source=gca_io.GcsSource(
2253+
uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
2254+
),
2255+
),
2256+
output_config=gca_batch_prediction_job.BatchPredictionJob.OutputConfig(
2257+
gcs_destination=gca_io.GcsDestination(
2258+
output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
2259+
),
2260+
predictions_format="jsonl",
2261+
),
2262+
encryption_spec=_TEST_ENCRYPTION_SPEC,
2263+
service_account=_TEST_SERVICE_ACCOUNT,
2264+
)
2265+
)
2266+
2267+
create_batch_prediction_job_mock.assert_called_once_with(
2268+
parent=_TEST_PARENT,
2269+
batch_prediction_job=expected_gapic_batch_prediction_job,
2270+
timeout=None,
2271+
)
2272+
22182273
@pytest.mark.parametrize("sync", [True, False])
22192274
@pytest.mark.usefixtures("get_model_mock", "get_batch_prediction_job_mock")
22202275
def test_batch_predict_gcs_source_and_dest(

tests/unit/aiplatform/test_pipeline_jobs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,8 @@ def test_run_call_pipeline_service_create(
485485
staging_bucket=_TEST_GCS_BUCKET_NAME,
486486
location=_TEST_LOCATION,
487487
credentials=_TEST_CREDENTIALS,
488+
service_account=_TEST_SERVICE_ACCOUNT,
489+
network=_TEST_NETWORK,
488490
)
489491

490492
job = pipeline_jobs.PipelineJob(
@@ -497,8 +499,6 @@ def test_run_call_pipeline_service_create(
497499
)
498500

499501
job.run(
500-
service_account=_TEST_SERVICE_ACCOUNT,
501-
network=_TEST_NETWORK,
502502
sync=sync,
503503
create_request_timeout=None,
504504
)

tests/unit/aiplatform/test_training_jobs.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
10561056
project=_TEST_PROJECT,
10571057
staging_bucket=_TEST_BUCKET_NAME,
10581058
credentials=_TEST_CREDENTIALS,
1059+
service_account=_TEST_SERVICE_ACCOUNT,
10591060
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
10601061
)
10611062

@@ -1082,7 +1083,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
10821083
model_from_job = job.run(
10831084
dataset=mock_tabular_dataset,
10841085
base_output_dir=_TEST_BASE_OUTPUT_DIR,
1085-
service_account=_TEST_SERVICE_ACCOUNT,
10861086
network=_TEST_NETWORK,
10871087
args=_TEST_RUN_ARGS,
10881088
environment_variables=_TEST_ENVIRONMENT_VARIABLES,
@@ -3181,6 +3181,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
31813181
aiplatform.init(
31823182
project=_TEST_PROJECT,
31833183
staging_bucket=_TEST_BUCKET_NAME,
3184+
service_account=_TEST_SERVICE_ACCOUNT,
31843185
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
31853186
)
31863187

@@ -3215,7 +3216,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
32153216
model_display_name=_TEST_MODEL_DISPLAY_NAME,
32163217
model_labels=_TEST_MODEL_LABELS,
32173218
predefined_split_column_name=_TEST_PREDEFINED_SPLIT_COLUMN_NAME,
3218-
service_account=_TEST_SERVICE_ACCOUNT,
32193219
tensorboard=_TEST_TENSORBOARD_RESOURCE_NAME,
32203220
sync=sync,
32213221
create_request_timeout=None,
@@ -5242,6 +5242,7 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
52425242
aiplatform.init(
52435243
project=_TEST_PROJECT,
52445244
staging_bucket=_TEST_BUCKET_NAME,
5245+
service_account=_TEST_SERVICE_ACCOUNT,
52455246
encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME,
52465247
)
52475248

@@ -5271,7 +5272,6 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
52715272
model_display_name=_TEST_MODEL_DISPLAY_NAME,
52725273
model_labels=_TEST_MODEL_LABELS,
52735274
base_output_dir=_TEST_BASE_OUTPUT_DIR,
5274-
service_account=_TEST_SERVICE_ACCOUNT,
52755275
network=_TEST_NETWORK,
52765276
args=_TEST_RUN_ARGS,
52775277
environment_variables=_TEST_ENVIRONMENT_VARIABLES,

0 commit comments

Comments
 (0)