@@ -223,7 +223,7 @@ def __repr__(self):
223
223
return self .text
224
224
225
225
226
- class TextGenerationModel (_LanguageModel ):
226
+ class _TextGenerationModel (_LanguageModel ):
227
227
"""TextGenerationModel represents a general language model.
228
228
229
229
Examples::
@@ -324,9 +324,6 @@ def _batch_predict(
324
324
return results
325
325
326
326
327
- _TextGenerationModel = TextGenerationModel
328
-
329
-
330
327
class _ModelWithBatchPredict (_LanguageModel ):
331
328
"""Model that supports batch prediction."""
332
329
@@ -432,15 +429,19 @@ def batch_predict(
432
429
)
433
430
434
431
432
+ class TextGenerationModel (_TextGenerationModel , _ModelWithBatchPredict ):
433
+ pass
434
+
435
+
435
436
class _PreviewTextGenerationModel (
436
- TextGenerationModel , _TunableModelMixin , _PreviewModelWithBatchPredict
437
+ _TextGenerationModel , _TunableModelMixin , _PreviewModelWithBatchPredict
437
438
):
438
439
"""Preview text generation model."""
439
440
440
441
_LAUNCH_STAGE = _model_garden_models ._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
441
442
442
443
443
- class _ChatModel (TextGenerationModel ):
444
+ class _ChatModel (_TextGenerationModel ):
444
445
"""ChatModel represents a language model that is capable of chat.
445
446
446
447
Examples::
@@ -457,10 +458,10 @@ class _ChatModel(TextGenerationModel):
457
458
458
459
def start_chat (
459
460
self ,
460
- max_output_tokens : int = TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
461
- temperature : float = TextGenerationModel ._DEFAULT_TEMPERATURE ,
462
- top_k : int = TextGenerationModel ._DEFAULT_TOP_K ,
463
- top_p : float = TextGenerationModel ._DEFAULT_TOP_P ,
461
+ max_output_tokens : int = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
462
+ temperature : float = _TextGenerationModel ._DEFAULT_TEMPERATURE ,
463
+ top_k : int = _TextGenerationModel ._DEFAULT_TOP_K ,
464
+ top_p : float = _TextGenerationModel ._DEFAULT_TOP_P ,
464
465
) -> "_ChatSession" :
465
466
"""Starts a chat session with the model.
466
467
@@ -491,10 +492,10 @@ class _ChatSession:
491
492
def __init__ (
492
493
self ,
493
494
model : _ChatModel ,
494
- max_output_tokens : int = TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
495
- temperature : float = TextGenerationModel ._DEFAULT_TEMPERATURE ,
496
- top_k : int = TextGenerationModel ._DEFAULT_TOP_K ,
497
- top_p : float = TextGenerationModel ._DEFAULT_TOP_P ,
495
+ max_output_tokens : int = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
496
+ temperature : float = _TextGenerationModel ._DEFAULT_TEMPERATURE ,
497
+ top_k : int = _TextGenerationModel ._DEFAULT_TOP_K ,
498
+ top_p : float = _TextGenerationModel ._DEFAULT_TOP_P ,
498
499
):
499
500
self ._model = model
500
501
self ._history = []
@@ -635,10 +636,10 @@ def start_chat(
635
636
* ,
636
637
context : Optional [str ] = None ,
637
638
examples : Optional [List [InputOutputTextPair ]] = None ,
638
- max_output_tokens : int = TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
639
- temperature : float = TextGenerationModel ._DEFAULT_TEMPERATURE ,
640
- top_k : int = TextGenerationModel ._DEFAULT_TOP_K ,
641
- top_p : float = TextGenerationModel ._DEFAULT_TOP_P ,
639
+ max_output_tokens : int = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
640
+ temperature : float = _TextGenerationModel ._DEFAULT_TEMPERATURE ,
641
+ top_k : int = _TextGenerationModel ._DEFAULT_TOP_K ,
642
+ top_p : float = _TextGenerationModel ._DEFAULT_TOP_P ,
642
643
message_history : Optional [List [ChatMessage ]] = None ,
643
644
) -> "ChatSession" :
644
645
"""Starts a chat session with the model.
@@ -754,10 +755,10 @@ def __init__(
754
755
model : _ChatModelBase ,
755
756
context : Optional [str ] = None ,
756
757
examples : Optional [List [InputOutputTextPair ]] = None ,
757
- max_output_tokens : int = TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
758
- temperature : float = TextGenerationModel ._DEFAULT_TEMPERATURE ,
759
- top_k : int = TextGenerationModel ._DEFAULT_TOP_K ,
760
- top_p : float = TextGenerationModel ._DEFAULT_TOP_P ,
758
+ max_output_tokens : int = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
759
+ temperature : float = _TextGenerationModel ._DEFAULT_TEMPERATURE ,
760
+ top_k : int = _TextGenerationModel ._DEFAULT_TOP_K ,
761
+ top_p : float = _TextGenerationModel ._DEFAULT_TOP_P ,
761
762
is_code_chat_session : bool = False ,
762
763
message_history : Optional [List [ChatMessage ]] = None ,
763
764
):
@@ -885,10 +886,10 @@ def __init__(
885
886
model : ChatModel ,
886
887
context : Optional [str ] = None ,
887
888
examples : Optional [List [InputOutputTextPair ]] = None ,
888
- max_output_tokens : int = TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
889
- temperature : float = TextGenerationModel ._DEFAULT_TEMPERATURE ,
890
- top_k : int = TextGenerationModel ._DEFAULT_TOP_K ,
891
- top_p : float = TextGenerationModel ._DEFAULT_TOP_P ,
889
+ max_output_tokens : int = _TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS ,
890
+ temperature : float = _TextGenerationModel ._DEFAULT_TEMPERATURE ,
891
+ top_k : int = _TextGenerationModel ._DEFAULT_TOP_K ,
892
+ top_p : float = _TextGenerationModel ._DEFAULT_TOP_P ,
892
893
message_history : Optional [List [ChatMessage ]] = None ,
893
894
):
894
895
super ().__init__ (
0 commit comments