@@ -92,13 +92,15 @@ def _model_resource_name(self) -> str:
92
92
@dataclasses .dataclass
93
93
class _PredictionRequest :
94
94
"""A single-instance prediction request."""
95
+
95
96
instance : Dict [str , Any ]
96
97
parameters : Optional [Dict [str , Any ]] = None
97
98
98
99
99
100
@dataclasses .dataclass
100
101
class _MultiInstancePredictionRequest :
101
102
"""A multi-instance prediction request."""
103
+
102
104
instances : List [Dict [str , Any ]]
103
105
parameters : Optional [Dict [str , Any ]] = None
104
106
@@ -573,6 +575,62 @@ def tune_model(
573
575
return job
574
576
575
577
578
+ @dataclasses .dataclass
579
+ class CountTokensResponse :
580
+ """The response from a count_tokens request.
581
+ Attributes:
582
+ total_tokens (int):
583
+ The total number of tokens counted across all
584
+ instances passed to the request.
585
+ total_billable_characters (int):
586
+ The total number of billable characters
587
+ counted across all instances from the request.
588
+ """
589
+
590
+ total_tokens : int
591
+ total_billable_characters : int
592
+ _count_tokens_response : Any
593
+
594
+
595
+ class _CountTokensMixin (_LanguageModel ):
596
+ """Mixin for models that support the CountTokens API"""
597
+
598
+ def count_tokens (
599
+ self ,
600
+ prompts : List [str ],
601
+ ) -> CountTokensResponse :
602
+ """Counts the tokens and billable characters for a given prompt.
603
+
604
+ Note: this does not make a request to the model, it only counts the tokens
605
+ in the request.
606
+
607
+ Args:
608
+ prompts (List[str]):
609
+ Required. A list of prompts to ask the model. For example: ["What should I do today?", "How's it going?"]
610
+
611
+ Returns:
612
+ A `CountTokensResponse` object that contains the number of tokens
613
+ in the text and the number of billable characters.
614
+ """
615
+ instances = []
616
+
617
+ for prompt in prompts :
618
+ instances .append ({"content" : prompt })
619
+
620
+ count_tokens_response = self ._endpoint ._prediction_client .select_version (
621
+ "v1beta1"
622
+ ).count_tokens (
623
+ endpoint = self ._endpoint_name ,
624
+ instances = instances ,
625
+ )
626
+
627
+ return CountTokensResponse (
628
+ total_tokens = count_tokens_response .total_tokens ,
629
+ total_billable_characters = count_tokens_response .total_billable_characters ,
630
+ _count_tokens_response = count_tokens_response ,
631
+ )
632
+
633
+
576
634
@dataclasses .dataclass
577
635
class TuningEvaluationSpec :
578
636
"""Specification for model evaluation to perform during tuning.
@@ -587,6 +645,7 @@ class TuningEvaluationSpec:
587
645
tensorboard: Vertex Tensorboard where to write the evaluation metrics.
588
646
The Tensorboard must be in the same location as the tuning job.
589
647
"""
648
+
590
649
__module__ = "vertexai.language_models"
591
650
592
651
evaluation_data : str
@@ -605,6 +664,7 @@ class TextGenerationResponse:
605
664
Learn more about the safety attributes here:
606
665
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
607
666
"""
667
+
608
668
__module__ = "vertexai.language_models"
609
669
610
670
text : str
@@ -761,7 +821,9 @@ def predict_streaming(
761
821
)
762
822
763
823
prediction_service_client = self ._endpoint ._prediction_client
764
- for prediction_dict in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
824
+ for (
825
+ prediction_dict
826
+ ) in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
765
827
prediction_service_client = prediction_service_client ,
766
828
endpoint_name = self ._endpoint_name ,
767
829
instance = prediction_request .instance ,
@@ -955,6 +1017,7 @@ class _PreviewTextGenerationModel(
955
1017
_PreviewTunableTextModelMixin ,
956
1018
_PreviewModelWithBatchPredict ,
957
1019
_evaluatable_language_models ._EvaluatableLanguageModel ,
1020
+ _CountTokensMixin ,
958
1021
):
959
1022
# Do not add docstring so that it's inherited from the base class.
960
1023
__name__ = "TextGenerationModel"
@@ -1094,6 +1157,7 @@ class TextEmbeddingInput:
1094
1157
Specifies that the embeddings will be used for clustering.
1095
1158
title: Optional identifier of the text content.
1096
1159
"""
1160
+
1097
1161
__module__ = "vertexai.language_models"
1098
1162
1099
1163
text : str
@@ -1113,6 +1177,7 @@ class TextEmbeddingModel(_LanguageModel):
1113
1177
vector = embedding.values
1114
1178
print(len(vector))
1115
1179
"""
1180
+
1116
1181
__module__ = "vertexai.language_models"
1117
1182
1118
1183
_LAUNCH_STAGE = _model_garden_models ._SDK_GA_LAUNCH_STAGE
@@ -1173,7 +1238,8 @@ def _parse_text_embedding_response(
1173
1238
_prediction_response = prediction_response ,
1174
1239
)
1175
1240
1176
- def get_embeddings (self ,
1241
+ def get_embeddings (
1242
+ self ,
1177
1243
texts : List [Union [str , TextEmbeddingInput ]],
1178
1244
* ,
1179
1245
auto_truncate : bool = True ,
@@ -1207,7 +1273,8 @@ def get_embeddings(self,
1207
1273
1208
1274
return results
1209
1275
1210
- async def get_embeddings_async (self ,
1276
+ async def get_embeddings_async (
1277
+ self ,
1211
1278
texts : List [Union [str , TextEmbeddingInput ]],
1212
1279
* ,
1213
1280
auto_truncate : bool = True ,
@@ -1242,7 +1309,9 @@ async def get_embeddings_async(self,
1242
1309
return results
1243
1310
1244
1311
1245
- class _PreviewTextEmbeddingModel (TextEmbeddingModel , _ModelWithBatchPredict ):
1312
+ class _PreviewTextEmbeddingModel (
1313
+ TextEmbeddingModel , _ModelWithBatchPredict , _CountTokensMixin
1314
+ ):
1246
1315
__name__ = "TextEmbeddingModel"
1247
1316
__module__ = "vertexai.preview.language_models"
1248
1317
@@ -1252,6 +1321,7 @@ class _PreviewTextEmbeddingModel(TextEmbeddingModel, _ModelWithBatchPredict):
1252
1321
@dataclasses .dataclass
1253
1322
class TextEmbeddingStatistics :
1254
1323
"""Text embedding statistics."""
1324
+
1255
1325
__module__ = "vertexai.language_models"
1256
1326
1257
1327
token_count : int
@@ -1261,6 +1331,7 @@ class TextEmbeddingStatistics:
1261
1331
@dataclasses .dataclass
1262
1332
class TextEmbedding :
1263
1333
"""Text embedding vector and statistics."""
1334
+
1264
1335
__module__ = "vertexai.language_models"
1265
1336
1266
1337
values : List [float ]
@@ -1271,6 +1342,7 @@ class TextEmbedding:
1271
1342
@dataclasses .dataclass
1272
1343
class InputOutputTextPair :
1273
1344
"""InputOutputTextPair represents a pair of input and output texts."""
1345
+
1274
1346
__module__ = "vertexai.language_models"
1275
1347
1276
1348
input_text : str
@@ -1285,6 +1357,7 @@ class ChatMessage:
1285
1357
content: Content of the message.
1286
1358
author: Author of the message.
1287
1359
"""
1360
+
1288
1361
__module__ = "vertexai.language_models"
1289
1362
1290
1363
content : str
@@ -1362,6 +1435,7 @@ class ChatModel(_ChatModelBase, _TunableChatModelMixin):
1362
1435
1363
1436
chat.send_message("Do you know any cool events this weekend?")
1364
1437
"""
1438
+
1365
1439
__module__ = "vertexai.language_models"
1366
1440
1367
1441
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
@@ -1388,6 +1462,7 @@ class CodeChatModel(_ChatModelBase):
1388
1462
1389
1463
code_chat.send_message("Please help write a function to calculate the min of two numbers")
1390
1464
"""
1465
+
1391
1466
__module__ = "vertexai.language_models"
1392
1467
1393
1468
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/codechat_generation_1.0.0.yaml"
@@ -1739,7 +1814,9 @@ def send_message_streaming(
1739
1814
1740
1815
full_response_text = ""
1741
1816
1742
- for prediction_dict in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
1817
+ for (
1818
+ prediction_dict
1819
+ ) in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
1743
1820
prediction_service_client = prediction_service_client ,
1744
1821
endpoint_name = self ._model ._endpoint_name ,
1745
1822
instance = prediction_request .instance ,
@@ -1770,6 +1847,7 @@ class ChatSession(_ChatSessionBase):
1770
1847
1771
1848
Within a chat session, the model keeps context and remembers the previous conversation.
1772
1849
"""
1850
+
1773
1851
__module__ = "vertexai.language_models"
1774
1852
1775
1853
def __init__ (
@@ -1802,6 +1880,7 @@ class CodeChatSession(_ChatSessionBase):
1802
1880
1803
1881
Within a code chat session, the model keeps context and remembers the previous converstion.
1804
1882
"""
1883
+
1805
1884
__module__ = "vertexai.language_models"
1806
1885
1807
1886
def __init__ (
@@ -1924,6 +2003,7 @@ class CodeGenerationModel(_LanguageModel):
1924
2003
prefix="def reverse_string(s):",
1925
2004
))
1926
2005
"""
2006
+
1927
2007
__module__ = "vertexai.language_models"
1928
2008
1929
2009
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
@@ -2074,7 +2154,9 @@ def predict_streaming(
2074
2154
)
2075
2155
2076
2156
prediction_service_client = self ._endpoint ._prediction_client
2077
- for prediction_dict in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
2157
+ for (
2158
+ prediction_dict
2159
+ ) in _streaming_prediction .predict_stream_of_dicts_from_single_dict (
2078
2160
prediction_service_client = prediction_service_client ,
2079
2161
endpoint_name = self ._endpoint_name ,
2080
2162
instance = prediction_request .instance ,
0 commit comments