Skip to content

Commit 9402b3d

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add vector_similarity_threshold support within RagRetrievalConfig in rag_store and rag_retrieval GA and preview versions
PiperOrigin-RevId: 700812116
1 parent 47a5a6d commit 9402b3d

10 files changed

+231
-33
lines changed

tests/unit/vertex_rag/test_rag_constants.py

+8
Original file line numberDiff line numberDiff line change
@@ -508,3 +508,11 @@
508508
top_k=2,
509509
filter=Filter(vector_distance_threshold=0.5),
510510
)
511+
TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG = RagRetrievalConfig(
512+
top_k=2,
513+
filter=Filter(vector_similarity_threshold=0.5),
514+
)
515+
TEST_RAG_RETRIEVAL_ERROR_CONFIG = RagRetrievalConfig(
516+
top_k=2,
517+
filter=Filter(vector_distance_threshold=0.5, vector_similarity_threshold=0.5),
518+
)

tests/unit/vertex_rag/test_rag_constants_preview.py

+9
Original file line numberDiff line numberDiff line change
@@ -581,3 +581,12 @@
581581
filter=Filter(vector_distance_threshold=0.5),
582582
hybrid_search=HybridSearch(alpha=0.5),
583583
)
584+
TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG = RagRetrievalConfig(
585+
top_k=2,
586+
filter=Filter(vector_distance_threshold=0.5),
587+
hybrid_search=HybridSearch(alpha=0.5),
588+
)
589+
TEST_RAG_RETRIEVAL_ERROR_CONFIG = RagRetrievalConfig(
590+
top_k=2,
591+
filter=Filter(vector_distance_threshold=0.5, vector_similarity_threshold=0.5),
592+
)

tests/unit/vertex_rag/test_rag_retrieval.py

+31
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ def test_retrieval_query_rag_resources_success(self):
7878
)
7979
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
8080

81+
@pytest.mark.usefixtures("retrieve_contexts_mock")
82+
def test_retrieval_query_rag_resources_similarity_success(self):
83+
response = rag.retrieval_query(
84+
rag_resources=[tc.TEST_RAG_RESOURCE],
85+
text=tc.TEST_QUERY_TEXT,
86+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
87+
)
88+
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
89+
8190
@pytest.mark.usefixtures("rag_client_mock_exception")
8291
def test_retrieval_query_failure(self):
8392
with pytest.raises(RuntimeError) as e:
@@ -105,3 +114,25 @@ def test_retrieval_query_multiple_rag_resources(self):
105114
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
106115
)
107116
e.match("Currently only support 1 RagResource")
117+
118+
def test_retrieval_query_similarity_multiple_rag_resources(self):
119+
with pytest.raises(ValueError) as e:
120+
rag.retrieval_query(
121+
rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE],
122+
text=tc.TEST_QUERY_TEXT,
123+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
124+
)
125+
e.match("Currently only support 1 RagResource")
126+
127+
def test_retrieval_query_invalid_config_filter(self):
128+
with pytest.raises(ValueError) as e:
129+
rag.retrieval_query(
130+
rag_resources=[tc.TEST_RAG_RESOURCE],
131+
text=tc.TEST_QUERY_TEXT,
132+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
133+
)
134+
e.match(
135+
"Only one of vector_distance_threshold or"
136+
" vector_similarity_threshold can be specified at a time"
137+
" in rag_retrieval_config."
138+
)

tests/unit/vertex_rag/test_rag_retrieval_preview.py

+36
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,17 @@ def test_retrieval_query_rag_resources_config_success(self):
9696
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
9797
)
9898

