Skip to content

Commit 7acf0f7

Browse files
sasha-gitgcopybara-github
authored andcommitted
feat: GenAI - Add support for logprobs and response_logprobs.
PiperOrigin-RevId: 677832804
1 parent 86fc215 commit 7acf0f7

File tree

3 files changed

+14
-0
lines changed

3 files changed

+14
-0
lines changed

tests/system/vertexai/test_generative_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ def test_generate_content_with_parameters(self, api_endpoint_env_name):
257257
candidate_count=1,
258258
max_output_tokens=100,
259259
stop_sequences=["STOP!"],
260+
response_logprobs=True,
261+
logprobs=3,
260262
),
261263
safety_settings={
262264
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,

tests/unit/vertexai/test_generative_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,8 @@ def test_generate_content(self, generative_models: generative_models):
585585
stop_sequences=["\n\n\n"],
586586
presence_penalty=0.0,
587587
frequency_penalty=0.0,
588+
logprobs=5,
589+
response_logprobs=True,
588590
),
589591
safety_settings=[
590592
generative_models.SafetySetting(

vertexai/generative_models/_generative_models.py

+10
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,8 @@ def __init__(
15771577
response_schema: Optional[Dict[str, Any]] = None,
15781578
seed: Optional[int] = None,
15791579
routing_config: Optional["RoutingConfig"] = None,
1580+
logprobs: Optional[int] = None,
1581+
response_logprobs: Optional[bool] = None,
15801582
):
15811583
r"""Constructs a GenerationConfig object.
15821584
@@ -1603,6 +1605,8 @@ def __init__(
16031605
response_schema: Output response schema of the genreated candidate text. Only valid when
16041606
response_mime_type is application/json.
16051607
routing_config: Model routing preference set in the request.
1608+
logprobs: Logit probabilities.
1609+
reponse_logprobs: If true, export the logprobs results in response.
16061610
16071611
Usage:
16081612
```
@@ -1637,6 +1641,8 @@ def __init__(
16371641
response_mime_type=response_mime_type,
16381642
response_schema=raw_schema,
16391643
seed=seed,
1644+
logprobs=logprobs,
1645+
response_logprobs=response_logprobs,
16401646
)
16411647
if routing_config is not None:
16421648
self._raw_generation_config.routing_config = (
@@ -2223,6 +2229,10 @@ def content(self) -> "Content":
22232229
def avg_logprobs(self) -> float:
22242230
return self._raw_candidate.avg_logprobs
22252231

2232+
@property
2233+
def logprobs_result(self) -> gapic_content_types.LogprobsResult:
2234+
return self._raw_candidate.logprobs_result
2235+
22262236
@property
22272237
def finish_reason(self) -> gapic_content_types.Candidate.FinishReason:
22282238
return self._raw_candidate.finish_reason

0 commit comments

Comments
 (0)