Skip to content

Commit b8b589c

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Make get_embeddings work both for foundational & tuned models.
PiperOrigin-RevId: 629254179
1 parent 3ce0126 commit b8b589c

File tree

1 file changed

+44
-37
lines changed

1 file changed

+44
-37
lines changed

vertexai/language_models/_language_models.py

+44-37
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import abc
1818
import dataclasses
19+
import collections.abc
1920
from typing import (
2021
Any,
2122
AsyncIterator,
@@ -975,6 +976,7 @@ class TuningEvaluationSpec:
975976
enable_checkpoint_selection: Optional[bool] = None
976977
tensorboard: Optional[Union[aiplatform.Tensorboard, str]] = None
977978

979+
978980
# Evaluation spec fields that are not supported by RLHF tuning
979981
_UNUSED_RLHF_EVAL_SPECS = (
980982
"evaluation_interval",
@@ -2053,30 +2055,12 @@ def _prepare_text_embedding_request(
20532055
parameters=parameters,
20542056
)
20552057

2056-
def _parse_text_embedding_response(
2057-
self,
2058-
prediction_response: aiplatform.models.Prediction,
2059-
prediction_idx: int = 0,
2060-
) -> "TextEmbedding":
2061-
"""Parses the text embedding model response."""
2062-
prediction = prediction_response.predictions[prediction_idx]
2063-
embeddings = prediction["embeddings"]
2064-
statistics = embeddings["statistics"]
2065-
return TextEmbedding(
2066-
values=embeddings["values"],
2067-
statistics=TextEmbeddingStatistics(
2068-
token_count=statistics["token_count"],
2069-
truncated=statistics["truncated"],
2070-
),
2071-
_prediction_response=prediction_response,
2072-
)
2073-
20742058
def get_embeddings(
20752059
self,
20762060
texts: List[Union[str, TextEmbeddingInput]],
20772061
*,
20782062
auto_truncate: bool = True,
2079-
output_dimensionality: Optional[int] = None
2063+
output_dimensionality: Optional[int] = None,
20802064
) -> List["TextEmbedding"]:
20812065
"""Calculates embeddings for the given texts.
20822066
@@ -2099,15 +2083,12 @@ def get_embeddings(
20992083
parameters=prediction_request.parameters,
21002084
)
21012085

2102-
results = []
2103-
for prediction_idx in range(len(prediction_response.predictions)):
2104-
result = self._parse_text_embedding_response(
2105-
prediction_response=prediction_response,
2106-
prediction_idx=prediction_idx,
2086+
return [
2087+
TextEmbedding._parse_text_embedding_response(
2088+
prediction_response, i_prediction
21072089
)
2108-
results.append(result)
2109-
2110-
return results
2090+
for i_prediction, _ in enumerate(prediction_response.predictions)
2091+
]
21112092

21122093
async def get_embeddings_async(
21132094
self,
@@ -2129,23 +2110,20 @@ async def get_embeddings_async(
21292110
prediction_request = self._prepare_text_embedding_request(
21302111
texts=texts,
21312112
auto_truncate=auto_truncate,
2132-
output_dimensionality=output_dimensionality
2113+
output_dimensionality=output_dimensionality,
21332114
)
21342115

21352116
prediction_response = await self._endpoint.predict_async(
21362117
instances=prediction_request.instances,
21372118
parameters=prediction_request.parameters,
21382119
)
21392120

2140-
results = []
2141-
for prediction_idx in range(len(prediction_response.predictions)):
2142-
result = self._parse_text_embedding_response(
2143-
prediction_response=prediction_response,
2144-
prediction_idx=prediction_idx,
2121+
return [
2122+
TextEmbedding._parse_text_embedding_response(
2123+
prediction_response, i_prediction
21452124
)
2146-
results.append(result)
2147-
2148-
return results
2125+
for i_prediction, _ in enumerate(prediction_response.predictions)
2126+
]
21492127

21502128

21512129
class _PreviewTextEmbeddingModel(
@@ -2175,6 +2153,36 @@ class TextEmbedding:
21752153
statistics: Optional[TextEmbeddingStatistics] = None
21762154
_prediction_response: Optional[aiplatform.models.Prediction] = None
21772155

2156+
@classmethod
2157+
def _parse_text_embedding_response(
2158+
cls, prediction_response: aiplatform.models.Prediction, prediction_index: int
2159+
) -> "TextEmbedding":
2160+
"""Creates a `TextEmbedding` object from a prediction.
2161+
2162+
Args:
2163+
prediction_response: `aiplatform.models.Prediction` object.
2164+
2165+
Returns:
2166+
`TextEmbedding` object.
2167+
"""
2168+
prediction = prediction_response.predictions[prediction_index]
2169+
is_prediction_from_pretrained_models = isinstance(
2170+
prediction, collections.abc.Mapping
2171+
)
2172+
if is_prediction_from_pretrained_models:
2173+
embeddings = prediction["embeddings"]
2174+
embedding_stats = embeddings["statistics"]
2175+
return cls(
2176+
values=embeddings["values"],
2177+
statistics=TextEmbeddingStatistics(
2178+
token_count=embedding_stats["token_count"],
2179+
truncated=embedding_stats["truncated"],
2180+
),
2181+
_prediction_response=prediction_response,
2182+
)
2183+
else:
2184+
return cls(values=prediction, _prediction_response=prediction_response)
2185+
21782186

21792187
@dataclasses.dataclass
21802188
class InputOutputTextPair:
@@ -3146,7 +3154,6 @@ class _CodeGenerationModel(_LanguageModel):
31463154

31473155
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
31483156

3149-
31503157
def _create_prediction_request(
31513158
self,
31523159
prefix: str,

0 commit comments

Comments
 (0)