@@ -615,7 +615,7 @@ def test_text_generation_ga(self):
615
615
target = prediction_service_client .PredictionServiceClient ,
616
616
attribute = "predict" ,
617
617
return_value = gca_predict_response ,
618
- ):
618
+ ) as mock_predict :
619
619
response = model .predict (
620
620
"What is the best recipe for banana bread? Recipe:" ,
621
621
max_output_tokens = 128 ,
@@ -624,8 +624,33 @@ def test_text_generation_ga(self):
624
624
top_k = 5 ,
625
625
)
626
626
627
+ prediction_parameters = mock_predict .call_args [1 ]["parameters" ]
628
+ assert prediction_parameters ["maxDecodeSteps" ] == 128
629
+ assert prediction_parameters ["temperature" ] == 0
630
+ assert prediction_parameters ["topP" ] == 1
631
+ assert prediction_parameters ["topK" ] == 5
627
632
assert response .text == _TEST_TEXT_GENERATION_PREDICTION ["content" ]
628
633
634
+ # Validating that unspecified parameters are not passed to the model
635
+ # (except `max_output_tokens`).
636
+ with mock .patch .object (
637
+ target = prediction_service_client .PredictionServiceClient ,
638
+ attribute = "predict" ,
639
+ return_value = gca_predict_response ,
640
+ ) as mock_predict :
641
+ model .predict (
642
+ "What is the best recipe for banana bread? Recipe:" ,
643
+ )
644
+
645
+ prediction_parameters = mock_predict .call_args [1 ]["parameters" ]
646
+ assert (
647
+ prediction_parameters ["maxDecodeSteps" ]
648
+ == language_models .TextGenerationModel ._DEFAULT_MAX_OUTPUT_TOKENS
649
+ )
650
+ assert "temperature" not in prediction_parameters
651
+ assert "topP" not in prediction_parameters
652
+ assert "topK" not in prediction_parameters
653
+
629
654
@pytest .mark .parametrize (
630
655
"job_spec" ,
631
656
[_TEST_PIPELINE_SPEC_JSON , _TEST_PIPELINE_JOB ],
0 commit comments