Skip to content

Commit c9c9f10

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added support for async prediction methods
PiperOrigin-RevId: 566381589
1 parent 41d341e commit c9c9f10

File tree

3 files changed

+369
-17
lines changed

3 files changed

+369
-17
lines changed

tests/system/aiplatform/test_language_models.py

+76
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
# pylint: disable=protected-access, g-multiple-import
1919

20+
import pytest
21+
22+
2023
from google.cloud import aiplatform
2124
from google.cloud.aiplatform.compat.types import (
2225
job_state as gca_job_state,
@@ -54,6 +57,22 @@ def test_text_generation(self):
5457
stop_sequences=["# %%"],
5558
).text
5659

60+
@pytest.mark.asyncio
61+
async def test_text_generation_model_predict_async(self):
62+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
63+
64+
model = TextGenerationModel.from_pretrained("google/text-bison@001")
65+
66+
response = await model.predict_async(
67+
"What is the best recipe for banana bread? Recipe:",
68+
max_output_tokens=128,
69+
temperature=0.0,
70+
top_p=1.0,
71+
top_k=5,
72+
stop_sequences=["# %%"],
73+
)
74+
assert response.text
75+
5776
def test_text_generation_streaming(self):
5877
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
5978

@@ -107,6 +126,46 @@ def test_chat_on_chat_model(self):
107126
assert chat.message_history[2].content == message2
108127
assert chat.message_history[3].author == chat.MODEL_AUTHOR
109128

129+
@pytest.mark.asyncio
130+
async def test_chat_model_async(self):
131+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
132+
133+
chat_model = ChatModel.from_pretrained("google/chat-bison@001")
134+
chat = chat_model.start_chat(
135+
context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.",
136+
examples=[
137+
InputOutputTextPair(
138+
input_text="Who do you work for?",
139+
output_text="I work for Ned.",
140+
),
141+
InputOutputTextPair(
142+
input_text="What do I like?",
143+
output_text="Ned likes watching movies.",
144+
),
145+
],
146+
temperature=0.0,
147+
stop_sequences=["# %%"],
148+
)
149+
150+
message1 = "Are my favorite movies based on a book series?"
151+
response1 = await chat.send_message_async(message1)
152+
assert response1.text
153+
assert len(chat.message_history) == 2
154+
assert chat.message_history[0].author == chat.USER_AUTHOR
155+
assert chat.message_history[0].content == message1
156+
assert chat.message_history[1].author == chat.MODEL_AUTHOR
157+
158+
message2 = "When were these books published?"
159+
response2 = await chat.send_message_async(
160+
message2,
161+
temperature=0.1,
162+
)
163+
assert response2.text
164+
assert len(chat.message_history) == 4
165+
assert chat.message_history[2].author == chat.USER_AUTHOR
166+
assert chat.message_history[2].content == message2
167+
assert chat.message_history[3].author == chat.MODEL_AUTHOR
168+
110169
def test_chat_model_send_message_streaming(self):
111170
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
112171

@@ -161,6 +220,23 @@ def test_text_embedding(self):
161220
assert embeddings[1].statistics.token_count > 1000
162221
assert embeddings[1].statistics.truncated
163222

223+
@pytest.mark.asyncio
224+
async def test_text_embedding_async(self):
225+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
226+
227+
model = TextEmbeddingModel.from_pretrained("google/textembedding-gecko@001")
228+
# One short text, one llong text (to check truncation)
229+
texts = ["What is life?", "What is life?" * 1000]
230+
embeddings = await model.get_embeddings_async(texts)
231+
assert len(embeddings) == 2
232+
assert len(embeddings[0].values) == 768
233+
assert embeddings[0].statistics.token_count > 0
234+
assert not embeddings[0].statistics.truncated
235+
236+
assert len(embeddings[1].values) == 768
237+
assert embeddings[1].statistics.token_count > 1000
238+
assert embeddings[1].statistics.truncated
239+
164240
def test_tuning(self, shared_state):
165241
"""Test tuning, listing and loading models."""
166242
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

tests/unit/aiplatform/test_language_models.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@
4040
model_service_client,
4141
pipeline_service_client,
4242
)
43-
from google.cloud.aiplatform.compat.services import prediction_service_client
43+
from google.cloud.aiplatform.compat.services import (
44+
prediction_service_client,
45+
prediction_service_async_client,
46+
)
4447
from google.cloud.aiplatform.compat.types import (
4548
artifact as gca_artifact,
4649
prediction_service as gca_prediction_service,
@@ -1273,6 +1276,49 @@ def test_text_generation_ga(self):
12731276
assert "topP" not in prediction_parameters
12741277
assert "topK" not in prediction_parameters
12751278

1279+
@pytest.mark.asyncio
1280+
async def test_text_generation_async(self):
1281+
"""Tests the text generation model."""
1282+
aiplatform.init(
1283+
project=_TEST_PROJECT,
1284+
location=_TEST_LOCATION,
1285+
)
1286+
with mock.patch.object(
1287+
target=model_garden_service_client.ModelGardenServiceClient,
1288+
attribute="get_publisher_model",
1289+
return_value=gca_publisher_model.PublisherModel(
1290+
_TEXT_BISON_PUBLISHER_MODEL_DICT
1291+
),
1292+
):
1293+
model = language_models.TextGenerationModel.from_pretrained(
1294+
"text-bison@001"
1295+
)
1296+
1297+
gca_predict_response = gca_prediction_service.PredictResponse()
1298+
gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION)
1299+
1300+
with mock.patch.object(
1301+
target=prediction_service_async_client.PredictionServiceAsyncClient,
1302+
attribute="predict",
1303+
return_value=gca_predict_response,
1304+
) as mock_predict:
1305+
response = await model.predict_async(
1306+
"What is the best recipe for banana bread? Recipe:",
1307+
max_output_tokens=128,
1308+
temperature=0.0,
1309+
top_p=1.0,
1310+
top_k=5,
1311+
stop_sequences=["\n"],
1312+
)
1313+
1314+
prediction_parameters = mock_predict.call_args[1]["parameters"]
1315+
assert prediction_parameters["maxDecodeSteps"] == 128
1316+
assert prediction_parameters["temperature"] == 0.0
1317+
assert prediction_parameters["topP"] == 1.0
1318+
assert prediction_parameters["topK"] == 5
1319+
assert prediction_parameters["stopSequences"] == ["\n"]
1320+
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
1321+
12761322
def test_text_generation_model_predict_streaming(self):
12771323
"""Tests the TextGenerationModel.predict_streaming method."""
12781324
with mock.patch.object(

0 commit comments

Comments
 (0)