Skip to content

Commit d4667f2

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: enable grounding to ChatModel send_message and send_message_async methods
PiperOrigin-RevId: 579999652
1 parent eaf4420 commit d4667f2

File tree

3 files changed

+355
-19
lines changed

3 files changed

+355
-19
lines changed

tests/system/aiplatform/test_language_models.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def test_preview_text_embedding_top_level_from_pretrained(self):
124124

125125
def test_chat_on_chat_model(self):
126126
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
127-
128127
chat_model = ChatModel.from_pretrained("google/chat-bison@001")
128+
grounding_source = language_models.GroundingSource.WebSearch()
129129
chat = chat_model.start_chat(
130130
context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.",
131131
examples=[
@@ -143,19 +143,23 @@ def test_chat_on_chat_model(self):
143143
)
144144

145145
message1 = "Are my favorite movies based on a book series?"
146-
response1 = chat.send_message(message1)
146+
response1 = chat.send_message(
147+
message1,
148+
grounding_source=grounding_source,
149+
)
147150
assert response1.text
151+
assert response1.grounding_metadata
148152
assert len(chat.message_history) == 2
149153
assert chat.message_history[0].author == chat.USER_AUTHOR
150154
assert chat.message_history[0].content == message1
151155
assert chat.message_history[1].author == chat.MODEL_AUTHOR
152156

153157
message2 = "When were these books published?"
154158
response2 = chat.send_message(
155-
message2,
156-
temperature=0.1,
159+
message2, temperature=0.1, grounding_source=grounding_source
157160
)
158161
assert response2.text
162+
assert response2.grounding_metadata
159163
assert len(chat.message_history) == 4
160164
assert chat.message_history[2].author == chat.USER_AUTHOR
161165
assert chat.message_history[2].content == message2
@@ -189,6 +193,7 @@ async def test_chat_model_async(self):
189193
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
190194

191195
chat_model = ChatModel.from_pretrained("google/chat-bison@001")
196+
grounding_source = language_models.GroundingSource.WebSearch()
192197
chat = chat_model.start_chat(
193198
context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.",
194199
examples=[
@@ -206,8 +211,12 @@ async def test_chat_model_async(self):
206211
)
207212

208213
message1 = "Are my favorite movies based on a book series?"
209-
response1 = await chat.send_message_async(message1)
214+
response1 = await chat.send_message_async(
215+
message1,
216+
grounding_source=grounding_source,
217+
)
210218
assert response1.text
219+
assert response1.grounding_metadata
211220
assert len(chat.message_history) == 2
212221
assert chat.message_history[0].author == chat.USER_AUTHOR
213222
assert chat.message_history[0].content == message1
@@ -217,8 +226,10 @@ async def test_chat_model_async(self):
217226
response2 = await chat.send_message_async(
218227
message2,
219228
temperature=0.1,
229+
grounding_source=grounding_source,
220230
)
221231
assert response2.text
232+
assert response2.grounding_metadata
222233
assert len(chat.message_history) == 4
223234
assert chat.message_history[2].author == chat.USER_AUTHOR
224235
assert chat.message_history[2].content == message2

tests/unit/aiplatform/test_language_models.py

+306
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,97 @@
311311
],
312312
}
313313

