Skip to content

Commit cb2f4aa

Browse files
mekencopybara-github
authored andcommitted
feat: LLM - Added the seed parameter to the TextGenerationModel's predict methods
Copybara import of the project: -- 6e23d68 by Murat Eken <[email protected]>: fix: missing request parameters -- cf2bff5 by Murat Eken <[email protected]>: removing the echo parameter, this fix only includes seed COPYBARA_INTEGRATE_REVIEW=#3186 from meken:main 32877b4 PiperOrigin-RevId: 638066958
1 parent 3e4fc18 commit cb2f4aa

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

tests/unit/aiplatform/test_language_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -1920,6 +1920,7 @@ def test_text_generation_ga(self):
19201920
presence_penalty=1.0,
19211921
frequency_penalty=1.0,
19221922
logit_bias={1: 100.0, 2: -100.0},
1923+
seed=42,
19231924
)
19241925

19251926
expected_errors = (100,)
@@ -1933,6 +1934,7 @@ def test_text_generation_ga(self):
19331934
assert prediction_parameters["presencePenalty"] == 1.0
19341935
assert prediction_parameters["frequencyPenalty"] == 1.0
19351936
assert prediction_parameters["logitBias"] == {1: 100.0, 2: -100.0}
1937+
assert prediction_parameters["seed"] == 42
19361938
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
19371939
assert response.errors == expected_errors
19381940

vertexai/language_models/_language_models.py

