16
16
17
17
import abc
18
18
import dataclasses
19
+ import collections .abc
19
20
from typing import (
20
21
Any ,
21
22
AsyncIterator ,
@@ -975,6 +976,7 @@ class TuningEvaluationSpec:
975
976
enable_checkpoint_selection : Optional [bool ] = None
976
977
tensorboard : Optional [Union [aiplatform .Tensorboard , str ]] = None
977
978
979
+
978
980
# Evaluation spec fields that are not supported by RLHF tuning
979
981
_UNUSED_RLHF_EVAL_SPECS = (
980
982
"evaluation_interval" ,
@@ -2053,30 +2055,12 @@ def _prepare_text_embedding_request(
2053
2055
parameters = parameters ,
2054
2056
)
2055
2057
2056
- def _parse_text_embedding_response (
2057
- self ,
2058
- prediction_response : aiplatform .models .Prediction ,
2059
- prediction_idx : int = 0 ,
2060
- ) -> "TextEmbedding" :
2061
- """Parses the text embedding model response."""
2062
- prediction = prediction_response .predictions [prediction_idx ]
2063
- embeddings = prediction ["embeddings" ]
2064
- statistics = embeddings ["statistics" ]
2065
- return TextEmbedding (
2066
- values = embeddings ["values" ],
2067
- statistics = TextEmbeddingStatistics (
2068
- token_count = statistics ["token_count" ],
2069
- truncated = statistics ["truncated" ],
2070
- ),
2071
- _prediction_response = prediction_response ,
2072
- )
2073
-
2074
2058
def get_embeddings (
2075
2059
self ,
2076
2060
texts : List [Union [str , TextEmbeddingInput ]],
2077
2061
* ,
2078
2062
auto_truncate : bool = True ,
2079
- output_dimensionality : Optional [int ] = None
2063
+ output_dimensionality : Optional [int ] = None ,
2080
2064
) -> List ["TextEmbedding" ]:
2081
2065
"""Calculates embeddings for the given texts.
2082
2066
@@ -2099,15 +2083,12 @@ def get_embeddings(
2099
2083
parameters = prediction_request .parameters ,
2100
2084
)
2101
2085
2102
- results = []
2103
- for prediction_idx in range (len (prediction_response .predictions )):
2104
- result = self ._parse_text_embedding_response (
2105
- prediction_response = prediction_response ,
2106
- prediction_idx = prediction_idx ,
2086
+ return [
2087
+ TextEmbedding ._parse_text_embedding_response (
2088
+ prediction_response , i_prediction
2107
2089
)
2108
- results .append (result )
2109
-
2110
- return results
2090
+ for i_prediction , _ in enumerate (prediction_response .predictions )
2091
+ ]
2111
2092
2112
2093
async def get_embeddings_async (
2113
2094
self ,
@@ -2129,23 +2110,20 @@ async def get_embeddings_async(
2129
2110
prediction_request = self ._prepare_text_embedding_request (
2130
2111
texts = texts ,
2131
2112
auto_truncate = auto_truncate ,
2132
- output_dimensionality = output_dimensionality
2113
+ output_dimensionality = output_dimensionality ,
2133
2114
)
2134
2115
2135
2116
prediction_response = await self ._endpoint .predict_async (
2136
2117
instances = prediction_request .instances ,
2137
2118
parameters = prediction_request .parameters ,
2138
2119
)
2139
2120
2140
- results = []
2141
- for prediction_idx in range (len (prediction_response .predictions )):
2142
- result = self ._parse_text_embedding_response (
2143
- prediction_response = prediction_response ,
2144
- prediction_idx = prediction_idx ,
2121
+ return [
2122
+ TextEmbedding ._parse_text_embedding_response (
2123
+ prediction_response , i_prediction
2145
2124
)
2146
- results .append (result )
2147
-
2148
- return results
2125
+ for i_prediction , _ in enumerate (prediction_response .predictions )
2126
+ ]
2149
2127
2150
2128
2151
2129
class _PreviewTextEmbeddingModel (
@@ -2175,6 +2153,36 @@ class TextEmbedding:
2175
2153
statistics : Optional [TextEmbeddingStatistics ] = None
2176
2154
_prediction_response : Optional [aiplatform .models .Prediction ] = None
2177
2155
2156
+ @classmethod
2157
+ def _parse_text_embedding_response (
2158
+ cls , prediction_response : aiplatform .models .Prediction , prediction_index : int
2159
+ ) -> "TextEmbedding" :
2160
+ """Creates a `TextEmbedding` object from a prediction.
2161
+
2162
+ Args:
2163
+ prediction_response: `aiplatform.models.Prediction` object.
2164
+
2165
+ Returns:
2166
+ `TextEmbedding` object.
2167
+ """
2168
+ prediction = prediction_response .predictions [prediction_index ]
2169
+ is_prediction_from_pretrained_models = isinstance (
2170
+ prediction , collections .abc .Mapping
2171
+ )
2172
+ if is_prediction_from_pretrained_models :
2173
+ embeddings = prediction ["embeddings" ]
2174
+ embedding_stats = embeddings ["statistics" ]
2175
+ return cls (
2176
+ values = embeddings ["values" ],
2177
+ statistics = TextEmbeddingStatistics (
2178
+ token_count = embedding_stats ["token_count" ],
2179
+ truncated = embedding_stats ["truncated" ],
2180
+ ),
2181
+ _prediction_response = prediction_response ,
2182
+ )
2183
+ else :
2184
+ return cls (values = prediction , _prediction_response = prediction_response )
2185
+
2178
2186
2179
2187
@dataclasses .dataclass
2180
2188
class InputOutputTextPair :
@@ -3146,7 +3154,6 @@ class _CodeGenerationModel(_LanguageModel):
3146
3154
3147
3155
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml"
3148
3156
3149
-
3150
3157
def _create_prediction_request (
3151
3158
self ,
3152
3159
prefix : str ,
0 commit comments