|
19 | 19 |
|
20 | 20 | import copy
|
21 | 21 | import datetime
|
| 22 | +import importlib |
22 | 23 | from typing import Dict, Iterable
|
23 | 24 | from unittest import mock
|
24 | 25 | import uuid
|
25 | 26 |
|
| 27 | +from google.cloud import aiplatform |
26 | 28 | import vertexai
|
27 | 29 | from google.cloud.aiplatform import compat
|
28 | 30 | from google.cloud.aiplatform import initializer
|
@@ -150,6 +152,10 @@ class TestgenerativeModelTuning:
|
150 | 152 | """Unit tests for generative model tuning."""
|
151 | 153 |
|
152 | 154 | def setup_method(self):
|
| 155 | + importlib.reload(initializer) |
| 156 | + importlib.reload(aiplatform) |
| 157 | + importlib.reload(vertexai) |
| 158 | + |
153 | 159 | vertexai.init(
|
154 | 160 | project=_TEST_PROJECT,
|
155 | 161 | location=_TEST_LOCATION,
|
@@ -197,3 +203,18 @@ def test_genai_tuning_service_supervised_tuning_tune_model(self):
|
197 | 203 | assert sft_tuning_job._experiment_name
|
198 | 204 | assert sft_tuning_job.tuned_model_name
|
199 | 205 | assert sft_tuning_job.tuned_model_endpoint_name
|
| 206 | + |
| 207 | + @mock.patch.object( |
| 208 | + target=tuning.TuningJob, |
| 209 | + attribute="client_class", |
| 210 | + new=MockTuningJobClientWithOverride, |
| 211 | + ) |
| 212 | + def test_genai_tuning_service_encryption_spec(self): |
| 213 | + """Test that the global encryption spec propagates to the tuning job.""" |
| 214 | + vertexai.init(encryption_spec_key_name="test-key") |
| 215 | + |
| 216 | + sft_tuning_job = supervised_tuning.train( |
| 217 | + source_model="gemini-1.0-pro-001", |
| 218 | + train_dataset="gs://some-bucket/some_dataset.jsonl", |
| 219 | + ) |
| 220 | + assert sft_tuning_job.encryption_spec.kms_key_name == "test-key" |
0 commit comments