+42
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,7 @@ def predict(
13551355
presence_penalty: Optional[float] = None,
13561356
frequency_penalty: Optional[float] = None,
13571357
logit_bias: Optional[Dict[int, float]] = None,
1358+
seed: Optional[int] = None,
13581359
) -> "MultiCandidateTextGenerationResponse":
13591360
"""Gets model response for a single prompt.
13601361
@@ -1387,6 +1388,12 @@ def predict(
13871388
Larger positive bias increases the probability of choosing the token.
13881389
Smaller negative bias decreases the probability of choosing the token.
13891390
Range: [-100.0, 100.0]
1391+
seed:
1392+
Decoder generates random noise with a pseudo random number generator, temperature * noise is added to
1393+
logits before sampling. The pseudo random number generator (prng) takes a seed as input, it generates
1394+
the same output with the same seed. If seed is not set, the seed used in decoder will not be
1395+
deterministic, thus the generated random noise will not be deterministic. If seed is set, the
1396+
generated random noise will be deterministic.
13901397
13911398
Returns:
13921399
A `MultiCandidateTextGenerationResponse` object that contains the text produced by the model.
@@ -1404,6 +1411,7 @@ def predict(
14041411
presence_penalty=presence_penalty,
14051412
frequency_penalty=frequency_penalty,
14061413
logit_bias=logit_bias,
1414+
seed=seed,
14071415
)
14081416

14091417
prediction_response = self._endpoint.predict(
@@ -1436,6 +1444,7 @@ async def predict_async(
14361444
presence_penalty: Optional[float] = None,
14371445
frequency_penalty: Optional[float] = None,
14381446
logit_bias: Optional[Dict[int, float]] = None,
1447+
seed: Optional[int] = None,
14391448
) -> "MultiCandidateTextGenerationResponse":
14401449
"""Asynchronously gets model response for a single prompt.
14411450
@@ -1468,6 +1477,12 @@ async def predict_async(
14681477
Larger positive bias increases the probability of choosing the token.
14691478
Smaller negative bias decreases the probability of choosing the token.
14701479
Range: [-100.0, 100.0]
1480+
seed:
1481+
Decoder generates random noise with a pseudo random number generator, temperature * noise is added to
1482+
logits before sampling. The pseudo random number generator (prng) takes a seed as input, it generates
1483+
the same output with the same seed. If seed is not set, the seed used in decoder will not be
1484+
deterministic, thus the generated random noise will not be deterministic. If seed is set, the
1485+
generated random noise will be deterministic.
14711486
14721487
Returns:
14731488
A `MultiCandidateTextGenerationResponse` object that contains the text produced by the model.
@@ -1485,6 +1500,7 @@ async def predict_async(
14851500
presence_penalty=presence_penalty,
14861501
frequency_penalty=frequency_penalty,
14871502
logit_bias=logit_bias,
1503+
seed=seed,
14881504
)
14891505

14901506
prediction_response = await self._endpoint.predict_async(
@@ -1509,6 +1525,7 @@ def predict_streaming(
15091525
presence_penalty: Optional[float] = None,
15101526
frequency_penalty: Optional[float] = None,
15111527
logit_bias: Optional[Dict[int, float]] = None,
1528+
seed: Optional[int] = None,
15121529
) -> Iterator[TextGenerationResponse]:
15131530
"""Gets a streaming model response for a single prompt.
15141531
@@ -1541,6 +1558,12 @@ def predict_streaming(
15411558
Larger positive bias increases the probability of choosing the token.
15421559
Smaller negative bias decreases the probability of choosing the token.
15431560
Range: [-100.0, 100.0]
1561+
seed:
1562+
Decoder generates random noise with a pseudo random number generator, temperature * noise is added to
1563+
logits before sampling. The pseudo random number generator (prng) takes a seed as input, it generates
1564+
the same output with the same seed. If seed is not set, the seed used in decoder will not be
1565+
deterministic, thus the generated random noise will not be deterministic. If seed is set, the
1566+
generated random noise will be deterministic.
15441567
15451568
Yields:
15461569
A stream of `TextGenerationResponse` objects that contain partial
@@ -1557,6 +1580,7 @@ def predict_streaming(
15571580
presence_penalty=presence_penalty,
15581581
frequency_penalty=frequency_penalty,
15591582
logit_bias=logit_bias,
1583+
seed=seed,
15601584
)
15611585

15621586
prediction_service_client = self._endpoint._prediction_client
@@ -1587,6 +1611,7 @@ async def predict_streaming_async(
15871611
presence_penalty: Optional[float] = None,
15881612
frequency_penalty: Optional[float] = None,
15891613
logit_bias: Optional[Dict[int, float]] = None,
1614+
seed: Optional[int] = None,
15901615
) -> AsyncIterator[TextGenerationResponse]:
15911616
"""Asynchronously gets a streaming model response for a single prompt.
15921617
@@ -1619,6 +1644,12 @@ async def predict_streaming_async(
16191644
Larger positive bias increases the probability of choosing the token.
16201645
Smaller negative bias decreases the probability of choosing the token.
16211646
Range: [-100.0, 100.0]
1647+
seed:
1648+
Decoder generates random noise with a pseudo random number generator, temperature * noise is added to
1649+
logits before sampling. The pseudo random number generator (prng) takes a seed as input, it generates
1650+
the same output with the same seed. If seed is not set, the seed used in decoder will not be
1651+
deterministic, thus the generated random noise will not be deterministic. If seed is set, the
1652+
generated random noise will be deterministic.
16221653
16231654
Yields:
16241655
A stream of `TextGenerationResponse` objects that contain partial
@@ -1635,6 +1666,7 @@ async def predict_streaming_async(
16351666
presence_penalty=presence_penalty,
16361667
frequency_penalty=frequency_penalty,
16371668
logit_bias=logit_bias,
1669+
seed=seed,
16381670
)
16391671

16401672
prediction_service_async_client = self._endpoint._prediction_async_client
@@ -1671,6 +1703,7 @@ def _create_text_generation_prediction_request(
16711703
presence_penalty: Optional[float] = None,
16721704
frequency_penalty: Optional[float] = None,
16731705
logit_bias: Optional[Dict[int, int]] = None,
1706+
seed: Optional[int] = None,
16741707
) -> "_PredictionRequest":
16751708
"""Prepares the text generation request for a single prompt.
16761709
@@ -1703,6 +1736,12 @@ def _create_text_generation_prediction_request(
17031736
Larger positive bias increases the probability of choosing the token.
17041737
Smaller negative bias decreases the probability of choosing the token.
17051738
Range: [-100.0, 100.0]
1739+
seed:
1740+
Decoder generates random noise with a pseudo random number generator, temperature * noise is added to
1741+
logits before sampling. The pseudo random number generator (prng) takes a seed as input, it generates
1742+
the same output with the same seed. If seed is not set, the seed used in decoder will not be
1743+
deterministic, thus the generated random noise will not be deterministic. If seed is set, the
1744+
generated random noise will be deterministic.
17061745
17071746
Returns:
17081747
A `_PredictionRequest` object that contains prediction instance and parameters.
@@ -1749,6 +1788,9 @@ def _create_text_generation_prediction_request(
17491788
if logit_bias is not None:
17501789
prediction_parameters["logitBias"] = logit_bias
17511790

1791+
if seed is not None:
1792+
prediction_parameters["seed"] = seed
1793+
17521794
return _PredictionRequest(
17531795
instance=instance,
17541796
parameters=prediction_parameters,

0 commit comments

Comments
 (0)