@@ -222,7 +222,9 @@ def tune_model(
222
222
if eval_spec .evaluation_data :
223
223
if isinstance (eval_spec .evaluation_data , str ):
224
224
if eval_spec .evaluation_data .startswith ("gs://" ):
225
- tuning_parameters ["evaluation_data_uri" ] = eval_spec .evaluation_data
225
+ tuning_parameters [
226
+ "evaluation_data_uri"
227
+ ] = eval_spec .evaluation_data
226
228
else :
227
229
raise ValueError ("evaluation_data should be a GCS URI" )
228
230
else :
@@ -627,7 +629,7 @@ def count_tokens(
627
629
) -> CountTokensResponse :
628
630
"""Counts the tokens and billable characters for a given prompt.
629
631
630
- Note: this does not make a request to the model, it only counts the tokens
632
+ Note: this does not make a prediction request to the model, it only counts the tokens
631
633
in the request.
632
634
633
635
Args:
@@ -802,7 +804,9 @@ def predict(
802
804
parameters = prediction_request .parameters ,
803
805
)
804
806
805
- return _parse_text_generation_model_multi_candidate_response (prediction_response )
807
+ return _parse_text_generation_model_multi_candidate_response (
808
+ prediction_response
809
+ )
806
810
807
811
async def predict_async (
808
812
self ,
@@ -844,7 +848,9 @@ async def predict_async(
844
848
parameters = prediction_request .parameters ,
845
849
)
846
850
847
- return _parse_text_generation_model_multi_candidate_response (prediction_response )
851
+ return _parse_text_generation_model_multi_candidate_response (
852
+ prediction_response
853
+ )
848
854
849
855
def predict_streaming (
850
856
self ,
@@ -1587,6 +1593,47 @@ class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin):
1587
1593
1588
1594
_LAUNCH_STAGE = _model_garden_models ._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
1589
1595
1596
+ def start_chat (
1597
+ self ,
1598
+ * ,
1599
+ context : Optional [str ] = None ,
1600
+ examples : Optional [List [InputOutputTextPair ]] = None ,
1601
+ max_output_tokens : Optional [int ] = None ,
1602
+ temperature : Optional [float ] = None ,
1603
+ top_k : Optional [int ] = None ,
1604
+ top_p : Optional [float ] = None ,
1605
+ message_history : Optional [List [ChatMessage ]] = None ,
1606
+ stop_sequences : Optional [List [str ]] = None ,
1607
+ ) -> "_PreviewChatSession" :
1608
+ """Starts a chat session with the model.
1609
+
1610
+ Args:
1611
+ context: Context shapes how the model responds throughout the conversation.
1612
+ For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
1613
+ examples: List of structured messages to the model to learn how to respond to the conversation.
1614
+ A list of `InputOutputTextPair` objects.
1615
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1616
+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1617
+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1618
+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1619
+ message_history: A list of previously sent and received messages.
1620
+ stop_sequences: Customized stop sequences to stop the decoding process.
1621
+
1622
+ Returns:
1623
+ A `ChatSession` object.
1624
+ """
1625
+ return _PreviewChatSession (
1626
+ model = self ,
1627
+ context = context ,
1628
+ examples = examples ,
1629
+ max_output_tokens = max_output_tokens ,
1630
+ temperature = temperature ,
1631
+ top_k = top_k ,
1632
+ top_p = top_p ,
1633
+ message_history = message_history ,
1634
+ stop_sequences = stop_sequences ,
1635
+ )
1636
+
1590
1637
1591
1638
class CodeChatModel (_ChatModelBase ):
1592
1639
"""CodeChatModel represents a model that is capable of completing code.
@@ -1646,6 +1693,47 @@ class _PreviewCodeChatModel(CodeChatModel, _TunableChatModelMixin):
1646
1693
1647
1694
_LAUNCH_STAGE = _model_garden_models ._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
1648
1695
1696
+ def start_chat (
1697
+ self ,
1698
+ * ,
1699
+ context : Optional [str ] = None ,
1700
+ examples : Optional [List [InputOutputTextPair ]] = None ,
1701
+ max_output_tokens : Optional [int ] = None ,
1702
+ temperature : Optional [float ] = None ,
1703
+ top_k : Optional [int ] = None ,
1704
+ top_p : Optional [float ] = None ,
1705
+ message_history : Optional [List [ChatMessage ]] = None ,
1706
+ stop_sequences : Optional [List [str ]] = None ,
1707
+ ) -> "_PreviewCodeChatSession" :
1708
+ """Starts a chat session with the model.
1709
+
1710
+ Args:
1711
+ context: Context shapes how the model responds throughout the conversation.
1712
+ For example, you can use context to specify words the model can or cannot use, topics to focus on or avoid, or the response format or style
1713
+ examples: List of structured messages to the model to learn how to respond to the conversation.
1714
+ A list of `InputOutputTextPair` objects.
1715
+ max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
1716
+ temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
1717
+ top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
1718
+ top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1719
+ message_history: A list of previously sent and received messages.
1720
+ stop_sequences: Customized stop sequences to stop the decoding process.
1721
+
1722
+ Returns:
1723
+ A `ChatSession` object.
1724
+ """
1725
+ return _PreviewCodeChatSession (
1726
+ model = self ,
1727
+ context = context ,
1728
+ examples = examples ,
1729
+ max_output_tokens = max_output_tokens ,
1730
+ temperature = temperature ,
1731
+ top_k = top_k ,
1732
+ top_p = top_p ,
1733
+ message_history = message_history ,
1734
+ stop_sequences = stop_sequences ,
1735
+ )
1736
+
1649
1737
1650
1738
class _ChatSessionBase :
1651
1739
"""_ChatSessionBase is a base class for all chat sessions."""
@@ -2071,6 +2159,67 @@ async def send_message_streaming_async(
2071
2159
)
2072
2160
2073
2161
2162
+ class _ChatSessionBaseWithCountTokensMixin (_ChatSessionBase ):
2163
+ """A mixin class for adding count_tokens to ChatSession."""
2164
+
2165
+ def count_tokens (
2166
+ self ,
2167
+ message : str ,
2168
+ ) -> CountTokensResponse :
2169
+ """Counts the tokens and billable characters for the provided chat message and any message history,
2170
+ context, or examples set on the chat session.
2171
+
2172
+ If you've called `send_message()` in the current chat session before calling `count_tokens()`, the
2173
+ response will include the total tokens and characters for the previously sent message and the one in the
2174
+ `count_tokens()` request. To count the tokens for a single message, call `count_tokens()` right after
2175
+ calling `start_chat()` before calling `send_message()`.
2176
+
2177
+ Note: this does not make a prediction request to the model, it only counts the tokens
2178
+ in the request.
2179
+
2180
+ Examples::
2181
+
2182
+ model = ChatModel.from_pretrained("chat-bison@001")
2183
+ chat_session = model.start_chat()
2184
+ count_tokens_response = chat_session.count_tokens("How's it going?")
2185
+
2186
+ count_tokens_response.total_tokens
2187
+ count_tokens_response.total_billable_characters
2188
+
2189
+ Args:
2190
+ message (str):
2191
+ Required. A chat message to count tokens or. For example: "How's it going?"
2192
+ Returns:
2193
+ A `CountTokensResponse` object that contains the number of tokens
2194
+ in the text and the number of billable characters.
2195
+ """
2196
+
2197
+ count_tokens_request = self ._prepare_request (message = message )
2198
+
2199
+ count_tokens_response = self ._model ._endpoint ._prediction_client .select_version (
2200
+ "v1beta1"
2201
+ ).count_tokens (
2202
+ endpoint = self ._model ._endpoint_name ,
2203
+ instances = [count_tokens_request .instance ],
2204
+ )
2205
+
2206
+ return CountTokensResponse (
2207
+ total_tokens = count_tokens_response .total_tokens ,
2208
+ total_billable_characters = count_tokens_response .total_billable_characters ,
2209
+ _count_tokens_response = count_tokens_response ,
2210
+ )
2211
+
2212
+
2213
+ class _PreviewChatSession (_ChatSessionBaseWithCountTokensMixin ):
2214
+
2215
+ __module__ = "vertexai.preview.language_models"
2216
+
2217
+
2218
+ class _PreviewCodeChatSession (_ChatSessionBaseWithCountTokensMixin ):
2219
+
2220
+ __module__ = "vertexai.preview.language_models"
2221
+
2222
+
2074
2223
class ChatSession (_ChatSessionBase ):
2075
2224
"""ChatSession represents a chat session with a language model.
2076
2225
@@ -2361,7 +2510,9 @@ def predict(
2361
2510
instances = [prediction_request .instance ],
2362
2511
parameters = prediction_request .parameters ,
2363
2512
)
2364
- return _parse_text_generation_model_multi_candidate_response (prediction_response )
2513
+ return _parse_text_generation_model_multi_candidate_response (
2514
+ prediction_response
2515
+ )
2365
2516
2366
2517
async def predict_async (
2367
2518
self ,
@@ -2400,7 +2551,9 @@ async def predict_async(
2400
2551
instances = [prediction_request .instance ],
2401
2552
parameters = prediction_request .parameters ,
2402
2553
)
2403
- return _parse_text_generation_model_multi_candidate_response (prediction_response )
2554
+ return _parse_text_generation_model_multi_candidate_response (
2555
+ prediction_response
2556
+ )
2404
2557
2405
2558
def predict_streaming (
2406
2559
self ,
0 commit comments