|
311 | 311 | ],
|
312 | 312 | }
|
313 | 313 |
|
| 314 | +_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING = { |
| 315 | + "safetyAttributes": [ |
| 316 | + { |
| 317 | + "scores": [], |
| 318 | + "categories": [], |
| 319 | + "blocked": False, |
| 320 | + }, |
| 321 | + { |
| 322 | + "scores": [0.1], |
| 323 | + "categories": ["Finance"], |
| 324 | + "blocked": True, |
| 325 | + }, |
| 326 | + ], |
| 327 | + "groundingMetadata": [ |
| 328 | + { |
| 329 | + "citations": [ |
| 330 | + { |
| 331 | + "startIndex": 1, |
| 332 | + "endIndex": 2, |
| 333 | + "url": "url1", |
| 334 | + } |
| 335 | + ] |
| 336 | + }, |
| 337 | + { |
| 338 | + "citations": [ |
| 339 | + { |
| 340 | + "startIndex": 3, |
| 341 | + "endIndex": 4, |
| 342 | + "url": "url2", |
| 343 | + } |
| 344 | + ] |
| 345 | + }, |
| 346 | + ], |
| 347 | + "candidates": [ |
| 348 | + { |
| 349 | + "author": "1", |
| 350 | + "content": "Chat response 2", |
| 351 | + }, |
| 352 | + { |
| 353 | + "author": "1", |
| 354 | + "content": "", |
| 355 | + }, |
| 356 | + ], |
| 357 | +} |
| 358 | + |
| 359 | +_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE = { |
| 360 | + "safetyAttributes": [ |
| 361 | + { |
| 362 | + "scores": [], |
| 363 | + "categories": [], |
| 364 | + "blocked": False, |
| 365 | + }, |
| 366 | + { |
| 367 | + "scores": [0.1], |
| 368 | + "categories": ["Finance"], |
| 369 | + "blocked": True, |
| 370 | + }, |
| 371 | + ], |
| 372 | + "groundingMetadata": [ |
| 373 | + None, |
| 374 | + None, |
| 375 | + ], |
| 376 | + "candidates": [ |
| 377 | + { |
| 378 | + "author": "1", |
| 379 | + "content": "Chat response 2", |
| 380 | + }, |
| 381 | + { |
| 382 | + "author": "1", |
| 383 | + "content": "", |
| 384 | + }, |
| 385 | + ], |
| 386 | +} |
| 387 | + |
| 388 | +_EXPECTED_PARSED_GROUNDING_METADATA_CHAT = { |
| 389 | + "citations": [ |
| 390 | + { |
| 391 | + "url": "url1", |
| 392 | + "start_index": 1, |
| 393 | + "end_index": 2, |
| 394 | + "title": None, |
| 395 | + "license": None, |
| 396 | + "publication_date": None, |
| 397 | + }, |
| 398 | + ], |
| 399 | +} |
| 400 | + |
| 401 | +_EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE = { |
| 402 | + "citations": [], |
| 403 | +} |
| 404 | + |
314 | 405 | _TEST_CHAT_PREDICTION_STREAMING = [
|
315 | 406 | {
|
316 | 407 | "candidates": [
|
@@ -2312,6 +2403,221 @@ def test_chat(self):
|
2312 | 2403 | assert prediction_parameters["topK"] == message_top_k
|
2313 | 2404 | assert prediction_parameters["topP"] == message_top_p
|
2314 | 2405 |
|
| 2406 | + gca_predict_response4 = gca_prediction_service.PredictResponse() |
| 2407 | + gca_predict_response4.predictions.append( |
| 2408 | + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING |
| 2409 | + ) |
| 2410 | + test_grounding_sources = [ |
| 2411 | + _TEST_GROUNDING_WEB_SEARCH, |
| 2412 | + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, |
| 2413 | + ] |
| 2414 | + datastore_path = ( |
| 2415 | + "projects/test-project/locations/global/" |
| 2416 | + "collections/default_collection/dataStores/test_datastore" |
| 2417 | + ) |
| 2418 | + expected_grounding_sources = [ |
| 2419 | + {"sources": [{"type": "WEB"}]}, |
| 2420 | + { |
| 2421 | + "sources": [ |
| 2422 | + { |
| 2423 | + "type": "ENTERPRISE", |
| 2424 | + "enterpriseDatastore": datastore_path, |
| 2425 | + } |
| 2426 | + ] |
| 2427 | + }, |
| 2428 | + ] |
| 2429 | + for test_grounding_source, expected_grounding_source in zip( |
| 2430 | + test_grounding_sources, expected_grounding_sources |
| 2431 | + ): |
| 2432 | + with mock.patch.object( |
| 2433 | + target=prediction_service_client.PredictionServiceClient, |
| 2434 | + attribute="predict", |
| 2435 | + return_value=gca_predict_response4, |
| 2436 | + ) as mock_predict4: |
| 2437 | + response = chat2.send_message( |
| 2438 | + "Are my favorite movies based on a book series?", |
| 2439 | + grounding_source=test_grounding_source, |
| 2440 | + ) |
| 2441 | + prediction_parameters = mock_predict4.call_args[1]["parameters"] |
| 2442 | + assert ( |
| 2443 | + prediction_parameters["groundingConfig"] |
| 2444 | + == expected_grounding_source |
| 2445 | + ) |
| 2446 | + assert ( |
| 2447 | + dataclasses.asdict(response.grounding_metadata) |
| 2448 | + == _EXPECTED_PARSED_GROUNDING_METADATA_CHAT |
| 2449 | + ) |
| 2450 | + |
| 2451 | + gca_predict_response5 = gca_prediction_service.PredictResponse() |
| 2452 | + gca_predict_response5.predictions.append( |
| 2453 | + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE |
| 2454 | + ) |
| 2455 | + test_grounding_sources = [ |
| 2456 | + _TEST_GROUNDING_WEB_SEARCH, |
| 2457 | + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, |
| 2458 | + ] |
| 2459 | + datastore_path = ( |
| 2460 | + "projects/test-project/locations/global/" |
| 2461 | + "collections/default_collection/dataStores/test_datastore" |
| 2462 | + ) |
| 2463 | + expected_grounding_sources = [ |
| 2464 | + {"sources": [{"type": "WEB"}]}, |
| 2465 | + { |
| 2466 | + "sources": [ |
| 2467 | + { |
| 2468 | + "type": "ENTERPRISE", |
| 2469 | + "enterpriseDatastore": datastore_path, |
| 2470 | + } |
| 2471 | + ] |
| 2472 | + }, |
| 2473 | + ] |
| 2474 | + for test_grounding_source, expected_grounding_source in zip( |
| 2475 | + test_grounding_sources, expected_grounding_sources |
| 2476 | + ): |
| 2477 | + with mock.patch.object( |
| 2478 | + target=prediction_service_client.PredictionServiceClient, |
| 2479 | + attribute="predict", |
| 2480 | + return_value=gca_predict_response5, |
| 2481 | + ) as mock_predict5: |
| 2482 | + response = chat2.send_message( |
| 2483 | + "Are my favorite movies based on a book series?", |
| 2484 | + grounding_source=test_grounding_source, |
| 2485 | + ) |
| 2486 | + prediction_parameters = mock_predict5.call_args[1]["parameters"] |
| 2487 | + assert ( |
| 2488 | + prediction_parameters["groundingConfig"] |
| 2489 | + == expected_grounding_source |
| 2490 | + ) |
| 2491 | + assert ( |
| 2492 | + dataclasses.asdict(response.grounding_metadata) |
| 2493 | + == _EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE |
| 2494 | + ) |
| 2495 | + |
| 2496 | + @pytest.mark.asyncio |
| 2497 | + async def test_chat_async(self): |
| 2498 | + """Test the chat generation model async api.""" |
| 2499 | + aiplatform.init( |
| 2500 | + project=_TEST_PROJECT, |
| 2501 | + location=_TEST_LOCATION, |
| 2502 | + ) |
| 2503 | + with mock.patch.object( |
| 2504 | + target=model_garden_service_client.ModelGardenServiceClient, |
| 2505 | + attribute="get_publisher_model", |
| 2506 | + return_value=gca_publisher_model.PublisherModel( |
| 2507 | + _CHAT_BISON_PUBLISHER_MODEL_DICT |
| 2508 | + ), |
| 2509 | + ) as mock_get_publisher_model: |
| 2510 | + model = preview_language_models.ChatModel.from_pretrained("chat-bison@001") |
| 2511 | + |
| 2512 | + mock_get_publisher_model.assert_called_once_with( |
| 2513 | + name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY |
| 2514 | + ) |
| 2515 | + chat_temperature = 0.1 |
| 2516 | + chat_max_output_tokens = 100 |
| 2517 | + chat_top_k = 1 |
| 2518 | + chat_top_p = 0.1 |
| 2519 | + |
| 2520 | + chat = model.start_chat( |
| 2521 | + temperature=chat_temperature, |
| 2522 | + max_output_tokens=chat_max_output_tokens, |
| 2523 | + top_k=chat_top_k, |
| 2524 | + top_p=chat_top_p, |
| 2525 | + ) |
| 2526 | + |
| 2527 | + gca_predict_response6 = gca_prediction_service.PredictResponse() |
| 2528 | + gca_predict_response6.predictions.append( |
| 2529 | + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING |
| 2530 | + ) |
| 2531 | + test_grounding_sources = [ |
| 2532 | + _TEST_GROUNDING_WEB_SEARCH, |
| 2533 | + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, |
| 2534 | + ] |
| 2535 | + datastore_path = ( |
| 2536 | + "projects/test-project/locations/global/" |
| 2537 | + "collections/default_collection/dataStores/test_datastore" |
| 2538 | + ) |
| 2539 | + expected_grounding_sources = [ |
| 2540 | + {"sources": [{"type": "WEB"}]}, |
| 2541 | + { |
| 2542 | + "sources": [ |
| 2543 | + { |
| 2544 | + "type": "ENTERPRISE", |
| 2545 | + "enterpriseDatastore": datastore_path, |
| 2546 | + } |
| 2547 | + ] |
| 2548 | + }, |
| 2549 | + ] |
| 2550 | + for test_grounding_source, expected_grounding_source in zip( |
| 2551 | + test_grounding_sources, expected_grounding_sources |
| 2552 | + ): |
| 2553 | + with mock.patch.object( |
| 2554 | + target=prediction_service_async_client.PredictionServiceAsyncClient, |
| 2555 | + attribute="predict", |
| 2556 | + return_value=gca_predict_response6, |
| 2557 | + ) as mock_predict6: |
| 2558 | + response = await chat.send_message_async( |
| 2559 | + "Are my favorite movies based on a book series?", |
| 2560 | + grounding_source=test_grounding_source, |
| 2561 | + ) |
| 2562 | + prediction_parameters = mock_predict6.call_args[1]["parameters"] |
| 2563 | + assert prediction_parameters["temperature"] == chat_temperature |
| 2564 | + assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens |
| 2565 | + assert prediction_parameters["topK"] == chat_top_k |
| 2566 | + assert prediction_parameters["topP"] == chat_top_p |
| 2567 | + assert ( |
| 2568 | + prediction_parameters["groundingConfig"] |
| 2569 | + == expected_grounding_source |
| 2570 | + ) |
| 2571 | + assert ( |
| 2572 | + dataclasses.asdict(response.grounding_metadata) |
| 2573 | + == _EXPECTED_PARSED_GROUNDING_METADATA_CHAT |
| 2574 | + ) |
| 2575 | + |
| 2576 | + gca_predict_response7 = gca_prediction_service.PredictResponse() |
| 2577 | + gca_predict_response7.predictions.append( |
| 2578 | + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE |
| 2579 | + ) |
| 2580 | + test_grounding_sources = [ |
| 2581 | + _TEST_GROUNDING_WEB_SEARCH, |
| 2582 | + _TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE, |
| 2583 | + ] |
| 2584 | + datastore_path = ( |
| 2585 | + "projects/test-project/locations/global/" |
| 2586 | + "collections/default_collection/dataStores/test_datastore" |
| 2587 | + ) |
| 2588 | + expected_grounding_sources = [ |
| 2589 | + {"sources": [{"type": "WEB"}]}, |
| 2590 | + { |
| 2591 | + "sources": [ |
| 2592 | + { |
| 2593 | + "type": "ENTERPRISE", |
| 2594 | + "enterpriseDatastore": datastore_path, |
| 2595 | + } |
| 2596 | + ] |
| 2597 | + }, |
| 2598 | + ] |
| 2599 | + for test_grounding_source, expected_grounding_source in zip( |
| 2600 | + test_grounding_sources, expected_grounding_sources |
| 2601 | + ): |
| 2602 | + with mock.patch.object( |
| 2603 | + target=prediction_service_async_client.PredictionServiceAsyncClient, |
| 2604 | + attribute="predict", |
| 2605 | + return_value=gca_predict_response7, |
| 2606 | + ) as mock_predict7: |
| 2607 | + response = await chat.send_message_async( |
| 2608 | + "Are my favorite movies based on a book series?", |
| 2609 | + grounding_source=test_grounding_source, |
| 2610 | + ) |
| 2611 | + prediction_parameters = mock_predict7.call_args[1]["parameters"] |
| 2612 | + assert ( |
| 2613 | + prediction_parameters["groundingConfig"] |
| 2614 | + == expected_grounding_source |
| 2615 | + ) |
| 2616 | + assert ( |
| 2617 | + dataclasses.asdict(response.grounding_metadata) |
| 2618 | + == _EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE |
| 2619 | + ) |
| 2620 | + |
2315 | 2621 | def test_chat_ga(self):
|
2316 | 2622 | """Tests the chat generation model."""
|
2317 | 2623 | aiplatform.init(
|
|
0 commit comments