Skip to content

Commit d3d69d6

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Adding Vertex AI Search Config for RAG corpuses to SDK
PiperOrigin-RevId: 700775020
1 parent 88ac48c commit d3d69d6

File tree

6 files changed

+285
-8
lines changed

6 files changed

+285
-8
lines changed

tests/unit/vertex_rag/test_rag_constants_preview.py

+40
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
JiraSource,
3535
JiraQuery,
3636
Weaviate,
37+
VertexAiSearchConfig,
3738
VertexVectorSearch,
3839
VertexFeatureStore,
3940
)
@@ -52,6 +53,7 @@
5253
RagContexts,
5354
RetrieveContextsResponse,
5455
RagVectorDbConfig,
56+
VertexAiSearchConfig as GapicVertexAiSearchConfig,
5557
)
5658
from google.cloud.aiplatform_v1beta1.types import api_auth
5759
from google.protobuf import timestamp_pb2
@@ -189,6 +191,44 @@
189191
vector_db=TEST_VERTEX_VECTOR_SEARCH_CONFIG,
190192
)
191193
TEST_PAGE_TOKEN = "test-page-token"
194+
# Vertex AI Search Config
195+
TEST_VERTEX_AI_SEARCH_ENGINE_SERVING_CONFIG = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/collections/test-collection/engines/test-engine/servingConfigs/test-serving-config"
196+
TEST_VERTEX_AI_SEARCH_DATASTORE_SERVING_CONFIG = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/collections/test-collection/dataStores/test-datastore/servingConfigs/test-serving-config"
197+
TEST_GAPIC_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG = GapicRagCorpus(
198+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
199+
display_name=TEST_CORPUS_DISPLAY_NAME,
200+
vertex_ai_search_config=GapicVertexAiSearchConfig(
201+
serving_config=TEST_VERTEX_AI_SEARCH_ENGINE_SERVING_CONFIG,
202+
),
203+
)
204+
TEST_GAPIC_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG = GapicRagCorpus(
205+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
206+
display_name=TEST_CORPUS_DISPLAY_NAME,
207+
vertex_ai_search_config=GapicVertexAiSearchConfig(
208+
serving_config=TEST_VERTEX_AI_SEARCH_DATASTORE_SERVING_CONFIG,
209+
),
210+
)
211+
TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE = VertexAiSearchConfig(
212+
serving_config=TEST_VERTEX_AI_SEARCH_ENGINE_SERVING_CONFIG,
213+
)
214+
TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE = VertexAiSearchConfig(
215+
serving_config=TEST_VERTEX_AI_SEARCH_DATASTORE_SERVING_CONFIG,
216+
)
217+
TEST_VERTEX_AI_SEARCH_CONFIG_INVALID = VertexAiSearchConfig(
218+
serving_config="invalid-serving-config",
219+
)
220+
TEST_VERTEX_AI_SEARCH_CONFIG_EMPTY = VertexAiSearchConfig()
221+
222+
TEST_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG = RagCorpus(
223+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
224+
display_name=TEST_CORPUS_DISPLAY_NAME,
225+
vertex_ai_search_config=TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE,
226+
)
227+
TEST_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG = RagCorpus(
228+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
229+
display_name=TEST_CORPUS_DISPLAY_NAME,
230+
vertex_ai_search_config=TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
231+
)
192232

193233
# RagFiles
194234
TEST_PATH = "usr/home/my_file.txt"

tests/unit/vertex_rag/test_rag_data_preview.py

+141
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,57 @@ def create_rag_corpus_mock_pinecone():
113113
yield create_rag_corpus_mock_pinecone
114114

115115

