Skip to content

Commit 7cbda03

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI - Tuning - Added support for BYOSA
PiperOrigin-RevId: 700223388
1 parent 598c931 commit 7cbda03

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

tests/unit/vertexai/tuning/test_tuning.py

+24
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,30 @@ def test_genai_tuning_service_encryption_spec(
250250
)
251251
assert sft_tuning_job.encryption_spec.kms_key_name == "test-key"
252252

253+
@mock.patch.object(
254+
target=tuning.TuningJob,
255+
attribute="client_class",
256+
new=MockTuningJobClientWithOverride,
257+
)
258+
@pytest.mark.parametrize(
259+
"supervised_tuning",
260+
[supervised_tuning, preview_supervised_tuning],
261+
)
262+
def test_genai_tuning_service_service_account(
263+
self, supervised_tuning: supervised_tuning
264+
):
265+
"""Test that the service account propagates to the tuning job."""
266+
vertexai.init(service_account="[email protected]")
267+
268+
sft_tuning_job = supervised_tuning.train(
269+
source_model="gemini-1.0-pro-002",
270+
train_dataset="gs://some-bucket/some_dataset.jsonl",
271+
)
272+
assert (
273+
sft_tuning_job.service_account
274+
275+
)
276+
253277
@mock.patch.object(
254278
target=tuning.TuningJob,
255279
attribute="client_class",

vertexai/tuning/_tuning.py

+8
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def experiment(self) -> Optional[aiplatform.Experiment]:
107107
def state(self) -> gca_types.JobState:
108108
return self._gca_resource.state
109109

110+
@property
111+
def service_account(self) -> Optional[str]:
112+
self._assert_gca_resource_is_available()
113+
return self._gca_resource.service_account
114+
110115
@property
111116
def has_ended(self):
112117
return self.state in jobs._JOB_COMPLETE_STATES
@@ -204,6 +209,9 @@ def _create(
204209
gca_tuning_job.encryption_spec.kms_key_name = (
205210
aiplatform_initializer.global_config.encryption_spec_key_name
206211
)
212+
gca_tuning_job.service_account = (
213+
aiplatform_initializer.global_config.service_account
214+
)
207215

208216
tuning_job: TuningJob = cls._construct_sdk_resource_from_gapic(
209217
gapic_resource=gca_tuning_job,

0 commit comments

Comments
 (0)