Skip to content

Commit eb651bc

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Tuning - Added support for CMEK
PiperOrigin-RevId: 642696819
1 parent e832a8a commit eb651bc

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

tests/unit/vertexai/test_tuning.py

+21
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919

2020
import copy
2121
import datetime
22+
import importlib
2223
from typing import Dict, Iterable
2324
from unittest import mock
2425
import uuid
2526

27+
from google.cloud import aiplatform
2628
import vertexai
2729
from google.cloud.aiplatform import compat
2830
from google.cloud.aiplatform import initializer
@@ -150,6 +152,10 @@ class TestgenerativeModelTuning:
150152
"""Unit tests for generative model tuning."""
151153

152154
def setup_method(self):
155+
importlib.reload(initializer)
156+
importlib.reload(aiplatform)
157+
importlib.reload(vertexai)
158+
153159
vertexai.init(
154160
project=_TEST_PROJECT,
155161
location=_TEST_LOCATION,
@@ -197,3 +203,18 @@ def test_genai_tuning_service_supervised_tuning_tune_model(self):
197203
assert sft_tuning_job._experiment_name
198204
assert sft_tuning_job.tuned_model_name
199205
assert sft_tuning_job.tuned_model_endpoint_name
206+
207+
@mock.patch.object(
208+
target=tuning.TuningJob,
209+
attribute="client_class",
210+
new=MockTuningJobClientWithOverride,
211+
)
212+
def test_genai_tuning_service_encryption_spec(self):
213+
"""Test that the global encryption spec propagates to the tuning job."""
214+
vertexai.init(encryption_spec_key_name="test-key")
215+
216+
sft_tuning_job = supervised_tuning.train(
217+
source_model="gemini-1.0-pro-001",
218+
train_dataset="gs://some-bucket/some_dataset.jsonl",
219+
)
220+
assert sft_tuning_job.encryption_spec.kms_key_name == "test-key"

vertexai/tuning/_tuning.py

+5
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ def _create(
195195
else:
196196
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")
197197

198+
if aiplatform_initializer.global_config.encryption_spec_key_name:
199+
gca_tuning_job.encryption_spec.kms_key_name = (
200+
aiplatform_initializer.global_config.encryption_spec_key_name
201+
)
202+
198203
tuning_job: TuningJob = cls._construct_sdk_resource_from_gapic(
199204
gapic_resource=gca_tuning_job,
200205
project=project,

0 commit comments

Comments
 (0)