Skip to content

Commit a6b7de5

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Modify v1 sdk to support rerankers
PiperOrigin-RevId: 738100302
1 parent 1232132 commit a6b7de5

File tree

7 files changed

+170
-6
lines changed

7 files changed

+170
-6
lines changed

tests/unit/vertex_rag/test_rag_constants.py

+26
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@
2121
from vertexai.rag import (
2222
Filter,
2323
LayoutParserConfig,
24+
LlmRanker,
2425
Pinecone,
2526
RagCorpus,
2627
RagFile,
2728
RagResource,
2829
RagRetrievalConfig,
2930
RagVectorDbConfig,
31+
Ranking,
32+
RankService,
3033
SharePointSource,
3134
SharePointSources,
3235
SlackChannelsSource,
@@ -560,3 +563,26 @@
560563
top_k=2,
561564
filter=Filter(vector_distance_threshold=0.5, vector_similarity_threshold=0.5),
562565
)
566+
TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE = RagRetrievalConfig(
567+
top_k=2,
568+
filter=Filter(vector_distance_threshold=0.5),
569+
ranking=Ranking(rank_service=RankService(model_name="test-model-name")),
570+
)
571+
TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER = RagRetrievalConfig(
572+
top_k=2,
573+
filter=Filter(vector_distance_threshold=0.5),
574+
ranking=Ranking(llm_ranker=LlmRanker(model_name="test-model-name")),
575+
)
576+
TEST_RAG_RETRIEVAL_RANKING_CONFIG = RagRetrievalConfig(
577+
top_k=2,
578+
filter=Filter(vector_distance_threshold=0.5),
579+
ranking=Ranking(rank_service=RankService(model_name="test-rank-service")),
580+
)
581+
TEST_RAG_RETRIEVAL_ERROR_RANKING_CONFIG = RagRetrievalConfig(
582+
top_k=2,
583+
filter=Filter(vector_distance_threshold=0.5),
584+
ranking=Ranking(
585+
rank_service=RankService(model_name="test-rank-service"),
586+
llm_ranker=LlmRanker(model_name="test-llm-ranker"),
587+
),
588+
)

tests/unit/vertex_rag/test_rag_retrieval.py

+18
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,24 @@ def test_retrieval_query_rag_resources_similarity_success(self):
8787
)
8888
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
8989

90+
@pytest.mark.usefixtures("retrieve_contexts_mock")
91+
def test_retrieval_query_rag_corpora_config_rank_service_success(self):
92+
response = rag.retrieval_query(
93+
rag_resources=[tc.TEST_RAG_RESOURCE],
94+
text=tc.TEST_QUERY_TEXT,
95+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE,
96+
)
97+
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
98+
99+
@pytest.mark.usefixtures("retrieve_contexts_mock")
100+
def test_retrieval_query_rag_corpora_config_llm_ranker_success(self):
101+
response = rag.retrieval_query(
102+
rag_resources=[tc.TEST_RAG_RESOURCE],
103+
text=tc.TEST_QUERY_TEXT,
104+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER,
105+
)
106+
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
107+
90108
@pytest.mark.usefixtures("rag_client_mock_exception")
91109
def test_retrieval_query_failure(self):
92110
with pytest.raises(RuntimeError) as e:

tests/unit/vertex_rag/test_rag_store.py

+26
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ def test_retrieval_tool_no_rag_resources(self):
5555
)
5656
e.match("rag_resources must be specified.")
5757

