Skip to content

Commit 4eef230

Browse files
nileshspringmlholtskinnergcf-owl-bot[bot]
authored
feat: Add samples for the Tuning API reference doc (#11631)
* Add samples for the Tuning API reference doc: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#python --------- Co-authored-by: Holt Skinner <[email protected]> Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 4cbc8f2 commit 4eef230

File tree

2 files changed

+164
-0
lines changed

2 files changed

+164
-0
lines changed

generative_ai/gemini_tuning.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List
16+
17+
from vertexai.preview.tuning import sft
18+
19+
20+
def gemini_supervised_tuning(project_id: str) -> sft.SupervisedTuningJob:
21+
# [START generativeaionvertexai_supervised_tuning]
22+
23+
import time
24+
import vertexai
25+
from vertexai.preview.tuning import sft
26+
27+
# TODO(developer): Update and un-comment below lines
28+
# project_id = "PROJECT_ID"
29+
30+
vertexai.init(project=project_id, location="us-central1")
31+
32+
sft_tuning_job = sft.train(
33+
source_model="gemini-1.0-pro-002",
34+
train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
35+
# The following parameters are optional
36+
validation_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_validation_data.jsonl",
37+
epochs=4,
38+
learning_rate_multiplier=1.0,
39+
tuned_model_display_name="tuned_gemini_pro",
40+
)
41+
42+
# Polling for job completion
43+
while not sft_tuning_job.has_ended:
44+
time.sleep(60)
45+
sft_tuning_job.refresh()
46+
47+
print(sft_tuning_job.tuned_model_name)
48+
print(sft_tuning_job.tuned_model_endpoint_name)
49+
print(sft_tuning_job.experiment)
50+
# [END generativeaionvertexai_supervised_tuning]
51+
52+
return sft_tuning_job
53+
54+
55+
def get_supervised_tuning_job(
56+
project_id: str, location: str, tuning_job_id: str
57+
) -> sft.SupervisedTuningJob:
58+
# [START generativeaionvertexai_get_supervised_tuning_job]
59+
import vertexai
60+
from vertexai.preview.tuning import sft
61+
62+
# TODO(developer): Update and un-comment below lines
63+
# project_id = "PROJECT_ID"
64+
# location = "us-central1"
65+
# tuning_job_id = "TUNING_JOB_ID"
66+
67+
vertexai.init(project=project_id, location="us-central1")
68+
69+
response = sft.SupervisedTuningJob(
70+
f"projects/{project_id}/locations/{location}/tuningJobs/{tuning_job_id}"
71+
)
72+
73+
print(response)
74+
# [END generativeaionvertexai_get_supervised_tuning_job]
75+
76+
return response
77+
78+
79+
def list_supervised_tuning_jobs(project_id: str) -> List[sft.SupervisedTuningJob]:
80+
# [START generativeaionvertexai_list_supervised_tuning_jobs]
81+
import vertexai
82+
from vertexai.preview.tuning import sft
83+
84+
# TODO(developer): Update and un-comment below lines
85+
# project_id = "PROJECT_ID"
86+
87+
vertexai.init(project=project_id, location="us-central1")
88+
89+
responses = sft.SupervisedTuningJob.list()
90+
91+
for response in responses:
92+
print(response)
93+
# [END generativeaionvertexai_list_supervised_tuning_jobs]
94+
95+
return responses
96+
97+
98+
def cancel_supervised_tuning_job(
99+
project_id: str, location: str, tuning_job_id: str
100+
) -> None:
101+
# [START generativeaionvertexai_cancel_supervised_tuning_job]
102+
import vertexai
103+
from vertexai.preview.tuning import sft
104+
105+
# TODO(developer): Update and un-comment below lines
106+
# project_id = "PROJECT_ID"
107+
# location = "us-central1"
108+
# tuning_job_id = "TUNING_JOB_ID"
109+
110+
vertexai.init(project=project_id, location="us-central1")
111+
112+
job = sft.SupervisedTuningJob(
113+
f"projects/{project_id}/locations/{location}/tuningJobs/{tuning_job_id}"
114+
)
115+
job.cancel()
116+
117+
# [END generativeaionvertexai_cancel_supervised_tuning_job]

generative_ai/gemini_tuning_test.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import pytest
18+
19+
import gemini_tuning
20+
21+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
22+
REGION = "us-central1"
23+
MODEL_ID = "gemini-1.5-pro-preview-0409"
24+
TUNING_JOB_ID = "4982013113894174720"
25+
26+
27+
@pytest.mark.skip(reason="Skip due to tuning taking a long time.")
28+
def test_supervised_tuning() -> None:
29+
response = gemini_tuning.gemini_supervised_tuning(PROJECT_ID)
30+
assert response
31+
32+
33+
def test_get_supervised_tuning_job() -> None:
34+
response = gemini_tuning.get_supervised_tuning_job(
35+
PROJECT_ID, REGION, TUNING_JOB_ID
36+
)
37+
assert response
38+
39+
40+
def test_list_supervised_tuning_jobs() -> None:
41+
response = gemini_tuning.list_supervised_tuning_jobs(PROJECT_ID)
42+
assert response
43+
44+
45+
@pytest.mark.skip(reason="Skip due to tuning taking a long time.")
46+
def test_cancel_supervised_tuning_job() -> None:
47+
gemini_tuning.cancel_supervised_tuning_job(PROJECT_ID, REGION, TUNING_JOB_ID)

0 commit comments

Comments
 (0)