Skip to content

Commit f3b25ab

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: LLM - Fixed the TextGenerationModel.predict parameters
PiperOrigin-RevId: 555940714
1 parent af6e455 commit f3b25ab

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

tests/unit/aiplatform/test_language_models.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def test_text_generation_ga(self):
615615
target=prediction_service_client.PredictionServiceClient,
616616
attribute="predict",
617617
return_value=gca_predict_response,
618-
):
618+
) as mock_predict:
619619
response = model.predict(
620620
"What is the best recipe for banana bread? Recipe:",
621621
max_output_tokens=128,
@@ -624,8 +624,33 @@ def test_text_generation_ga(self):
624624
top_k=5,
625625
)
626626

627+
prediction_parameters = mock_predict.call_args[1]["parameters"]
628+
assert prediction_parameters["maxDecodeSteps"] == 128
629+
assert prediction_parameters["temperature"] == 0
630+
assert prediction_parameters["topP"] == 1
631+
assert prediction_parameters["topK"] == 5
627632
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
628633

634+
# Validating that unspecified parameters are not passed to the model
635+
# (except `max_output_tokens`).
636+
with mock.patch.object(
637+
target=prediction_service_client.PredictionServiceClient,
638+
attribute="predict",
639+
return_value=gca_predict_response,
640+
) as mock_predict:
641+
model.predict(
642+
"What is the best recipe for banana bread? Recipe:",
643+
)
644+
645+
prediction_parameters = mock_predict.call_args[1]["parameters"]
646+
assert (
647+
prediction_parameters["maxDecodeSteps"]
648+
== language_models.TextGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS
649+
)
650+
assert "temperature" not in prediction_parameters
651+
assert "topP" not in prediction_parameters
652+
assert "topK" not in prediction_parameters
653+
629654
@pytest.mark.parametrize(
630655
"job_spec",
631656
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB],

vertexai/language_models/_language_models.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,19 @@ def _batch_predict(
290290
A list of `TextGenerationResponse` objects that contain the texts produced by the model.
291291
"""
292292
instances = [{"content": str(prompt)} for prompt in prompts]
293-
prediction_parameters = {
294-
"temperature": temperature,
295-
"maxDecodeSteps": max_output_tokens,
296-
"topP": top_p,
297-
"topK": top_k,
298-
}
293+
prediction_parameters = {}
294+
295+
if max_output_tokens:
296+
prediction_parameters["maxDecodeSteps"] = max_output_tokens
297+
298+
if temperature is not None:
299+
prediction_parameters["temperature"] = temperature
300+
301+
if top_p:
302+
prediction_parameters["topP"] = top_p
303+
304+
if top_k:
305+
prediction_parameters["topK"] = top_k
299306

300307
prediction_response = self._endpoint.predict(
301308
instances=instances,

0 commit comments

Comments
 (0)