58+
def test_retrieval_tool_ranking_config_success(self):
59+
tool = Tool.from_retrieval(
60+
retrieval=rag.Retrieval(
61+
source=rag.VertexRagStore(
62+
rag_resources=[tc.TEST_RAG_RESOURCE],
63+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_RANKING_CONFIG,
64+
),
65+
)
66+
)
67+
assert tool is not None
68+
5869
def test_retrieval_tool_invalid_name(self):
5970
with pytest.raises(ValueError) as e:
6071
Tool.from_retrieval(
@@ -94,3 +105,18 @@ def test_retrieval_tool_invalid_config_filter(self):
94105
" vector_similarity_threshold can be specified at a time"
95106
" in rag_retrieval_config."
96107
)
108+
109+
def test_retrieval_tool_invalid_ranking_config_filter(self):
110+
with pytest.raises(ValueError) as e:
111+
Tool.from_retrieval(
112+
retrieval=rag.Retrieval(
113+
source=rag.VertexRagStore(
114+
rag_resources=[tc.TEST_RAG_RESOURCE],
115+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_RANKING_CONFIG,
116+
)
117+
)
118+
)
119+
e.match(
120+
"Only one of rank_service or llm_ranker can be specified"
121+
" at a time in rag_retrieval_config."
122+
)

vertexai/rag/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
JiraQuery,
4444
JiraSource,
4545
LayoutParserConfig,
46+
LlmRanker,
4647
Pinecone,
4748
RagCorpus,
4849
RagEmbeddingModelConfig,
@@ -51,6 +52,8 @@
5152
RagResource,
5253
RagRetrievalConfig,
5354
RagVectorDbConfig,
55+
Ranking,
56+
RankService,
5457
SharePointSource,
5558
SharePointSources,
5659
SlackChannel,
@@ -67,6 +70,7 @@
6770
"JiraQuery",
6871
"JiraSource",
6972
"LayoutParserConfig",
73+
"LlmRanker",
7074
"Pinecone",
7175
"RagCorpus",
7276
"RagEmbeddingModelConfig",
@@ -75,6 +79,8 @@
7579
"RagResource",
7680
"RagRetrievalConfig",
7781
"RagVectorDbConfig",
82+
"Ranking",
83+
"RankService",
7884
"Retrieval",
7985
"SharePointSource",
8086
"SharePointSources",

