|
88 | 88 | "name": "publishers/google/models/chat-bison",
|
89 | 89 | "version_id": "001",
|
90 | 90 | "open_source_category": "PROPRIETARY",
|
91 |
| - "launch_stage": gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW, |
| 91 | + "launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA, |
92 | 92 | "publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001",
|
93 | 93 | "predict_schemata": {
|
94 | 94 | "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml",
|
@@ -792,6 +792,139 @@ def test_chat(self):
|
792 | 792 | gca_predict_response2 = gca_prediction_service.PredictResponse()
|
793 | 793 | gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
|
794 | 794 |
|
| 795 | + with mock.patch.object( |
| 796 | + target=prediction_service_client.PredictionServiceClient, |
| 797 | + attribute="predict", |
| 798 | + return_value=gca_predict_response2, |
| 799 | + ): |
| 800 | + message_text2 = "When were these books published?" |
| 801 | + expected_response2 = _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0][ |
| 802 | + "content" |
| 803 | + ] |
| 804 | + response = chat.send_message(message_text2, temperature=0.1) |
| 805 | + assert response.text == expected_response2 |
| 806 | + assert len(chat.message_history) == 6 |
| 807 | + assert chat.message_history[4].author == chat.USER_AUTHOR |
| 808 | + assert chat.message_history[4].content == message_text2 |
| 809 | + assert chat.message_history[5].author == chat.MODEL_AUTHOR |
| 810 | + assert chat.message_history[5].content == expected_response2 |
| 811 | + |
| 812 | + # Validating the parameters |
| 813 | + chat_temperature = 0.1 |
| 814 | + chat_max_output_tokens = 100 |
| 815 | + chat_top_k = 1 |
| 816 | + chat_top_p = 0.1 |
| 817 | + message_temperature = 0.2 |
| 818 | + message_max_output_tokens = 200 |
| 819 | + message_top_k = 2 |
| 820 | + message_top_p = 0.2 |
| 821 | + |
| 822 | + chat2 = model.start_chat( |
| 823 | + temperature=chat_temperature, |
| 824 | + max_output_tokens=chat_max_output_tokens, |
| 825 | + top_k=chat_top_k, |
| 826 | + top_p=chat_top_p, |
| 827 | + ) |
| 828 | + |
| 829 | + gca_predict_response3 = gca_prediction_service.PredictResponse() |
| 830 | + gca_predict_response3.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1) |
| 831 | + |
| 832 | + with mock.patch.object( |
| 833 | + target=prediction_service_client.PredictionServiceClient, |
| 834 | + attribute="predict", |
| 835 | + return_value=gca_predict_response3, |
| 836 | + ) as mock_predict3: |
| 837 | + chat2.send_message("Are my favorite movies based on a book series?") |
| 838 | + prediction_parameters = mock_predict3.call_args[1]["parameters"] |
| 839 | + assert prediction_parameters["temperature"] == chat_temperature |
| 840 | + assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens |
| 841 | + assert prediction_parameters["topK"] == chat_top_k |
| 842 | + assert prediction_parameters["topP"] == chat_top_p |
| 843 | + |
| 844 | + chat2.send_message( |
| 845 | + "Are my favorite movies based on a book series?", |
| 846 | + temperature=message_temperature, |
| 847 | + max_output_tokens=message_max_output_tokens, |
| 848 | + top_k=message_top_k, |
| 849 | + top_p=message_top_p, |
| 850 | + ) |
| 851 | + prediction_parameters = mock_predict3.call_args[1]["parameters"] |
| 852 | + assert prediction_parameters["temperature"] == message_temperature |
| 853 | + assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens |
| 854 | + assert prediction_parameters["topK"] == message_top_k |
| 855 | + assert prediction_parameters["topP"] == message_top_p |
| 856 | + |
| 857 | + def test_chat_ga(self): |
| 858 | + """Tests the chat generation model.""" |
| 859 | + aiplatform.init( |
| 860 | + project=_TEST_PROJECT, |
| 861 | + location=_TEST_LOCATION, |
| 862 | + ) |
| 863 | + with mock.patch.object( |
| 864 | + target=model_garden_service_client.ModelGardenServiceClient, |
| 865 | + attribute="get_publisher_model", |
| 866 | + return_value=gca_publisher_model.PublisherModel( |
| 867 | + _CHAT_BISON_PUBLISHER_MODEL_DICT |
| 868 | + ), |
| 869 | + ) as mock_get_publisher_model: |
| 870 | + model = language_models.ChatModel.from_pretrained("chat-bison@001") |
| 871 | + |
| 872 | + mock_get_publisher_model.assert_called_once_with( |
| 873 | + name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY |
| 874 | + ) |
| 875 | + |
| 876 | + chat = model.start_chat( |
| 877 | + context=""" |
| 878 | + My name is Ned. |
| 879 | + You are my personal assistant. |
| 880 | + My favorite movies are Lord of the Rings and Hobbit. |
| 881 | + """, |
| 882 | + examples=[ |
| 883 | + language_models.InputOutputTextPair( |
| 884 | + input_text="Who do you work for?", |
| 885 | + output_text="I work for Ned.", |
| 886 | + ), |
| 887 | + language_models.InputOutputTextPair( |
| 888 | + input_text="What do I like?", |
| 889 | + output_text="Ned likes watching movies.", |
| 890 | + ), |
| 891 | + ], |
| 892 | + message_history=[ |
| 893 | + language_models.ChatMessage( |
| 894 | + author=preview_language_models.ChatSession.USER_AUTHOR, |
| 895 | + content="Question 1?", |
| 896 | + ), |
| 897 | + language_models.ChatMessage( |
| 898 | + author=preview_language_models.ChatSession.MODEL_AUTHOR, |
| 899 | + content="Answer 1.", |
| 900 | + ), |
| 901 | + ], |
| 902 | + temperature=0.0, |
| 903 | + ) |
| 904 | + |
| 905 | + gca_predict_response1 = gca_prediction_service.PredictResponse() |
| 906 | + gca_predict_response1.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1) |
| 907 | + |
| 908 | + with mock.patch.object( |
| 909 | + target=prediction_service_client.PredictionServiceClient, |
| 910 | + attribute="predict", |
| 911 | + return_value=gca_predict_response1, |
| 912 | + ): |
| 913 | + message_text1 = "Are my favorite movies based on a book series?" |
| 914 | + expected_response1 = _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0][ |
| 915 | + "content" |
| 916 | + ] |
| 917 | + response = chat.send_message(message_text1) |
| 918 | + assert response.text == expected_response1 |
| 919 | + assert len(chat.message_history) == 4 |
| 920 | + assert chat.message_history[2].author == chat.USER_AUTHOR |
| 921 | + assert chat.message_history[2].content == message_text1 |
| 922 | + assert chat.message_history[3].author == chat.MODEL_AUTHOR |
| 923 | + assert chat.message_history[3].content == expected_response1 |
| 924 | + |
| 925 | + gca_predict_response2 = gca_prediction_service.PredictResponse() |
| 926 | + gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2) |
| 927 | + |
795 | 928 | with mock.patch.object(
|
796 | 929 | target=prediction_service_client.PredictionServiceClient,
|
797 | 930 | attribute="predict",
|
|
0 commit comments