116+
@pytest.fixture
117+
def create_rag_corpus_mock_vertex_ai_engine_search_config():
118+
with mock.patch.object(
119+
VertexRagDataServiceClient,
120+
"create_rag_corpus",
121+
) as create_rag_corpus_mock_vertex_ai_engine_search_config:
122+
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
123+
create_rag_corpus_lro_mock.done.return_value = True
124+
create_rag_corpus_lro_mock.result.return_value = (
125+
test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG
126+
)
127+
create_rag_corpus_mock_vertex_ai_engine_search_config.return_value = (
128+
create_rag_corpus_lro_mock
129+
)
130+
yield create_rag_corpus_mock_vertex_ai_engine_search_config
131+
132+
133+
@pytest.fixture
134+
def create_rag_corpus_mock_vertex_ai_datastore_search_config():
135+
with mock.patch.object(
136+
VertexRagDataServiceClient,
137+
"create_rag_corpus",
138+
) as create_rag_corpus_mock_vertex_ai_datastore_search_config:
139+
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
140+
create_rag_corpus_lro_mock.done.return_value = True
141+
create_rag_corpus_lro_mock.result.return_value = (
142+
test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG
143+
)
144+
create_rag_corpus_mock_vertex_ai_datastore_search_config.return_value = (
145+
create_rag_corpus_lro_mock
146+
)
147+
yield create_rag_corpus_mock_vertex_ai_datastore_search_config
148+
149+
150+
@pytest.fixture
151+
def update_rag_corpus_mock_vertex_ai_engine_search_config():
152+
with mock.patch.object(
153+
VertexRagDataServiceClient,
154+
"update_rag_corpus",
155+
) as update_rag_corpus_mock_vertex_ai_engine_search_config:
156+
update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
157+
update_rag_corpus_lro_mock.done.return_value = True
158+
update_rag_corpus_lro_mock.result.return_value = (
159+
test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG
160+
)
161+
update_rag_corpus_mock_vertex_ai_engine_search_config.return_value = (
162+
update_rag_corpus_lro_mock
163+
)
164+
yield update_rag_corpus_mock_vertex_ai_engine_search_config
165+
166+
116167
@pytest.fixture
117168
def update_rag_corpus_mock_weaviate():
118169
with mock.patch.object(
@@ -280,6 +331,9 @@ def rag_corpus_eq(returned_corpus, expected_corpus):
280331
assert returned_corpus.name == expected_corpus.name
281332
assert returned_corpus.display_name == expected_corpus.display_name
282333
assert returned_corpus.vector_db.__eq__(expected_corpus.vector_db)
334+
assert returned_corpus.vertex_ai_search_config.__eq__(
335+
expected_corpus.vertex_ai_search_config
336+
)
283337

284338

285339
def rag_file_eq(returned_file, expected_file):
@@ -373,6 +427,70 @@ def test_create_corpus_pinecone_success(self):
373427

374428
rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_PINECONE)
375429

430+
@pytest.mark.usefixtures("create_rag_corpus_mock_vertex_ai_engine_search_config")
431+
def test_create_corpus_vais_engine_search_config_success(self):
432+
rag_corpus = rag.create_corpus(
433+
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
434+
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE,
435+
)
436+
437+
rag_corpus_eq(
438+
rag_corpus,
439+
test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG,
440+
)
441+
442+
@pytest.mark.usefixtures("create_rag_corpus_mock_vertex_ai_datastore_search_config")
443+
def test_create_corpus_vais_datastore_search_config_success(self):
444+
rag_corpus = rag.create_corpus(
445+
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
446+
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
447+
)
448+
449+
rag_corpus_eq(
450+
rag_corpus,
451+
test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG,
452+
)
453+
454+
def test_create_corpus_vais_datastore_search_config_with_vector_db_failure(self):
455+
with pytest.raises(ValueError) as e:
456+
rag.create_corpus(
457+
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
458+
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
459+
vector_db=test_rag_constants_preview.TEST_VERTEX_VECTOR_SEARCH_CONFIG,
460+
)
461+
e.match("Only one of vertex_ai_search_config or vector_db can be set.")
462+
463+
def test_create_corpus_vais_datastore_search_config_with_embedding_model_config_failure(
464+
self,
465+
):
466+
with pytest.raises(ValueError) as e:
467+
rag.create_corpus(
468+
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
469+
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
470+
embedding_model_config=test_rag_constants_preview.TEST_EMBEDDING_MODEL_CONFIG,
471+
)
472+
e.match(
473+
"Only one of vertex_ai_search_config or embedding_model_config can be set."
474+
)
475+
476+
def test_set_vertex_ai_search_config_with_invalid_serving_config_failure(self):
477+
with pytest.raises(ValueError) as e:
478+
rag.create_corpus(
479+
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
480+
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_INVALID,
481+
)
482+
e.match(
483+
"serving_config must be of the format `projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}` or `projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`"
484+
)
485+
486+
def test_set_vertex_ai_search_config_with_empty_serving_config_failure(self):
487+
with pytest.raises(ValueError) as e:
488+
rag.create_corpus(
489+
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
490+
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_EMPTY,
491+
)
492+
e.match("serving_config must be set.")
493+
376494
@pytest.mark.usefixtures("rag_data_client_preview_mock_exception")
377495
def test_create_corpus_failure(self):
378496
with pytest.raises(RuntimeError) as e:
@@ -462,6 +580,29 @@ def test_update_corpus_failure(self):
462580
)
463581
e.match("Failed in RagCorpus update due to")
464582