vertexai/rag/rag_retrieval.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def retrieval_query(
4343
filter=vertexai.rag.rag_retrieval_config.filter(
4444
vector_distance_threshold=0.5
4545
),
46+
ranking=vertex.rag.Ranking(
47+
llm_ranker=vertexai.rag.LlmRanker(
48+
model_name="gemini-1.5-flash-002"
49+
)
50+
)
4651
)
4752
4853
results = vertexai.rag.retrieval_query(
@@ -105,11 +110,11 @@ def retrieval_query(
105110

106111
# If rag_retrieval_config is not specified, set it to default values.
107112
if not rag_retrieval_config:
108-
api_retrival_config = aiplatform_v1.RagRetrievalConfig()
113+
api_retrieval_config = aiplatform_v1.RagRetrievalConfig()
109114
else:
110115
# If rag_retrieval_config is specified, check for missing parameters.
111-
api_retrival_config = aiplatform_v1.RagRetrievalConfig()
112-
api_retrival_config.top_k = rag_retrieval_config.top_k
116+
api_retrieval_config = aiplatform_v1.RagRetrievalConfig()
117+
api_retrieval_config.top_k = rag_retrieval_config.top_k
113118
# Set vector_distance_threshold to config value if specified
114119
if rag_retrieval_config.filter:
115120
# Check if both vector_distance_threshold and vector_similarity_threshold
@@ -124,16 +129,30 @@ def retrieval_query(
124129
" vector_similarity_threshold can be specified at a time"
125130
" in rag_retrieval_config."
126131
)
127-
api_retrival_config.filter.vector_distance_threshold = (
132+
api_retrieval_config.filter.vector_distance_threshold = (
128133
rag_retrieval_config.filter.vector_distance_threshold
129134
)
130-
api_retrival_config.filter.vector_similarity_threshold = (
135+
api_retrieval_config.filter.vector_similarity_threshold = (
131136
rag_retrieval_config.filter.vector_similarity_threshold
132137
)
138+
if (
139+
rag_retrieval_config.ranking
140+
and rag_retrieval_config.ranking.rank_service
141+
and rag_retrieval_config.ranking.llm_ranker
142+
):
143+
raise ValueError("Only one of rank_service and llm_ranker can be set.")
144+
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service:
145+
api_retrieval_config.ranking.rank_service.model_name = (
146+
rag_retrieval_config.ranking.rank_service.model_name
147+
)
148+
elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
149+
api_retrieval_config.ranking.llm_ranker.model_name = (
150+
rag_retrieval_config.ranking.llm_ranker.model_name
151+
)
133152

134153
query = aiplatform_v1.RagQuery(
135154
text=text,
136-
rag_retrieval_config=api_retrival_config,
155+
rag_retrieval_config=api_retrieval_config,
137156
)
138157
request = aiplatform_v1.RetrieveContextsRequest(
139158
vertex_rag_store=vertex_rag_store,

vertexai/rag/rag_store.py

+30
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def __init__(
6262
filter=vertexai.rag.RagRetrievalConfig.Filter(
6363
vector_distance_threshold=0.5
6464
),
65+
ranking=vertex.rag.Ranking(
66+
llm_ranker=vertexai.rag.LlmRanker(
67+
model_name="gemini-1.5-flash-002"
68+
)
69+
)
6570
)
6671
6772
tool = Tool.from_retrieval(
@@ -127,6 +132,31 @@ def __init__(
127132
api_retrieval_config.filter.vector_similarity_threshold = (
128133
rag_retrieval_config.filter.vector_similarity_threshold
129134
)
135+
# Check if both rank_service and llm_ranker are specified.
136+
if (
137+
rag_retrieval_config.ranking
138+
and rag_retrieval_config.ranking.rank_service
139+
and rag_retrieval_config.ranking.rank_service.model_name
140+
and rag_retrieval_config.ranking.llm_ranker
141+
and rag_retrieval_config.ranking.llm_ranker.model_name
142+
):
143+
raise ValueError(
144+
"Only one of rank_service or llm_ranker can be specified"
145+
" at a time in rag_retrieval_config."
146+
)
147+
# Set rank_service to config value if specified
148+
if (
149+
rag_retrieval_config.ranking
150+
and rag_retrieval_config.ranking.rank_service
151+
):
152+
api_retrieval_config.ranking.rank_service.model_name = (
153+
rag_retrieval_config.ranking.rank_service.model_name
154+
)
155+
# Set llm_ranker to config value if specified
156+
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
157+
api_retrieval_config.ranking.llm_ranker.model_name = (
158+
rag_retrieval_config.ranking.llm_ranker.model_name
159+
)
130160

131161
gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource(
132162
rag_corpus=rag_corpus_name,

vertexai/rag/utils/resources.py

+39
Original file line numberDiff line numberDiff line change
@@ -332,17 +332,56 @@ class Filter:
332332
metadata_filter: Optional[str] = None
333333

334334

335+
@dataclasses.dataclass
336+
class LlmRanker:
337+
"""LlmRanker.
338+
339+
Attributes:
340+
model_name: The model name used for ranking. Only Gemini models are
341+
supported for now.
342+
"""
343+
344+
model_name: Optional[str] = None
345+
346+
347+
@dataclasses.dataclass
348+
class RankService:
349+
"""RankService.
350+
351+
Attributes:
352+
model_name: The model name of the rank service. Format:
353+
``semantic-ranker-512@latest``
354+
"""
355+
356+
model_name: Optional[str] = None
357+
358+
359+
@dataclasses.dataclass
360+
class Ranking:
361+
"""Ranking.
362+
363+
Attributes:
364+
rank_service: Config for Rank Service.
365+
llm_ranker: Config for LlmRanker.
366+
"""
367+
368+
rank_service: Optional[RankService] = None
369+
llm_ranker: Optional[LlmRanker] = None
370+
371+
335372
@dataclasses.dataclass
336373
class RagRetrievalConfig:
337374
"""RagRetrievalConfig.
338375
339376
Attributes:
340377
top_k: The number of contexts to retrieve.
341378
filter: Config for filters.
379+
ranking: Config for ranking.
342380
"""
343381

344382
top_k: Optional[int] = None
345383
filter: Optional[Filter] = None
384+
ranking: Optional[Ranking] = None
346385

347386

348387
@dataclasses.dataclass

0 commit comments

Comments
 (0)