Skip to content

Commit 8191035

Browse files
Ark-kuncopybara-github
authored andcommitted
chore: [LLM] Added system tests for tuning
The tests cover the tuning as well as listing and loading the tuned models PiperOrigin-RevId: 540433943
1 parent d4d8613 commit 8191035

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

tests/system/aiplatform/test_language_models.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,66 @@ def test_text_embedding(self):
8181
for embedding in embeddings:
8282
vector = embedding.values
8383
assert len(vector) == 768
84+
85+
def test_tuning(self, shared_state):
86+
"""Test tuning, listing and loading models."""
87+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
88+
89+
model = TextGenerationModel.from_pretrained("google/text-bison@001")
90+
91+
import pandas
92+
93+
training_data = pandas.DataFrame(
94+
data=[
95+
{"input_text": "Input 0", "output_text": "Output 0"},
96+
{"input_text": "Input 1", "output_text": "Output 1"},
97+
{"input_text": "Input 2", "output_text": "Output 2"},
98+
{"input_text": "Input 3", "output_text": "Output 3"},
99+
{"input_text": "Input 4", "output_text": "Output 4"},
100+
{"input_text": "Input 5", "output_text": "Output 5"},
101+
{"input_text": "Input 6", "output_text": "Output 6"},
102+
{"input_text": "Input 7", "output_text": "Output 7"},
103+
{"input_text": "Input 8", "output_text": "Output 8"},
104+
{"input_text": "Input 9", "output_text": "Output 9"},
105+
]
106+
)
107+
108+
model.tune_model(
109+
training_data=training_data,
110+
train_steps=1,
111+
tuning_job_location="europe-west4",
112+
tuned_model_location="us-central1",
113+
)
114+
# According to the Pipelines design, external resources created by a pipeline
115+
# must not be modified or deleted. Otherwise caching will break next pipeline runs.
116+
shared_state.setdefault("resources", [])
117+
shared_state["resources"].append(model._endpoint)
118+
shared_state["resources"].extend(
119+
aiplatform.Model(model_name=deployed_model.model)
120+
for deployed_model in model._endpoint.list_models()
121+
)
122+
# Deleting the Endpoint is a little less bad since the LLM SDK will recreate it, but it's not advised for the same reason.
123+
124+
response = model.predict(
125+
"What is the best recipe for banana bread? Recipe:",
126+
max_output_tokens=128,
127+
temperature=0,
128+
top_p=1,
129+
top_k=5,
130+
)
131+
assert response.text
132+
133+
tuned_model_names = model.list_tuned_model_names()
134+
assert tuned_model_names
135+
tuned_model_name = tuned_model_names[0]
136+
137+
tuned_model = TextGenerationModel.get_tuned_model(tuned_model_name)
138+
139+
tuned_model_response = tuned_model.predict(
140+
"What is the best recipe for banana bread? Recipe:",
141+
max_output_tokens=128,
142+
temperature=0,
143+
top_p=1,
144+
top_k=5,
145+
)
146+
assert tuned_model_response.text

0 commit comments

Comments
 (0)