Skip to content

Commit 701c3a2

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Released the BatchPrediction to GA for TextGenerationModel
PiperOrigin-RevId: 548634713
1 parent eabe720 commit 701c3a2

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

tests/unit/aiplatform/test_language_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@ def test_batch_prediction(self):
13021302
_TEXT_BISON_PUBLISHER_MODEL_DICT
13031303
),
13041304
):
1305-
model = preview_language_models.TextGenerationModel.from_pretrained(
1305+
model = language_models.TextGenerationModel.from_pretrained(
13061306
"text-bison@001"
13071307
)
13081308

vertexai/language_models/_language_models.py

+27-26
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def __repr__(self):
223223
return self.text
224224

225225

226-
class TextGenerationModel(_LanguageModel):
226+
class _TextGenerationModel(_LanguageModel):
227227
"""TextGenerationModel represents a general language model.
228228
229229
Examples::
@@ -324,9 +324,6 @@ def _batch_predict(
324324
return results
325325

326326

327-
_TextGenerationModel = TextGenerationModel
328-
329-
330327
class _ModelWithBatchPredict(_LanguageModel):
331328
"""Model that supports batch prediction."""
332329

@@ -432,15 +429,19 @@ def batch_predict(
432429
)
433430

434431

432+
class TextGenerationModel(_TextGenerationModel, _ModelWithBatchPredict):
433+
pass
434+
435+
435436
class _PreviewTextGenerationModel(
436-
TextGenerationModel, _TunableModelMixin, _PreviewModelWithBatchPredict
437+
_TextGenerationModel, _TunableModelMixin, _PreviewModelWithBatchPredict
437438
):
438439
"""Preview text generation model."""
439440

440441
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
441442

442443

443-
class _ChatModel(TextGenerationModel):
444+
class _ChatModel(_TextGenerationModel):
444445
"""ChatModel represents a language model that is capable of chat.
445446
446447
Examples::
@@ -457,10 +458,10 @@ class _ChatModel(TextGenerationModel):
457458

458459
def start_chat(
459460
self,
460-
max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
461-
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
462-
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
463-
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
461+
max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
462+
temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE,
463+
top_k: int = _TextGenerationModel._DEFAULT_TOP_K,
464+
top_p: float = _TextGenerationModel._DEFAULT_TOP_P,
464465
) -> "_ChatSession":
465466
"""Starts a chat session with the model.
466467
@@ -491,10 +492,10 @@ class _ChatSession:
491492
def __init__(
492493
self,
493494
model: _ChatModel,
494-
max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
495-
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
496-
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
497-
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
495+
max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
496+
temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE,
497+
top_k: int = _TextGenerationModel._DEFAULT_TOP_K,
498+
top_p: float = _TextGenerationModel._DEFAULT_TOP_P,
498499
):
499500
self._model = model
500501
self._history = []
@@ -635,10 +636,10 @@ def start_chat(
635636
*,
636637
context: Optional[str] = None,
637638
examples: Optional[List[InputOutputTextPair]] = None,
638-
max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
639-
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
640-
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
641-
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
639+
max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
640+
temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE,
641+
top_k: int = _TextGenerationModel._DEFAULT_TOP_K,
642+
top_p: float = _TextGenerationModel._DEFAULT_TOP_P,
642643
message_history: Optional[List[ChatMessage]] = None,
643644
) -> "ChatSession":
644645
"""Starts a chat session with the model.
@@ -754,10 +755,10 @@ def __init__(
754755
model: _ChatModelBase,
755756
context: Optional[str] = None,
756757
examples: Optional[List[InputOutputTextPair]] = None,
757-
max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
758-
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
759-
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
760-
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
758+
max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
759+
temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE,
760+
top_k: int = _TextGenerationModel._DEFAULT_TOP_K,
761+
top_p: float = _TextGenerationModel._DEFAULT_TOP_P,
761762
is_code_chat_session: bool = False,
762763
message_history: Optional[List[ChatMessage]] = None,
763764
):
@@ -885,10 +886,10 @@ def __init__(
885886
model: ChatModel,
886887
context: Optional[str] = None,
887888
examples: Optional[List[InputOutputTextPair]] = None,
888-
max_output_tokens: int = TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
889-
temperature: float = TextGenerationModel._DEFAULT_TEMPERATURE,
890-
top_k: int = TextGenerationModel._DEFAULT_TOP_K,
891-
top_p: float = TextGenerationModel._DEFAULT_TOP_P,
889+
max_output_tokens: int = _TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS,
890+
temperature: float = _TextGenerationModel._DEFAULT_TEMPERATURE,
891+
top_k: int = _TextGenerationModel._DEFAULT_TOP_K,
892+
top_p: float = _TextGenerationModel._DEFAULT_TOP_P,
892893
message_history: Optional[List[ChatMessage]] = None,
893894
):
894895
super().__init__(

0 commit comments

Comments
 (0)