314+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING = {
315+
"safetyAttributes": [
316+
{
317+
"scores": [],
318+
"categories": [],
319+
"blocked": False,
320+
},
321+
{
322+
"scores": [0.1],
323+
"categories": ["Finance"],
324+
"blocked": True,
325+
},
326+
],
327+
"groundingMetadata": [
328+
{
329+
"citations": [
330+
{
331+
"startIndex": 1,
332+
"endIndex": 2,
333+
"url": "url1",
334+
}
335+
]
336+
},
337+
{
338+
"citations": [
339+
{
340+
"startIndex": 3,
341+
"endIndex": 4,
342+
"url": "url2",
343+
}
344+
]
345+
},
346+
],
347+
"candidates": [
348+
{
349+
"author": "1",
350+
"content": "Chat response 2",
351+
},
352+
{
353+
"author": "1",
354+
"content": "",
355+
},
356+
],
357+
}
358+
359+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE = {
360+
"safetyAttributes": [
361+
{
362+
"scores": [],
363+
"categories": [],
364+
"blocked": False,
365+
},
366+
{
367+
"scores": [0.1],
368+
"categories": ["Finance"],
369+
"blocked": True,
370+
},
371+
],
372+
"groundingMetadata": [
373+
None,
374+
None,
375+
],
376+
"candidates": [
377+
{
378+
"author": "1",
379+
"content": "Chat response 2",
380+
},
381+
{
382+
"author": "1",
383+
"content": "",
384+
},
385+
],
386+
}
387+
388+
_EXPECTED_PARSED_GROUNDING_METADATA_CHAT = {
389+
"citations": [
390+
{
391+
"url": "url1",
392+
"start_index": 1,
393+
"end_index": 2,
394+
"title": None,
395+
"license": None,
396+
"publication_date": None,
397+
},
398+
],
399+
}
400+
401+
_EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE = {
402+
"citations": [],
403+
}
404+
314405
_TEST_CHAT_PREDICTION_STREAMING = [
315406
{
316407
"candidates": [
@@ -2312,6 +2403,221 @@ def test_chat(self):
23122403
assert prediction_parameters["topK"] == message_top_k
23132404
assert prediction_parameters["topP"] == message_top_p
23142405

2406+
gca_predict_response4 = gca_prediction_service.PredictResponse()
2407+
gca_predict_response4.predictions.append(
2408+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING
2409+
)
2410+
test_grounding_sources = [
2411+
_TEST_GROUNDING_WEB_SEARCH,
2412+
_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE,
2413+
]
2414+
datastore_path = (
2415+
"projects/test-project/locations/global/"
2416+
"collections/default_collection/dataStores/test_datastore"
2417+
)
2418+
expected_grounding_sources = [
2419+
{"sources": [{"type": "WEB"}]},
2420+
{
2421+
"sources": [
2422+
{
2423+
"type": "ENTERPRISE",
2424+
"enterpriseDatastore": datastore_path,
2425+
}
2426+
]
2427+
},
2428+
]
2429+
for test_grounding_source, expected_grounding_source in zip(
2430+
test_grounding_sources, expected_grounding_sources
2431+
):
2432+
with mock.patch.object(
2433+
target=prediction_service_client.PredictionServiceClient,
2434+
attribute="predict",
2435+
return_value=gca_predict_response4,
2436+
) as mock_predict4:
2437+
response = chat2.send_message(
2438+
"Are my favorite movies based on a book series?",
2439+
grounding_source=test_grounding_source,
2440+
)
2441+
prediction_parameters = mock_predict4.call_args[1]["parameters"]
2442+
assert (
2443+
prediction_parameters["groundingConfig"]
2444+
== expected_grounding_source
2445+
)
2446+
assert (
2447+
dataclasses.asdict(response.grounding_metadata)
2448+
== _EXPECTED_PARSED_GROUNDING_METADATA_CHAT
2449+
)
2450+
2451+
gca_predict_response5 = gca_prediction_service.PredictResponse()
2452+
gca_predict_response5.predictions.append(
2453+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE
2454+
)
2455+
test_grounding_sources = [
2456+
_TEST_GROUNDING_WEB_SEARCH,
2457+
_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE,
2458+
]
2459+
datastore_path = (
2460+
"projects/test-project/locations/global/"
2461+
"collections/default_collection/dataStores/test_datastore"
2462+
)
2463+
expected_grounding_sources = [
2464+
{"sources": [{"type": "WEB"}]},
2465+
{
2466+
"sources": [
2467+
{
2468+
"type": "ENTERPRISE",
2469+
"enterpriseDatastore": datastore_path,
2470+
}
2471+
]
2472+
},
2473+
]
2474+
for test_grounding_source, expected_grounding_source in zip(
2475+
test_grounding_sources, expected_grounding_sources
2476+
):
2477+
with mock.patch.object(
2478+
target=prediction_service_client.PredictionServiceClient,
2479+
attribute="predict",
2480+
return_value=gca_predict_response5,
2481+
) as mock_predict5:
2482+
response = chat2.send_message(
2483+
"Are my favorite movies based on a book series?",
2484+
grounding_source=test_grounding_source,
2485+
)
2486+
prediction_parameters = mock_predict5.call_args[1]["parameters"]
2487+
assert (
2488+
prediction_parameters["groundingConfig"]
2489+
== expected_grounding_source
2490+
)
2491+
assert (
2492+
dataclasses.asdict(response.grounding_metadata)
2493+
== _EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE
2494+
)
2495+
2496+
@pytest.mark.asyncio
2497+
async def test_chat_async(self):
2498+
"""Test the chat generation model async api."""
2499+
aiplatform.init(
2500+
project=_TEST_PROJECT,
2501+
location=_TEST_LOCATION,
2502+
)
2503+
with mock.patch.object(
2504+
target=model_garden_service_client.ModelGardenServiceClient,
2505+
attribute="get_publisher_model",
2506+
return_value=gca_publisher_model.PublisherModel(
2507+
_CHAT_BISON_PUBLISHER_MODEL_DICT
2508+
),
2509+
) as mock_get_publisher_model:
2510+
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")
2511+
2512+
mock_get_publisher_model.assert_called_once_with(
2513+
name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY
2514+
)
2515+
chat_temperature = 0.1
2516+
chat_max_output_tokens = 100
2517+
chat_top_k = 1
2518+
chat_top_p = 0.1
2519+
2520+
chat = model.start_chat(
2521+
temperature=chat_temperature,
2522+
max_output_tokens=chat_max_output_tokens,
2523+
top_k=chat_top_k,
2524+
top_p=chat_top_p,
2525+
)
2526+
2527+
gca_predict_response6 = gca_prediction_service.PredictResponse()
2528+
gca_predict_response6.predictions.append(
2529+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING
2530+
)
2531+
test_grounding_sources = [
2532+
_TEST_GROUNDING_WEB_SEARCH,
2533+
_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE,
2534+
]
2535+
datastore_path = (
2536+
"projects/test-project/locations/global/"
2537+
"collections/default_collection/dataStores/test_datastore"
2538+
)
2539+
expected_grounding_sources = [
2540+
{"sources": [{"type": "WEB"}]},
2541+
{
2542+
"sources": [
2543+
{
2544+
"type": "ENTERPRISE",
2545+
"enterpriseDatastore": datastore_path,
2546+
}
2547+
]
2548+
},
2549+
]
2550+
for test_grounding_source, expected_grounding_source in zip(
2551+
test_grounding_sources, expected_grounding_sources
2552+
):
2553+
with mock.patch.object(
2554+
target=prediction_service_async_client.PredictionServiceAsyncClient,
2555+
attribute="predict",
2556+
return_value=gca_predict_response6,
2557+
) as mock_predict6:
2558+
response = await chat.send_message_async(
2559+
"Are my favorite movies based on a book series?",
2560+
grounding_source=test_grounding_source,
2561+
)
2562+
prediction_parameters = mock_predict6.call_args[1]["parameters"]
2563+
assert prediction_parameters["temperature"] == chat_temperature
2564+
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
2565+
assert prediction_parameters["topK"] == chat_top_k
2566+
assert prediction_parameters["topP"] == chat_top_p
2567+
assert (
2568+
prediction_parameters["groundingConfig"]
2569+
== expected_grounding_source
2570+
)
2571+
assert (
2572+
dataclasses.asdict(response.grounding_metadata)
2573+
== _EXPECTED_PARSED_GROUNDING_METADATA_CHAT
2574+
)
2575+
2576+
gca_predict_response7 = gca_prediction_service.PredictResponse()
2577+
gca_predict_response7.predictions.append(
2578+
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE
2579+
)
2580+
test_grounding_sources = [
2581+
_TEST_GROUNDING_WEB_SEARCH,
2582+
_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE,
2583+
]
2584+
datastore_path = (
2585+
"projects/test-project/locations/global/"
2586+
"collections/default_collection/dataStores/test_datastore"
2587+
)
2588+
expected_grounding_sources = [
2589+
{"sources": [{"type": "WEB"}]},
2590+
{
2591+
"sources": [
2592+
{
2593+
"type": "ENTERPRISE",
2594+
"enterpriseDatastore": datastore_path,
2595+
}
2596+
]
2597+
},
2598+
]
2599+
for test_grounding_source, expected_grounding_source in zip(
2600+
test_grounding_sources, expected_grounding_sources
2601+
):
2602+
with mock.patch.object(
2603+
target=prediction_service_async_client.PredictionServiceAsyncClient,
2604+
attribute="predict",
2605+
return_value=gca_predict_response7,
2606+
) as mock_predict7:
2607+
response = await chat.send_message_async(
2608+
"Are my favorite movies based on a book series?",
2609+
grounding_source=test_grounding_source,
2610+
)
2611+
prediction_parameters = mock_predict7.call_args[1]["parameters"]
2612+
assert (
2613+
prediction_parameters["groundingConfig"]
2614+
== expected_grounding_source
2615+
)
2616+
assert (
2617+
dataclasses.asdict(response.grounding_metadata)
2618+
== _EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE
2619+
)
2620+
23152621
def test_chat_ga(self):
23162622
"""Tests the chat generation model."""
23172623
aiplatform.init(

0 commit comments

Comments
 (0)