@@ -190,11 +190,12 @@ def retrieval_query(
190
190
else :
191
191
# If rag_retrieval_config is specified, check for missing parameters.
192
192
api_retrival_config = aiplatform_v1beta1 .RagRetrievalConfig ()
193
- api_retrival_config .top_k = (
194
- rag_retrieval_config .top_k
195
- if rag_retrieval_config .top_k
196
- else similarity_top_k
197
- )
193
+ # Set top_k to config value if specified
194
+ if rag_retrieval_config .top_k :
195
+ api_retrival_config .top_k = rag_retrieval_config .top_k
196
+ else :
197
+ api_retrival_config .top_k = similarity_top_k
198
+ # Set alpha to config value if specified
198
199
if (
199
200
rag_retrieval_config .hybrid_search
200
201
and rag_retrieval_config .hybrid_search .alpha
@@ -204,6 +205,19 @@ def retrieval_query(
204
205
)
205
206
else :
206
207
api_retrival_config .hybrid_search .alpha = vector_search_alpha
208
+ # Check if both vector_distance_threshold and vector_similarity_threshold
209
+ # are specified.
210
+ if (
211
+ rag_retrieval_config .filter
212
+ and rag_retrieval_config .filter .vector_distance_threshold
213
+ and rag_retrieval_config .filter .vector_similarity_threshold
214
+ ):
215
+ raise ValueError (
216
+ "Only one of vector_distance_threshold or"
217
+ " vector_similarity_threshold can be specified at a time"
218
+ " in rag_retrieval_config."
219
+ )
220
+ # Set vector_distance_threshold to config value if specified
207
221
if (
208
222
rag_retrieval_config .filter
209
223
and rag_retrieval_config .filter .vector_distance_threshold
@@ -215,6 +229,15 @@ def retrieval_query(
215
229
api_retrival_config .filter .vector_distance_threshold = (
216
230
vector_distance_threshold
217
231
)
232
+ # Set vector_similarity_threshold to config value if specified
233
+ if (
234
+ rag_retrieval_config .filter
235
+ and rag_retrieval_config .filter .vector_similarity_threshold
236
+ ):
237
+ api_retrival_config .filter .vector_similarity_threshold = (
238
+ rag_retrieval_config .filter .vector_similarity_threshold
239
+ )
240
+
218
241
query = aiplatform_v1beta1 .RagQuery (
219
242
text = text ,
220
243
rag_retrieval_config = api_retrival_config ,
0 commit comments