99+
@pytest.mark.usefixtures("retrieve_contexts_mock")
100+
def test_retrieval_query_rag_resources_similarity_config_success(self):
101+
response = rag.retrieval_query(
102+
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
103+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
104+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
105+
)
106+
retrieve_contexts_eq(
107+
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
108+
)
109+
99110
@pytest.mark.usefixtures("retrieve_contexts_mock")
100111
def test_retrieval_query_rag_resources_default_config_success(self):
101112
response = rag.retrieval_query(
@@ -223,3 +234,28 @@ def test_retrieval_query_multiple_rag_resources_config(self):
223234
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG,
224235
)
225236
e.match("Currently only support 1 RagResource")
237+
238+
def test_retrieval_query_multiple_rag_resources_similarity_config(self):
239+
with pytest.raises(ValueError) as e:
240+
rag.retrieval_query(
241+
rag_resources=[
242+
test_rag_constants_preview.TEST_RAG_RESOURCE,
243+
test_rag_constants_preview.TEST_RAG_RESOURCE,
244+
],
245+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
246+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
247+
)
248+
e.match("Currently only support 1 RagResource")
249+
250+
def test_retrieval_query_invalid_config_filter(self):
251+
with pytest.raises(ValueError) as e:
252+
rag.retrieval_query(
253+
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
254+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
255+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
256+
)
257+
e.match(
258+
"Only one of vector_distance_threshold or"
259+
" vector_similarity_threshold can be specified at a time"
260+
" in rag_retrieval_config."
261+
)

tests/unit/vertex_rag/test_rag_store.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_retrieval_tool_invalid_name(self):
2828
retrieval=rag.Retrieval(
2929
source=rag.VertexRagStore(
3030
rag_resources=[tc.TEST_RAG_RESOURCE_INVALID_NAME],
31-
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
31+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
3232
),
3333
)
3434
)
@@ -45,3 +45,19 @@ def test_retrieval_tool_multiple_rag_resources(self):
4545
)
4646
)
4747
e.match("Currently only support 1 RagResource")
48+
49+
def test_retrieval_tool_invalid_config_filter(self):
50+
with pytest.raises(ValueError) as e:
51+
Tool.from_retrieval(
52+
retrieval=rag.Retrieval(
53+
source=rag.VertexRagStore(
54+
rag_resources=[tc.TEST_RAG_RESOURCE],
55+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
56+
)
57+
)
58+
)
59+
e.match(
60+
"Only one of vector_distance_threshold or"
61+
" vector_similarity_threshold can be specified at a time"
62+
" in rag_retrieval_config."
63+
)

tests/unit/vertex_rag/test_rag_store_preview.py

+29
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ def test_retrieval_tool_config_success(self):
4747
)
4848
)
4949

50+
def test_retrieval_tool_similarity_config_success(self):
51+
with pytest.warns(DeprecationWarning):
52+
Tool.from_retrieval(
53+
retrieval=rag.Retrieval(
54+
source=rag.VertexRagStore(
55+
rag_corpora=[
56+
test_rag_constants_preview.TEST_RAG_CORPUS_ID,
57+
],
58+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
59+
),
60+
)
61+
)
62+
5063
def test_retrieval_tool_invalid_name(self):
5164
with pytest.raises(ValueError) as e:
5265
Tool.from_retrieval(
@@ -137,3 +150,19 @@ def test_retrieval_tool_multiple_rag_resources_config(self):
137150
)
138151
)
139152
e.match("Currently only support 1 RagResource")
153+
154+
def test_retrieval_tool_invalid_config_filter(self):
155+
with pytest.raises(ValueError) as e:
156+
Tool.from_retrieval(
157+
retrieval=rag.Retrieval(
158+
source=rag.VertexRagStore(
159+
rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE],
160+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
161+
)
162+
)
163+
)
164+
e.match(
165+
"Only one of vector_distance_threshold or"
166+
" vector_similarity_threshold can be specified at a time"
167+
" in rag_retrieval_config."
168+
)

