@@ -43,6 +43,11 @@ def retrieval_query(
43
43
filter=vertexai.rag.rag_retrieval_config.filter(
44
44
vector_distance_threshold=0.5
45
45
),
46
+ ranking=vertex.rag.Ranking(
47
+ llm_ranker=vertexai.rag.LlmRanker(
48
+ model_name="gemini-1.5-flash-002"
49
+ )
50
+ )
46
51
)
47
52
48
53
results = vertexai.rag.retrieval_query(
@@ -105,11 +110,11 @@ def retrieval_query(
105
110
106
111
# If rag_retrieval_config is not specified, set it to default values.
107
112
if not rag_retrieval_config :
108
- api_retrival_config = aiplatform_v1 .RagRetrievalConfig ()
113
+ api_retrieval_config = aiplatform_v1 .RagRetrievalConfig ()
109
114
else :
110
115
# If rag_retrieval_config is specified, check for missing parameters.
111
- api_retrival_config = aiplatform_v1 .RagRetrievalConfig ()
112
- api_retrival_config .top_k = rag_retrieval_config .top_k
116
+ api_retrieval_config = aiplatform_v1 .RagRetrievalConfig ()
117
+ api_retrieval_config .top_k = rag_retrieval_config .top_k
113
118
# Set vector_distance_threshold to config value if specified
114
119
if rag_retrieval_config .filter :
115
120
# Check if both vector_distance_threshold and vector_similarity_threshold
@@ -124,16 +129,30 @@ def retrieval_query(
124
129
" vector_similarity_threshold can be specified at a time"
125
130
" in rag_retrieval_config."
126
131
)
127
- api_retrival_config .filter .vector_distance_threshold = (
132
+ api_retrieval_config .filter .vector_distance_threshold = (
128
133
rag_retrieval_config .filter .vector_distance_threshold
129
134
)
130
- api_retrival_config .filter .vector_similarity_threshold = (
135
+ api_retrieval_config .filter .vector_similarity_threshold = (
131
136
rag_retrieval_config .filter .vector_similarity_threshold
132
137
)
138
+ if (
139
+ rag_retrieval_config .ranking
140
+ and rag_retrieval_config .ranking .rank_service
141
+ and rag_retrieval_config .ranking .llm_ranker
142
+ ):
143
+ raise ValueError ("Only one of rank_service and llm_ranker can be set." )
144
+ if rag_retrieval_config .ranking and rag_retrieval_config .ranking .rank_service :
145
+ api_retrieval_config .ranking .rank_service .model_name = (
146
+ rag_retrieval_config .ranking .rank_service .model_name
147
+ )
148
+ elif rag_retrieval_config .ranking and rag_retrieval_config .ranking .llm_ranker :
149
+ api_retrieval_config .ranking .llm_ranker .model_name = (
150
+ rag_retrieval_config .ranking .llm_ranker .model_name
151
+ )
133
152
134
153
query = aiplatform_v1 .RagQuery (
135
154
text = text ,
136
- rag_retrieval_config = api_retrival_config ,
155
+ rag_retrieval_config = api_retrieval_config ,
137
156
)
138
157
request = aiplatform_v1 .RetrieveContextsRequest (
139
158
vertex_rag_store = vertex_rag_store ,
0 commit comments