File tree 2 files changed +32
-0
lines changed
tests/unit/vertexai/tuning
2 files changed +32
-0
lines changed Original file line number Diff line number Diff line change @@ -250,6 +250,30 @@ def test_genai_tuning_service_encryption_spec(
250
250
)
251
251
assert sft_tuning_job .encryption_spec .kms_key_name == "test-key"
252
252
253
+ @mock .patch .object (
254
+ target = tuning .TuningJob ,
255
+ attribute = "client_class" ,
256
+ new = MockTuningJobClientWithOverride ,
257
+ )
258
+ @pytest .mark .parametrize (
259
+ "supervised_tuning" ,
260
+ [supervised_tuning , preview_supervised_tuning ],
261
+ )
262
+ def test_genai_tuning_service_service_account (
263
+ self , supervised_tuning : supervised_tuning
264
+ ):
265
+ """Test that the service account propagates to the tuning job."""
266
+ vertexai .
init (
service_account = "[email protected] " )
267
+
268
+ sft_tuning_job = supervised_tuning .train (
269
+ source_model = "gemini-1.0-pro-002" ,
270
+ train_dataset = "gs://some-bucket/some_dataset.jsonl" ,
271
+ )
272
+ assert (
273
+ sft_tuning_job .service_account
274
+
275
+ )
276
+
253
277
@mock .patch .object (
254
278
target = tuning .TuningJob ,
255
279
attribute = "client_class" ,
Original file line number Diff line number Diff line change @@ -107,6 +107,11 @@ def experiment(self) -> Optional[aiplatform.Experiment]:
107
107
def state (self ) -> gca_types .JobState :
108
108
return self ._gca_resource .state
109
109
110
+ @property
111
+ def service_account (self ) -> Optional [str ]:
112
+ self ._assert_gca_resource_is_available ()
113
+ return self ._gca_resource .service_account
114
+
110
115
@property
111
116
def has_ended (self ):
112
117
return self .state in jobs ._JOB_COMPLETE_STATES
@@ -204,6 +209,9 @@ def _create(
204
209
gca_tuning_job .encryption_spec .kms_key_name = (
205
210
aiplatform_initializer .global_config .encryption_spec_key_name
206
211
)
212
+ gca_tuning_job .service_account = (
213
+ aiplatform_initializer .global_config .service_account
214
+ )
207
215
208
216
tuning_job : TuningJob = cls ._construct_sdk_resource_from_gapic (
209
217
gapic_resource = gca_tuning_job ,
You can’t perform that action at this time.
0 commit comments