583+
@pytest.mark.usefixtures("update_rag_corpus_mock_vertex_ai_engine_search_config")
584+
def test_update_corpus_vais_engine_search_config_success(self):
585+
rag_corpus = rag.update_corpus(
586+
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
587+
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
588+
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE,
589+
)
590+
591+
rag_corpus_eq(
592+
rag_corpus,
593+
test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG,
594+
)
595+
596+
def test_update_corpus_vais_datastore_search_config_with_vector_db_failure(self):
597+
with pytest.raises(ValueError) as e:
598+
rag.update_corpus(
599+
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
600+
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
601+
vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE,
602+
vector_db=test_rag_constants_preview.TEST_VERTEX_VECTOR_SEARCH_CONFIG,
603+
)
604+
e.match("Only one of vertex_ai_search_config or vector_db can be set.")
605+
465606
@pytest.mark.usefixtures("rag_data_client_preview_mock")
466607
def test_get_corpus_success(self):
467608
rag_corpus = rag.get_corpus(

vertexai/preview/rag/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
SharePointSources,
5454
SlackChannel,
5555
SlackChannelsSource,
56+
VertexAiSearchConfig,
5657
VertexFeatureStore,
5758
VertexVectorSearch,
5859
Weaviate,
@@ -76,6 +77,7 @@
7677
"SharePointSources",
7778
"SlackChannel",
7879
"SlackChannelsSource",
80+
"VertexAiSearchConfig",
7981
"VertexFeatureStore",
8082
"VertexRagStore",
8183
"VertexVectorSearch",

vertexai/preview/rag/rag_data.py

+42-8
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
RagManagedDb,
5252
SharePointSources,
5353
SlackChannelsSource,
54+
VertexAiSearchConfig,
5455
VertexFeatureStore,
5556
VertexVectorSearch,
5657
Weaviate,
@@ -64,6 +65,7 @@ def create_corpus(
6465
vector_db: Optional[
6566
Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]
6667
] = None,
68+
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
6769
) -> RagCorpus:
6870
"""Creates a new RagCorpus resource.
6971
@@ -87,6 +89,9 @@ def create_corpus(
8789
embedding_model_config: The embedding model config.
8890
vector_db: The vector db config of the RagCorpus. If unspecified, the
8991
default database Spanner is used.
92+
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
93+
Note: embedding_model_config or vector_db cannot be set if
94+
vertex_ai_search_config is specified.
9095
Returns:
9196
RagCorpus.
9297
Raises:
@@ -103,10 +108,25 @@ def create_corpus(
103108
embedding_model_config=embedding_model_config,
104109
rag_corpus=rag_corpus,
105110
)
106-
_gapic_utils.set_vector_db(
107-
vector_db=vector_db,
108-
rag_corpus=rag_corpus,
109-
)
111+
112+
if vertex_ai_search_config and embedding_model_config:
113+
raise ValueError(
114+
"Only one of vertex_ai_search_config or embedding_model_config can be set."
115+
)
116+
117+
if vertex_ai_search_config and vector_db:
118+
raise ValueError("Only one of vertex_ai_search_config or vector_db can be set.")
119+
120+
if vertex_ai_search_config:
121+
_gapic_utils.set_vertex_ai_search_config(
122+
vertex_ai_search_config=vertex_ai_search_config,
123+
rag_corpus=rag_corpus,
124+
)
125+
else:
126+
_gapic_utils.set_vector_db(
127+
vector_db=vector_db,
128+
rag_corpus=rag_corpus,
129+
)
110130

111131
request = CreateRagCorpusRequest(
112132
parent=parent,
@@ -134,6 +154,7 @@ def update_corpus(
134154
RagManagedDb,
135155
]
136156
] = None,
157+
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
137158
) -> RagCorpus:
138159
"""Updates a RagCorpus resource.
139160
@@ -161,6 +182,10 @@ def update_corpus(
161182
description will not be updated.
162183
vector_db: The vector db config of the RagCorpus. If not provided, the
163184
vector db will not be updated.
185+
vertex_ai_search_config: The Vertex AI Search config of the RagCorpus.
186+
If not provided, the Vertex AI Search config will not be updated.
187+
Note: embedding_model_config or vector_db cannot be set if
188+
vertex_ai_search_config is specified.
164189
165190
Returns:
166191
RagCorpus.
@@ -180,10 +205,19 @@ def update_corpus(
180205
else:
181206
rag_corpus = GapicRagCorpus(name=corpus_name)
182207

183-
_gapic_utils.set_vector_db(
184-
vector_db=vector_db,
185-
rag_corpus=rag_corpus,
186-
)
208+
if vertex_ai_search_config and vector_db:
209+
raise ValueError("Only one of vertex_ai_search_config or vector_db can be set.")
210+
211+
if vertex_ai_search_config:
212+
_gapic_utils.set_vertex_ai_search_config(
213+
vertex_ai_search_config=vertex_ai_search_config,
214+
rag_corpus=rag_corpus,
215+
)
216+
else:
217+
_gapic_utils.set_vector_db(
218+
vector_db=vector_db,
219+
rag_corpus=rag_corpus,
220+
)
187221

188222
request = UpdateRagCorpusRequest(
189223
rag_corpus=rag_corpus,

0 commit comments

Comments
 (0)