53
53
model as gca_model ,
54
54
)
55
55
56
- from vertexai .preview import language_models
56
+ from vertexai .preview import (
57
+ language_models as preview_language_models ,
58
+ )
59
+ from vertexai import language_models
57
60
from google .cloud .aiplatform_v1 import Execution as GapicExecution
58
61
from google .cloud .aiplatform .compat .types import (
59
62
encryption_spec as gca_encryption_spec ,
@@ -456,7 +459,7 @@ def get_endpoint_mock():
456
459
@pytest .fixture
457
460
def mock_get_tuned_model (get_endpoint_mock ):
458
461
with mock .patch .object (
459
- language_models .TextGenerationModel , "get_tuned_model"
462
+ preview_language_models .TextGenerationModel , "get_tuned_model"
460
463
) as mock_text_generation_model :
461
464
mock_text_generation_model ._model_id = (
462
465
test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME
@@ -519,6 +522,50 @@ def teardown_method(self):
519
522
initializer .global_pool .shutdown (wait = True )
520
523
521
524
def test_text_generation (self ):
525
+ """Tests the text generation model."""
526
+ aiplatform .init (
527
+ project = _TEST_PROJECT ,
528
+ location = _TEST_LOCATION ,
529
+ )
530
+ with mock .patch .object (
531
+ target = model_garden_service_client .ModelGardenServiceClient ,
532
+ attribute = "get_publisher_model" ,
533
+ return_value = gca_publisher_model .PublisherModel (
534
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
535
+ ),
536
+ ) as mock_get_publisher_model :
537
+ model = preview_language_models .TextGenerationModel .from_pretrained (
538
+ "text-bison@001"
539
+ )
540
+
541
+ mock_get_publisher_model .assert_called_once_with (
542
+ name = "publishers/google/models/text-bison@001" , retry = base ._DEFAULT_RETRY
543
+ )
544
+
545
+ assert (
546
+ model ._model_resource_name
547
+ == f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /publishers/google/models/text-bison@001"
548
+ )
549
+
550
+ gca_predict_response = gca_prediction_service .PredictResponse ()
551
+ gca_predict_response .predictions .append (_TEST_TEXT_GENERATION_PREDICTION )
552
+
553
+ with mock .patch .object (
554
+ target = prediction_service_client .PredictionServiceClient ,
555
+ attribute = "predict" ,
556
+ return_value = gca_predict_response ,
557
+ ):
558
+ response = model .predict (
559
+ "What is the best recipe for banana bread? Recipe:" ,
560
+ max_output_tokens = 128 ,
561
+ temperature = 0 ,
562
+ top_p = 1 ,
563
+ top_k = 5 ,
564
+ )
565
+
566
+ assert response .text == _TEST_TEXT_GENERATION_PREDICTION ["content" ]
567
+
568
+ def test_text_generation_ga (self ):
522
569
"""Tests the text generation model."""
523
570
aiplatform .init (
524
571
project = _TEST_PROJECT ,
@@ -596,7 +643,7 @@ def test_tune_model(
596
643
_TEXT_BISON_PUBLISHER_MODEL_DICT
597
644
),
598
645
):
599
- model = language_models .TextGenerationModel .from_pretrained (
646
+ model = preview_language_models .TextGenerationModel .from_pretrained (
600
647
"text-bison@001"
601
648
)
602
649
@@ -631,7 +678,7 @@ def test_get_tuned_model(
631
678
_TEXT_BISON_PUBLISHER_MODEL_DICT
632
679
),
633
680
):
634
- tuned_model = language_models .TextGenerationModel .get_tuned_model (
681
+ tuned_model = preview_language_models .TextGenerationModel .get_tuned_model (
635
682
test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME
636
683
)
637
684
@@ -651,7 +698,7 @@ def get_tuned_model_raises_if_not_called_with_mg_model(self):
651
698
)
652
699
653
700
with pytest .raises (ValueError ):
654
- language_models .TextGenerationModel .get_tuned_model (
701
+ preview_language_models .TextGenerationModel .get_tuned_model (
655
702
test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME
656
703
)
657
704
@@ -668,7 +715,7 @@ def test_chat(self):
668
715
_CHAT_BISON_PUBLISHER_MODEL_DICT
669
716
),
670
717
) as mock_get_publisher_model :
671
- model = language_models .ChatModel .from_pretrained ("chat-bison@001" )
718
+ model = preview_language_models .ChatModel .from_pretrained ("chat-bison@001" )
672
719
673
720
mock_get_publisher_model .assert_called_once_with (
674
721
name = "publishers/google/models/chat-bison@001" , retry = base ._DEFAULT_RETRY
@@ -681,11 +728,11 @@ def test_chat(self):
681
728
My favorite movies are Lord of the Rings and Hobbit.
682
729
""" ,
683
730
examples = [
684
- language_models .InputOutputTextPair (
731
+ preview_language_models .InputOutputTextPair (
685
732
input_text = "Who do you work for?" ,
686
733
output_text = "I work for Ned." ,
687
734
),
688
- language_models .InputOutputTextPair (
735
+ preview_language_models .InputOutputTextPair (
689
736
input_text = "What do I like?" ,
690
737
output_text = "Ned likes watching movies." ,
691
738
),
@@ -786,7 +833,7 @@ def test_code_chat(self):
786
833
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
787
834
),
788
835
) as mock_get_publisher_model :
789
- model = language_models .CodeChatModel .from_pretrained (
836
+ model = preview_language_models .CodeChatModel .from_pretrained (
790
837
"google/codechat-bison@001"
791
838
)
792
839
@@ -882,7 +929,7 @@ def test_code_generation(self):
882
929
_CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT
883
930
),
884
931
) as mock_get_publisher_model :
885
- model = language_models .CodeGenerationModel .from_pretrained (
932
+ model = preview_language_models .CodeGenerationModel .from_pretrained (
886
933
"google/code-bison@001"
887
934
)
888
935
@@ -909,9 +956,11 @@ def test_code_generation(self):
909
956
# Validating the parameters
910
957
predict_temperature = 0.1
911
958
predict_max_output_tokens = 100
912
- default_temperature = language_models .CodeGenerationModel ._DEFAULT_TEMPERATURE
959
+ default_temperature = (
960
+ preview_language_models .CodeGenerationModel ._DEFAULT_TEMPERATURE
961
+ )
913
962
default_max_output_tokens = (
914
- language_models .CodeGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS
963
+ preview_language_models .CodeGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS
915
964
)
916
965
917
966
with mock .patch .object (
@@ -948,7 +997,7 @@ def test_code_completion(self):
948
997
_CODE_COMPLETION_BISON_PUBLISHER_MODEL_DICT
949
998
),
950
999
) as mock_get_publisher_model :
951
- model = language_models .CodeGenerationModel .from_pretrained (
1000
+ model = preview_language_models .CodeGenerationModel .from_pretrained (
952
1001
"google/code-gecko@001"
953
1002
)
954
1003
@@ -975,9 +1024,11 @@ def test_code_completion(self):
975
1024
# Validating the parameters
976
1025
predict_temperature = 0.1
977
1026
predict_max_output_tokens = 100
978
- default_temperature = language_models .CodeGenerationModel ._DEFAULT_TEMPERATURE
1027
+ default_temperature = (
1028
+ preview_language_models .CodeGenerationModel ._DEFAULT_TEMPERATURE
1029
+ )
979
1030
default_max_output_tokens = (
980
- language_models .CodeGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS
1031
+ preview_language_models .CodeGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS
981
1032
)
982
1033
983
1034
with mock .patch .object (
@@ -1002,6 +1053,43 @@ def test_code_completion(self):
1002
1053
assert prediction_parameters ["maxOutputTokens" ] == default_max_output_tokens
1003
1054
1004
1055
def test_text_embedding (self ):
1056
+ """Tests the text embedding model."""
1057
+ aiplatform .init (
1058
+ project = _TEST_PROJECT ,
1059
+ location = _TEST_LOCATION ,
1060
+ )
1061
+ with mock .patch .object (
1062
+ target = model_garden_service_client .ModelGardenServiceClient ,
1063
+ attribute = "get_publisher_model" ,
1064
+ return_value = gca_publisher_model .PublisherModel (
1065
+ _TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
1066
+ ),
1067
+ ) as mock_get_publisher_model :
1068
+ model = preview_language_models .TextEmbeddingModel .from_pretrained (
1069
+ "textembedding-gecko@001"
1070
+ )
1071
+
1072
+ mock_get_publisher_model .assert_called_once_with (
1073
+ name = "publishers/google/models/textembedding-gecko@001" ,
1074
+ retry = base ._DEFAULT_RETRY ,
1075
+ )
1076
+
1077
+ gca_predict_response = gca_prediction_service .PredictResponse ()
1078
+ gca_predict_response .predictions .append (_TEST_TEXT_EMBEDDING_PREDICTION )
1079
+
1080
+ with mock .patch .object (
1081
+ target = prediction_service_client .PredictionServiceClient ,
1082
+ attribute = "predict" ,
1083
+ return_value = gca_predict_response ,
1084
+ ):
1085
+ embeddings = model .get_embeddings (["What is life?" ])
1086
+ assert embeddings
1087
+ for embedding in embeddings :
1088
+ vector = embedding .values
1089
+ assert len (vector ) == _TEXT_EMBEDDING_VECTOR_LENGTH
1090
+ assert vector == _TEST_TEXT_EMBEDDING_PREDICTION ["embeddings" ]["values" ]
1091
+
1092
+ def test_text_embedding_ga (self ):
1005
1093
"""Tests the text embedding model."""
1006
1094
aiplatform .init (
1007
1095
project = _TEST_PROJECT ,
0 commit comments