vertexai/preview/rag/rag_retrieval.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,12 @@ def retrieval_query(
190190
else:
191191
# If rag_retrieval_config is specified, check for missing parameters.
192192
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig()
193-
api_retrival_config.top_k = (
194-
rag_retrieval_config.top_k
195-
if rag_retrieval_config.top_k
196-
else similarity_top_k
197-
)
193+
# Set top_k to config value if specified
194+
if rag_retrieval_config.top_k:
195+
api_retrival_config.top_k = rag_retrieval_config.top_k
196+
else:
197+
api_retrival_config.top_k = similarity_top_k
198+
# Set alpha to config value if specified
198199
if (
199200
rag_retrieval_config.hybrid_search
200201
and rag_retrieval_config.hybrid_search.alpha
@@ -204,6 +205,19 @@ def retrieval_query(
204205
)
205206
else:
206207
api_retrival_config.hybrid_search.alpha = vector_search_alpha
208+
# Check if both vector_distance_threshold and vector_similarity_threshold
209+
# are specified.
210+
if (
211+
rag_retrieval_config.filter
212+
and rag_retrieval_config.filter.vector_distance_threshold
213+
and rag_retrieval_config.filter.vector_similarity_threshold
214+
):
215+
raise ValueError(
216+
"Only one of vector_distance_threshold or"
217+
" vector_similarity_threshold can be specified at a time"
218+
" in rag_retrieval_config."
219+
)
220+
# Set vector_distance_threshold to config value if specified
207221
if (
208222
rag_retrieval_config.filter
209223
and rag_retrieval_config.filter.vector_distance_threshold
@@ -215,6 +229,15 @@ def retrieval_query(
215229
api_retrival_config.filter.vector_distance_threshold = (
216230
vector_distance_threshold
217231
)
232+
# Set vector_similarity_threshold to config value if specified
233+
if (
234+
rag_retrieval_config.filter
235+
and rag_retrieval_config.filter.vector_similarity_threshold
236+
):
237+
api_retrival_config.filter.vector_similarity_threshold = (
238+
rag_retrieval_config.filter.vector_similarity_threshold
239+
)
240+
218241
query = aiplatform_v1beta1.RagQuery(
219242
text=text,
220243
rag_retrieval_config=api_retrival_config,

vertexai/preview/rag/rag_store.py

+33-7
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,42 @@ def __init__(
167167
else:
168168
# If rag_retrieval_config is specified, check for missing parameters.
169169
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig()
170-
if not rag_retrieval_config.top_k:
170+
# Set top_k to config value if specified
171+
if rag_retrieval_config.top_k:
172+
api_retrival_config.top_k = rag_retrieval_config.top_k
173+
else:
171174
api_retrival_config.top_k = similarity_top_k
175+
# Check if both vector_distance_threshold and vector_similarity_threshold
176+
# are specified.
172177
if (
173-
not rag_retrieval_config.filter
174-
or not rag_retrieval_config.filter.vector_distance_threshold
178+
rag_retrieval_config.filter
179+
and rag_retrieval_config.filter.vector_distance_threshold
180+
and rag_retrieval_config.filter.vector_similarity_threshold
175181
):
176-
api_retrival_config.filter = (
177-
aiplatform_v1beta1.RagRetrievalConfig.Filter(
178-
vector_distance_threshold=vector_distance_threshold
179-
),
182+
raise ValueError(
183+
"Only one of vector_distance_threshold or"
184+
" vector_similarity_threshold can be specified at a time"
185+
" in rag_retrieval_config."
186+
)
187+
# Set vector_distance_threshold to config value if specified
188+
if (
189+
rag_retrieval_config.filter
190+
and rag_retrieval_config.filter.vector_distance_threshold
191+
):
192+
api_retrival_config.filter.vector_distance_threshold = (
193+
rag_retrieval_config.filter.vector_distance_threshold
194+
)
195+
else:
196+
api_retrival_config.filter.vector_distance_threshold = (
197+
vector_distance_threshold
198+
)
199+
# Set vector_similarity_threshold to config value if specified
200+
if (
201+
rag_retrieval_config.filter
202+
and rag_retrieval_config.filter.vector_similarity_threshold
203+
):
204+
api_retrival_config.filter.vector_similarity_threshold = (
205+
rag_retrieval_config.filter.vector_similarity_threshold
180206
)
181207

182208
if rag_resources:

vertexai/rag/rag_retrieval.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,27 @@ def retrieval_query(
108108
api_retrival_config = aiplatform_v1.RagRetrievalConfig()
109109
else:
110110
# If rag_retrieval_config is specified, check for missing parameters.
111-
api_retrival_config = aiplatform_v1.RagRetrievalConfig(
112-
top_k=rag_retrieval_config.top_k,
113-
)
111+
api_retrival_config = aiplatform_v1.RagRetrievalConfig()
112+
api_retrival_config.top_k = rag_retrieval_config.top_k
113+
# Set vector_distance_threshold to config value if specified
114114
if rag_retrieval_config.filter:
115-
api_retrival_config.filter = aiplatform_v1.RagRetrievalConfig.Filter(
116-
vector_distance_threshold=rag_retrieval_config.filter.vector_distance_threshold
115+
# Check if both vector_distance_threshold and vector_similarity_threshold
116+
# are specified.
117+
if (
118+
rag_retrieval_config.filter
119+
and rag_retrieval_config.filter.vector_distance_threshold
120+
and rag_retrieval_config.filter.vector_similarity_threshold
121+
):
122+
raise ValueError(
123+
"Only one of vector_distance_threshold or"
124+
" vector_similarity_threshold can be specified at a time"
125+
" in rag_retrieval_config."
126+
)
127+
api_retrival_config.filter.vector_distance_threshold = (
128+
rag_retrieval_config.filter.vector_distance_threshold
117129
)
118-
else:
119-
api_retrival_config.filter = aiplatform_v1.RagRetrievalConfig.Filter(
120-
vector_distance_threshold=None
130+
api_retrival_config.filter.vector_similarity_threshold = (
131+
rag_retrieval_config.filter.vector_similarity_threshold
121132
)
122133

123134
query = aiplatform_v1.RagQuery(

vertexai/rag/rag_store.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -103,20 +103,29 @@ def __init__(
103103
)
104104

105105
# If rag_retrieval_config is not specified, set it to default values.
106-
if not rag_retrieval_config:
107-
api_retrival_config = aiplatform_v1.RagRetrievalConfig()
108-
else:
109-
# If rag_retrieval_config is specified, check for missing parameters.
110-
api_retrival_config = aiplatform_v1.RagRetrievalConfig(
111-
top_k=rag_retrieval_config.top_k,
112-
)
106+
api_retrival_config = aiplatform_v1.RagRetrievalConfig()
107+
# If rag_retrieval_config is specified, populate the default config.
108+
if rag_retrieval_config:
109+
api_retrival_config.top_k = rag_retrieval_config.top_k
110+
# Set vector_distance_threshold to config value if specified
113111
if rag_retrieval_config.filter:
114-
api_retrival_config.filter = aiplatform_v1.RagRetrievalConfig.Filter(
115-
vector_distance_threshold=rag_retrieval_config.filter.vector_distance_threshold
112+
# Check if both vector_distance_threshold and
113+
# vector_similarity_threshold are specified.
114+
if (
115+
rag_retrieval_config.filter
116+
and rag_retrieval_config.filter.vector_distance_threshold
117+
and rag_retrieval_config.filter.vector_similarity_threshold
118+
):
119+
raise ValueError(
120+
"Only one of vector_distance_threshold or"
121+
" vector_similarity_threshold can be specified at a time"
122+
" in rag_retrieval_config."
123+
)
124+
api_retrival_config.filter.vector_distance_threshold = (
125+
rag_retrieval_config.filter.vector_distance_threshold
116126
)
117-
else:
118-
api_retrival_config.filter = aiplatform_v1.RagRetrievalConfig.Filter(
119-
vector_distance_threshold=None
127+
api_retrival_config.filter.vector_similarity_threshold = (
128+
rag_retrieval_config.filter.vector_similarity_threshold
120129
)
121130

122131
if rag_resources:

0 commit comments

Comments
 (0)