Skip to content

Commit 036d2d0

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Added support for supervised fine-tuning
PiperOrigin-RevId: 621984253
1 parent a2778ba commit 036d2d0

File tree

6 files changed

+553
-3
lines changed

6 files changed

+553
-3
lines changed

tests/unit/vertexai/test_tuning.py

+189
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""Unit tests for generative model tuning."""
18+
# pylint: disable=protected-access,bad-continuation
19+
20+
import copy
21+
import datetime
22+
from typing import Dict, Iterable
23+
from unittest import mock
24+
import uuid
25+
26+
import vertexai
27+
from google.cloud.aiplatform import compat
28+
from google.cloud.aiplatform import initializer
29+
from google.cloud.aiplatform import utils as aiplatform_utils
30+
from google.cloud.aiplatform_v1.services import gen_ai_tuning_service
31+
from google.cloud.aiplatform_v1.types import job_state
32+
from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job
33+
from vertexai.preview import tuning
34+
from vertexai.preview.tuning import sft as supervised_tuning
35+
36+
import pytest
37+
38+
from google.rpc import status_pb2
39+
40+
41+
_TEST_PROJECT = "test-project"
42+
_TEST_LOCATION = "us-central1"
43+
44+
45+
_global_tuning_jobs: Dict[str, gca_tuning_job.TuningJob] = {}
46+
47+
48+
class MockGenAiTuningServiceClient(gen_ai_tuning_service.GenAiTuningServiceClient):
49+
@property
50+
def _tuning_jobs(self) -> Dict[str, gca_tuning_job.TuningJob]:
51+
return _global_tuning_jobs
52+
53+
def create_tuning_job(
54+
self,
55+
*,
56+
parent: str,
57+
tuning_job: gca_tuning_job.TuningJob,
58+
**_,
59+
) -> gca_tuning_job.TuningJob:
60+
tuning_job = copy.deepcopy(tuning_job)
61+
resource_id = uuid.uuid4().hex
62+
resource_name = f"{parent}/tuningJobs/{resource_id}"
63+
tuning_job.name = resource_name
64+
current_time = datetime.datetime.now(datetime.timezone.utc)
65+
tuning_job.create_time = current_time
66+
tuning_job.update_time = current_time
67+
tuning_job.state = job_state.JobState.JOB_STATE_PENDING
68+
self._tuning_jobs[resource_name] = tuning_job
69+
return tuning_job
70+
71+
def _progress_tuning_job(self, name: str):
72+
tuning_job: gca_tuning_job.TuningJob = self._tuning_jobs[name]
73+
current_time = datetime.datetime.now(datetime.timezone.utc)
74+
if tuning_job.state == job_state.JobState.JOB_STATE_PENDING:
75+
if (
76+
"invalid_dataset"
77+
in tuning_job.supervised_tuning_spec.training_dataset_uri
78+
):
79+
tuning_job.state = job_state.JobState.JOB_STATE_FAILED
80+
tuning_job.error = status_pb2.Status(
81+
code=400, message="Invalid dataset."
82+
)
83+
else:
84+
tuning_job.state = job_state.JobState.JOB_STATE_RUNNING
85+
tuning_job.update_time = current_time
86+
elif tuning_job.state == job_state.JobState.JOB_STATE_RUNNING:
87+
parent = tuning_job.name.partition("/tuningJobs/")[0]
88+
tuning_job.state = job_state.JobState.JOB_STATE_SUCCEEDED
89+
experiment_id = uuid.uuid4().hex
90+
tuned_model_id = uuid.uuid4().hex
91+
tuned_model_endpoint_id = uuid.uuid4().hex
92+
tuning_job.experiment = (
93+
f"{parent}/metadataStores/default/contexts/{experiment_id}"
94+
)
95+
tuning_job.tuned_model = gca_tuning_job.TunedModel(
96+
model=f"{parent}/models/{tuned_model_id}",
97+
endpoint=f"{parent}/endpoints/{tuned_model_endpoint_id}",
98+
)
99+
tuning_job.end_time = current_time
100+
tuning_job.update_time = current_time
101+
else:
102+
pass
103+
104+
def get_tuning_job(self, *, name: str, **_) -> gca_tuning_job.TuningJob:
105+
tuning_job = self._tuning_jobs[name]
106+
tuning_job = copy.deepcopy(tuning_job)
107+
self._progress_tuning_job(name)
108+
109+
return tuning_job
110+
111+
def list_tuning_jobs(
112+
self, *, parent: str, **_
113+
) -> Iterable[gca_tuning_job.TuningJob]:
114+
return [
115+
tuning_job
116+
for name, tuning_job in self._tuning_jobs.items()
117+
if name.startswith(parent + "/")
118+
]
119+
120+
def cancel_tuning_job(self, *, name: str, **_) -> None:
121+
tuning_job = self._tuning_jobs[name]
122+
assert tuning_job.state in (
123+
job_state.JobState.JOB_STATE_RUNNING,
124+
job_state.JobState.JOB_STATE_PENDING,
125+
)
126+
tuning_job.state = job_state.JobState.JOB_STATE_CANCELLED
127+
128+
129+
class MockTuningJobClientWithOverride(aiplatform_utils.ClientWithOverride):
130+
_is_temporary = False
131+
_default_version = compat.V1
132+
_version_map = (
133+
(compat.V1, MockGenAiTuningServiceClient),
134+
# v1beta1 version does not exist
135+
# (compat.V1BETA1, gen_ai_tuning_service_v1beta1.client.JobServiceClient),
136+
)
137+
138+
139+
@pytest.mark.usefixtures("google_auth_mock")
140+
class TestgenerativeModelTuning:
141+
"""Unit tests for generative model tuning."""
142+
143+
def setup_method(self):
144+
vertexai.init(
145+
project=_TEST_PROJECT,
146+
location=_TEST_LOCATION,
147+
)
148+
149+
def teardown_method(self):
150+
initializer.global_pool.shutdown(wait=True)
151+
152+
@mock.patch.object(
153+
target=tuning.TuningJob,
154+
attribute="client_class",
155+
new=MockTuningJobClientWithOverride,
156+
)
157+
def test_genai_tuning_service_supervised_tuning_tune_model(self):
158+
sft_tuning_job = supervised_tuning.train(
159+
source_model="gemini-1.0-pro-001",
160+
train_dataset="gs://some-bucket/some_dataset.jsonl",
161+
# Optional:
162+
validation_dataset="gs://some-bucket/some_dataset.jsonl",
163+
epochs=300,
164+
learning_rate_multiplier=1.0,
165+
)
166+
assert sft_tuning_job.state == job_state.JobState.JOB_STATE_PENDING
167+
assert not sft_tuning_job.has_ended
168+
assert not sft_tuning_job.has_succeeded
169+
170+
# Refreshing the job
171+
sft_tuning_job.refresh()
172+
assert sft_tuning_job.state == job_state.JobState.JOB_STATE_PENDING
173+
assert not sft_tuning_job.has_ended
174+
assert not sft_tuning_job.has_succeeded
175+
176+
# Refreshing the job
177+
sft_tuning_job.refresh()
178+
assert sft_tuning_job.state == job_state.JobState.JOB_STATE_RUNNING
179+
assert not sft_tuning_job.has_ended
180+
assert not sft_tuning_job.has_succeeded
181+
182+
# Refreshing the job
183+
sft_tuning_job.refresh()
184+
assert sft_tuning_job.state == job_state.JobState.JOB_STATE_SUCCEEDED
185+
assert sft_tuning_job.has_ended
186+
assert sft_tuning_job.has_succeeded
187+
assert sft_tuning_job._experiment_name
188+
assert sft_tuning_job.tuned_model_name
189+
assert sft_tuning_job.tuned_model_endpoint_name

vertexai/generative_models/_generative_models.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,16 @@ def __init__(
145145
146146
Args:
147147
model_name: Model Garden model resource name.
148+
Alternatively, a tuned model endpoint resource name can be provided.
148149
generation_config: Default generation config to use in generate_content.
149150
safety_settings: Default safety settings to use in generate_content.
150151
tools: Default tools to use in generate_content.
151152
system_instruction: Default system instruction to use in generate_content.
152153
Note: Only text should be used in parts.
153154
Content of each part will become a separate paragraph.
154155
"""
156+
if not model_name:
157+
raise ValueError("model_name must not be empty")
155158
if "/" not in model_name:
156159
model_name = "publishers/google/models/" + model_name
157160
if model_name.startswith("models/"):
@@ -160,10 +163,13 @@ def __init__(
160163
project = aiplatform_initializer.global_config.project
161164
location = aiplatform_initializer.global_config.location
162165

166+
if model_name.startswith("publishers/"):
167+
prediction_resource_name = f"projects/{project}/locations/{location}/{model_name}"
168+
else:
169+
prediction_resource_name = model_name
170+
163171
self._model_name = model_name
164-
self._prediction_resource_name = (
165-
f"projects/{project}/locations/{location}/{model_name}"
166-
)
172+
self._prediction_resource_name = prediction_resource_name
167173
self._generation_config = generation_config
168174
self._safety_settings = safety_settings
169175
self._tools = tools

vertexai/preview/tuning/__init__.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Classes for tuning models."""
16+
17+
# We just want to re-export certain classes
18+
# pylint: disable=g-multiple-import,g-importing-member
19+
from vertexai.tuning._tuning import TuningJob
20+
21+
__all__ = [
22+
"TuningJob",
23+
]

vertexai/preview/tuning/sft.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Classes for supervised tuning."""
16+
17+
# We just want to re-export certain classes
18+
# pylint: disable=g-multiple-import,g-importing-member
19+
from vertexai.tuning._supervised_tuning import (
20+
train,
21+
SupervisedTuningJob,
22+
)
23+
24+
__all__ = [
25+
"train",
26+
"SupervisedTuningJob",
27+
]

vertexai/tuning/_supervised_tuning.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
from typing import Optional, Union
17+
18+
from google.cloud.aiplatform_v1.types import tuning_job as gca_tuning_job_types
19+
20+
from vertexai import generative_models
21+
from vertexai.tuning import _tuning
22+
23+
24+
def train(
25+
*,
26+
source_model: Union[str, generative_models.GenerativeModel],
27+
train_dataset: str,
28+
validation_dataset: Optional[str] = None,
29+
tuned_model_display_name: Optional[str] = None,
30+
epochs: Optional[int] = None,
31+
learning_rate_multiplier: Optional[float] = None,
32+
) -> "SupervisedTuningJob":
33+
"""Tunes a model using supervised training.
34+
35+
Args:
36+
source_model (str):
37+
Model name for tuning, e.g., "gemini-1.0-pro" or "gemini-1.0-pro-001".
38+
train_dataset: Cloud Storage path to file containing training dataset for tuning.
39+
The dataset should be in JSONL format.
40+
validation_dataset: Cloud Storage path to file containing validation dataset for tuning.
41+
The dataset should be in JSONL format.
42+
tuned_model_display_name: The display name of the
43+
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can
44+
be up to 128 characters long and can consist of any UTF-8 characters.
45+
epochs: Number of training epoches for this tuning job.
46+
learning_rate_multiplier: Learning rate multiplier for tuning.
47+
48+
Returns:
49+
A `TuningJob` object.
50+
"""
51+
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
52+
training_dataset_uri=train_dataset,
53+
validation_dataset_uri=validation_dataset,
54+
hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters(
55+
epoch_count=epochs,
56+
learning_rate_multiplier=learning_rate_multiplier,
57+
),
58+
)
59+
60+
if isinstance(source_model, generative_models.GenerativeModel):
61+
source_model = source_model._prediction_resource_name.rpartition('/')[-1]
62+
63+
return SupervisedTuningJob._create( # pylint: disable=protected-access
64+
base_model=source_model,
65+
tuning_spec=supervised_tuning_spec,
66+
tuned_model_display_name=tuned_model_display_name,
67+
)
68+
69+
70+
class SupervisedTuningJob(_tuning.TuningJob):
71+
pass

0 commit comments

Comments
 (0)